top of page

I bolted MiniMax's MSA sparse attention onto a 3B model on a single 4090

A dark wall of tiles with a sparse few glowing blue and amber: sparse attention keeping a handful of blocks and ignoring the rest.

The headline number is real. The useful findings aren't in the paper.


In June 2026 MiniMax open-sourced MSA, a sparse-attention method that ships with three claims. It cuts per-token attention compute 28.4x at a million tokens of context. It keeps model quality. And paired with their hand-written kernel it runs 14.2x faster on prefill.

The third claim is the one everyone quotes, and it's the one I can't check. That kernel, fmha_sm100, only builds for Blackwell datacenter GPUs: B200, GB200, the SM100 architecture. I have an RTX 4090. It is the wrong chip by two generations. The kernel won't compile, let alone run.

So the speedup is off the table for me, and for almost everyone who isn't renting B200 time by the hour. But only one of those three claims needs a B200. The compute reduction is arithmetic. Whether quality survives is a property of the algorithm, not the silicon. Both of those I can test on hardware I own.

So I did. I took MSA's selection rule, bolted it onto Qwen2.5-3B with no retraining, and ran it on the 4090. The 28.4x checked out and turned out to be the least interesting thing I found. The quality holds, on one condition the paper doesn't mention. And the design choice that makes the whole thing work is the one most people would have thrown away.


What MSA sparse attention is, in one paragraph

Attention compares every token against every earlier token. At a million tokens that's a trillion comparisons per layer, and almost all of them resolve to near-zero weight. MSA's bet is that you can skip the near-zero ones cheaply. It chops the past into blocks of 128 tokens, runs a lightweight pass that scores each block for the current query, keeps the 16 highest-scoring blocks, and does real attention only over those. Sixteen blocks is 2,048 tokens. At a million tokens of context you're attending to 0.2% of it. That fixed budget, held constant while the context grows, is where the 28.4x comes from.

MSA mechanism: score every block, keep the top-k plus the forced sink and local block, skip the rest.

MSA in one picture: score every 128-token block (max-pool of Q·K), then attend only to the top-16 plus the forced sink and local block. Everything else is skipped.


What I copied, and what I modified


The version of this experiment starts with what's mine and what's MiniMax's, because that line decides which numbers below you can lean on. MSA is two things stacked: a structure that picks which blocks of the past to look at, and a trained scorer that does the picking. I copied the structure exactly. I didn't have the trained scorer, so I replaced it with a signal the model already computes.

  • Sparsity structure — MiniMax: 128-token blocks, keep top 16, force the local + sink block. Mine: identical.

  • Block scorer — MiniMax: a trained Index Branch with its own learned projections. Mine: no training; I reuse the model's own query·key dot product as the relevance signal.

  • Base model — MiniMax: pretrained (or a 3T-token continued-pretraining run) with sparse attention on. Mine: stock Qwen2.5-3B, dense-trained, made sparse only at inference.

  • Kernel — MiniMax: fused fmha_sm100, Blackwell-only. Mine: plain PyTorch, correct but slow.

That first item is why the compute number is a true reproduction: FLOPs depend on how many blocks you attend to, not on how you chose them. The scorer and base-model rows are why “quality holds” is a softer claim than it reads. My scorer is dumber than MiniMax's by construction, so every quality result here is a floor, not a match. If the cheap heuristic keeps quality, the trained one should keep at least as much. That's a fair inference. Reproducing their exact quality numbers is not something this rig can do, and I won't pretend it does.

What keeps this from being a different method wearing MSA's name is that the findings worth keeping live in the structure I copied, not the scorer I improvised. The budget floor where retrieval falls to zero, and the pooling choice that decides whether it works at all, are both properties of the block structure. Real MSA pins the same blocks and pools the same way, so they carry straight over. The one result leaning on my stand-in is “quality survives,” and that's the one I hedge hardest.


The 28.4x is real, and it has a ceiling

I rebuilt the FLOP math first, partly to confirm the number and partly because one parameter, the index dimension, isn't stated outright in the paper. The 28.4x only lands if that dimension is 128, the same as the main attention heads. Guess 64 and you get 50x. It's a small thing, but it's the kind of small thing that tells you whether you actually understand the method or just retyped its abstract. The number checks out. It also has a ceiling I hadn't expected: the reduction tops out at 32x no matter how long the context gets, because the block-scoring pass keeps its own quadratic term. 28.4x at a million tokens is a point on a curve that's already flattening, not the start of an exponential.

That's the fact-check. It's reassuring and it's boring. The interesting part starts when you ask whether a model you already have can use this.


The retrofit, and the trick that makes it training-free

MSA, as MiniMax built it, is trained. The block-scoring pass has its own learned weights, and the published quality numbers come from either pretraining from scratch or a 3-trillion-token continued-pretraining run. That is not a thing a builder casually reproduces.

But the question a builder actually has is different: can I switch this on for a model I already run, today, without a training budget? The paper doesn't answer that. So I built the version that does.

Two pieces made it work. The first is a seam in how modern transformers are written. A Qwen attention layer doesn't hardcode its attention math; it looks the implementation up by name at every forward pass and calls whatever's registered under that name. You register your own function and flip one config string. No retraining, no surgery on the weights, no subclassing. The model starts calling your code instead of dense attention, and flips back to the fast path when you want the baseline.

The second piece is the substitution that makes it training-free. MSA's scoring pass has learned weights I don't have. The model, though, is already computing query-key dot products as part of normal attention, and a high dot product is exactly the signal I want: this key matters to this query. So I reuse it. Average the query heads in each group into one scoring query, take the dot product against the keys, and for each block keep the single highest token score as that block's relevance. That last detail, taking the max over the block rather than the average, looked like an arbitrary implementation choice when I read it. It is not. More on that below.

I flagged the boundary up top: my scorer is a stand-in, not MiniMax's trained Index Branch. What I'll add here is how I made sure the plumbing around it was honest, which came down to a single check: when the selection keeps every block, the output has to match dense attention exactly. It does, on CPU and GPU, on random tensors and on the real model, down to the last digit. When it skips nothing, it is the original model. When it skips, the only difference is the skip.


Does the quality survive?

Perplexity is the clean way to ask this. It measures how surprised the model is by real text, and it's sensitive to small degradations that a pass/fail task would miss. Lower is better. I ran the model over the same long documents, changing only how many blocks the attention was allowed to see.

  • Dense, full context: 7.045 perplexity — the baseline.

  • MSA k=32 (4,096 tokens): 7.045 — no change; at this length the budget covers everything, so it reduces to dense.

  • MSA k=16 (2,048 tokens): 7.055 — +0.14%, and this is MiniMax's deployed setting.

  • MSA k=8 (1,024 tokens): 7.110 — +0.92%.

  • MSA k=4 (512 tokens): 7.469 — +6.0%.

At k=16, the model is 0.14% more surprised by text while ignoring half the context, and the gap only widens in MSA's favor as the context grows past what a 4090 lets me measure here. Call that lossless. The degradation is smooth as you cut deeper: k=8 costs about a percent, k=4 costs six. So the second claim holds — quality survives the retrofit, no retraining required.

There's a catch that perplexity alone would have hidden. Swap my max-pool scorer for random block selection and perplexity barely moves (+3.7% at k=16), yet retrieval collapses to 0 out of 12. A perplexity-only evaluation would have waved a completely broken selector straight through.

Perplexity vs budget for max-pool and random selection, with the dense baseline.

Perplexity barely flinches at the failure it hides: random selection looks fine on perplexity but scores 0/12 on retrieval.


Where it breaks, and the floor nobody mentions

“Quality holds” is a claim that begs for a counterexample. I went looking for where it stops holding.

My first probe was the standard one: hide a single fact in a long document, ask for it back. It was useless. MSA retrieved the needle perfectly at every setting I tried, including k=4 at 32x sparsity over 16,000 tokens of context. That's not a flaw in MSA, it's a feature of how it selects. The scoring is query-aware: when you ask about the fact, the query lights up the block that contains it, and a query-aware top-k will almost always grab that block. Single-needle retrieval is the task MSA is built to win.

So I made it harder: twelve distinct facts scattered through the context, retrieve all of them, and swept the budget downward. The cliff showed up immediately.

Multi-needle recall vs budget k; recall is zero at k=2 and recovers by k=8.

The budget floor: recall collapses to zero at k=2, where the two forced blocks (local + sink) consume the whole budget.

At k=2, recall is zero. Not degraded, zero. MSA always forces two blocks into the budget regardless of score — the block the query sits in, and the first block of the sequence — because models lean on both. At k=2 those two forced blocks consume the entire budget, leaving no slot for the block that actually holds a fact. The model is structurally blind to its own context. Give it one free slot at k=3 and recall claws back to a third; two free slots at k=4 and it's almost whole.

This is the caveat that should travel with every “quality holds” sentence about block-sparse attention. The real constraint isn't the context length or the number of facts. It's that k has to clear the blocks the method pins for free. MiniMax ships k=16, which leaves fourteen free slots, miles above the floor — which is precisely why their quality holds. A builder turning k down to save compute should know the cliff is there, and where it is.


The design choice I assumed didn't matter

Here's where I was wrong, which turned out to be the best part.

I had a tidy hypothesis. The floor I'd just found — k has to clear the forced blocks — looked structural. It's about the budget arithmetic, not about how you score blocks, so I expected it to reproduce no matter how you picked them. To check, I swapped the scorer and kept everything else identical: MSA's max-over-the-block, a mean-over-the-block, an upper-bound estimator in the style of Quest, and random selection as a control.

The structural floor held exactly as predicted — every scorer is at zero recall at k=2. Everything above the floor was not.

Recall by scorer (max-pool, mean-pool, Quest-style, random) at k=4, 8, 16.

Above the floor, the pooling op decides everything: max-pool (MSA) recovers by k=8; mean-pool, Quest-style, and random do not.

MSA's max-pool recovers by k=4 and is perfect by k=8. Mean-pool is flat zero until k=16 and only reaches a coin flip; the Quest-style estimator does the same; random never recovers. The reason is the detail I'd flagged as arbitrary. A fact lives in one sentence inside a 128-token block, mostly filler. Max-pool scores the block by its single best-matching token, so that one sentence makes the whole block light up. Mean-pool averages that one strong token against 127 weak ones and washes the signal out below the selection threshold. The block that holds the answer never makes the cut.

So MSA's max-pooling isn't an implementation footnote. It's load-bearing — the thing that lets sparse attention find a localized fact at all. Change one pooling operation and a method that retrieves twelve out of twelve drops to zero out of twelve at the same budget. (To be precise: my mean-pool and Quest-style scorers are training-free stand-ins, not the full published NSA or Quest systems. This is a controlled test of one variable — the pooling operation — not a bake-off between methods.)


Fewer FLOPs, same wall clock

Back to the claim I started with, the speedup I can't run. I can't reproduce MiniMax's kernel, but I can show why the FLOP reduction doesn't hand you a speedup on its own. I built the genuinely sparse version — the one that gathers only the selected blocks — and timed it against the fused dense attention PyTorch already ships.

  • 2,048 tokens: 0.5x FLOP reduction (MSA does twice the work) — gather is 35x slower than fused dense.

  • 8,192 tokens: 1.8x fewer FLOPs — still 12x slower.

  • 32,768 tokens: 5.3x fewer FLOPs — still 3.5x slower.

Left: FLOP reduction vs context with a 32x ceiling. Right: gather slowdown vs fused dense.

Left: the 28.4x compute reduction is real but flattens to a 32x ceiling. Right: fewer FLOPs is still slower than fused dense on a 4090.

The sparse version does up to five times fewer floating-point operations at 32k and heads toward MiniMax's 28x at a million, yet it is slower than dense at every length I measured. The dense path is a fused kernel the hardware loves; my sparse path is honest PyTorch with a gather and a Python loop over tiles, and that loop is exactly the fusion MiniMax wrote a custom kernel to eliminate. The gap narrows as context grows, from 35x down to 3.5x, so somewhere past where a 4090 can reach, the arithmetic would win. On a B200, with their kernel, it does. On my desk it does not.

And below about 8k tokens MSA does more FLOPs than dense: the scoring pass has its own quadratic cost, and a 2,048-token budget is the entire context at that length, so you pay for the machinery and skip nothing. Sparse attention is a long-context tool; using it on short prompts is a straight loss.


What I'd tell a builder

The 28.4x is real and largely beside the point for anyone not holding a B200. What you can actually take from MSA, today, on hardware you own, is narrower and more practical. You can retrofit the selection rule onto an existing model without retraining and keep quality at a sane budget. That part works. But carry the conditions with it:

  • Keep k well above the blocks the method pins for free, or it goes blind to its own context — and the failure is silent and total, not gradual.

  • Keep the max-pool. The obvious-looking simplification to mean-pool quietly destroys retrieval.

  • Don't switch it on under about 8k tokens, where it does more work than the dense attention it's replacing.

None of those three are in the paper, because the paper is about a trained model on a datacenter GPU and they aren't problems at that scale. They're problems at mine — which is to say at the scale most people actually work. The headline benchmark answers “is this impressive.” These answer “should I turn it on,” which is the only question I had to begin with.

MSA is open-sourced by MiniMax. Source on GitHub. This is an independent, training-free reproduction on a single RTX 4090 — not MiniMax's trained system or their production benchmark.

 
 
 

Comments


Join our mailing list

Phenx Machine Learning Technologies – Custom AI Solutions Since 2018

info@phenx.io | Cincinnati, OH

© Phenx Machine Learning Technologies Inc. 2018-2025.

bottom of page