My 2 cents on Fusing GEMM + Top-K + Softmax on SM100
Mixture-of-Experts (MoE) routing is one of the most latency-sensitive operations in modern LLMs. Every forward pass computes a routing score matrix, selects the top-K experts, and softmax-normalises the weights before dispatching tokens. At inference scale this happens millions of times per second. Shaving microseconds here matters.
This post walks through two implementations of a fused GEMM + Top-K + Softmax kernel targeting NVIDIA's Blackwell (SM100) architecture, both achieving 3.5β8.5Γ speedup over PyTorch eager.
Code: blackwell_gemm_topk_softmax.cu and gemm_topk_epi.py
The Problem
A standard MoE routing step:
# A: token embeddings (M, K) B: expert weights (N, K)
scores = A @ B.T # (M, N) β full GEMM
vals, idx = torch.topk(scores, k=4) # (M, 4) β top-K per token
weights = torch.softmax(vals, dim=-1) # (M, 4) β normalise
The naive implementation forces the full (M, N) score matrix to round-trip through global memory between steps. For small N (8β128 experts) this is wasteful β the matrix is tiny but the overhead is constant.
The goal: compute the entire pipeline in a single kernel pass, processing scores while they are still in the on-chip accumulator, never writing them to global memory.
Background: The Blackwell GEMM Pipeline
On SM100, a GEMM kernel follows this pipeline:
A tiles β SMEM β UMMA (MMA) β TMEM (Tensor Memory, on-chip accumulator)
B tiles β SMEM β
β
Epilogue (our fusion point)
β
D β Global Memory
TMEM is new in Blackwell β a dedicated high-bandwidth on-chip buffer that holds the MMA accumulator. The epilogue reads from TMEM via a T2R (TMEM-to-register) copy and processes the result before writing to global memory. This is the only window to operate on the full accumulator without a global memory round-trip.
Approach 1: CUTLASS Epilogue Visitor Tree
CUTLASS 3.x uses an Epilogue Visitor Tree (EVT) β a compile-time tree of operation nodes that compose into a single epilogue pass. The Colfax Research blog describes it well:
"An epilogue visitor tree is a collection of visitors organized in a tree that collectively operate as a single visitor."
Each node implements two callbacks:
visit()β called per fragment, receives the accumulator, returns a transformed valuereduce()β called once per tile, can do cross-element reduction and modify results in-place
Nodes compose as: Sm90EVT<OuterNode, InnerTree>, where the inner tree's output feeds the outer node.
For our use case the tree is:
Sm100TopKSoftmaxColReduction β outer: top-K selection + masked softmax
βββ Sm90LinearCombination β inner: alpha * acc
βββ Sm90ScalarBroadcast(alpha)
βββ Sm90AccFetch
This maps directly to D = softmax(top_k(alpha * A @ B^T)).
The SM100 Compatibility Problem
CUTLASS ships LinCombTopKSoftmaxCol for Hopper (SM90). For Blackwell (SM100) it silently produces wrong results. Adding debug prints to get_consumer_store_callbacks() that run the SM90 layout algebra with SM100's actual inputs reveals exactly what goes wrong:
[SM90 layout algebra on SM100 inputs]
tiled_copy (T2R) TiledNumThr : 128
lane_layout_MN size_M=32 size_N=1 β no lanes in N direction
warp_layout_MN size_M=4 size_N=1
tCrColReduce layout size : 16384 β full tile, N-stride not collapsed
[SM100 fix via args.tCcD]
tCcD total size : 128
element[0] coord (m,n) : (0, 0) β row 0, col 0 β
element[1] coord (m,n) : (0, 1) β row 0, col 1 β
The SM90 visitor builds tCrTopK's layout by composing three pieces: the UMMA thread partition of a zero-N-stride tensor (gColReduce), the T2R retile, and the accumulator register layout. On SM90 this composition preserves the zero-N-stride β all N elements of the same row map to the same tCrTopK entry, so every insertion updates the shared row accumulator.
On SM100, the UMMA and T2R have different thread partitions. The composition does not preserve the zero-N-stride. tCrColReduce layout size = 16384 β the full tile β meaning each of the 128 per-thread elements maps to a separate tCrTopK entry instead of a shared one. Each entry sees only the single element inserted at its specific N position; the "2nd largest" (top_k_[1]) stays at ββ.
The butterfly reduction loop for (j = 1; j < size<1>(lane_layout_MN); j *= 2) runs zero iterations because size<1>(lane_layout_MN) = 1. After reduce_final(), every entry has min_ = -β. The mask check val >= min_ is always true β every element passes, every output is non-zero instead of K.
The Fix: Sm100TopKSoftmaxColReduction
Instead of the broken layout algebra, the custom node uses args.tCcD β the coordinate tensor that maps each register element to its (m, n) position in the tile. This is computed directly from thread_t2r.partition_D(identity_coord_tensor) and is always correct:
template<bool ReferenceSrc, class... Args>
CUTLASS_DEVICE auto
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args)
{
// Use args.tCcD directly β correct on both SM90 and SM100.
// args.tiled_mma / args.tiled_copy produce wrong layouts on SM100 UMMA.
auto args_tuple = cute::make_tuple(args.tCcD, args.residue_tCcD);
return ConsumerStoreCallbacks<decltype(args_tuple)>(
cute::move(args_tuple), params);
}
The visit() callback simply converts fp16 β fp32 and passes through. All the work happens in reduce():
CUTLASS_DEVICE void
reduce(STensor&&, SyncFn const&, int epi_m, int epi_n, bool,
VTensor visit_results)
{
auto tCcD_mn = tCcD(_,_,_,epi_m,epi_n);
// Pass 1: insertion sort to find top-K among in-bounds elements
Array<float, TopK> top_k;
top_k.fill(-cutlass::platform::numeric_limits<float>::infinity());
for (int epi_v = 0; epi_v < cute::size(visit_results); ++epi_v) {
auto& frag = visit_results(epi_v);
for (int i = 0; i < FragmentSize; ++i) {
auto coord = tCcD_mn(epi_v * FragmentSize + i);
if (elem_less(coord, residue_tCcD)) { // skip OOB columns
float val = frag[i];
for (int k = 0; k < TopK; ++k) {
if (val > top_k[k]) {
for (int l = TopK-1; l > k; --l) top_k[l] = top_k[l-1];
top_k[k] = val;
break;
}
}
}
}
}
// Numerically stable softmax denominator
float max_v = top_k[0];
float sum_exp = 0.f;
for (int k = 0; k < TopK; ++k)
sum_exp += cutlass::fast_exp(top_k[k] - max_v);
float min_thresh = top_k[TopK - 1];
// Pass 2: masked softmax in-place on visit_results
for (int epi_v = 0; epi_v < cute::size(visit_results); ++epi_v) {
auto& frag = visit_results(epi_v);
for (int i = 0; i < FragmentSize; ++i) {
auto coord = tCcD_mn(epi_v * FragmentSize + i);
if (elem_less(coord, residue_tCcD)) {
float val = frag[i];
frag[i] = (val >= min_thresh)
? cutlass::fast_exp(val - max_v) / sum_exp
: 0.f;
} else {
frag[i] = 0.f;
}
}
}
}
visit_results is a register buffer view β modifying it in-place changes what gets written to SMEM and then via TMA to global memory. No intermediate global write of D happens.
Two Silent Traps
The fp16 trap. On SM100, IsDirectR2S = true forces RegisterElementD = ElementD. If ElementD = fp16, loading tRS_rD and storing to a Float32 register tensor without .to(Float32) reinterprets the fp16 bit pattern as fp32 β the value 37.0 (fp16 bits β0x5124) becomes β3Γ10β»β΄Β³ in fp32. This requires ElementD = float in the kernel configuration.
The NaN trap. Encoding a column index into the low bits of -inf (0xFF800000) produces a NaN (e.g. 0xFF800001). NaN poisons fmax in comparisons. The fix: only encode in-bounds elements, leave OOB as clean -inf.
The FusionCallbacks Override
CUTLASS has a blanket rule: FusionCallbacks<Sm100TmaWarpSpecialized, ANY_OP> aliases to the SM90 version. Adding a more-specific partial specialisation from user code overrides it β no CUTLASS headers modified:
// In blackwell_gemm_topk_softmax.cu β injected into cutlass::epilogue::fusion namespace
template<
int StagesC, int StagesD, int FragmentSize, bool ReuseSmemC, bool DelayTmaStore,
int TopK, class ElementOutput, class ElementCompute, ...,
class CtaTile, class EpiTile>
struct FusionCallbacks<
epilogue::Sm100TmaWarpSpecialized<StagesC, StagesD, FragmentSize, ReuseSmemC, DelayTmaStore>,
fusion::LinCombTopKSoftmaxCol<TopK, ElementOutput, ElementCompute, ...>,
CtaTile, EpiTile>
: Sm100LinCombTopKSoftmaxCol<...> // our fixed node
{
// ... Arguments struct with alpha/beta ...
};
C++ template resolution always picks the more-specific partial specialisation.
Data Flow
A (fp16) Γ B (fp16)
β
SM100 UMMA + TMEM pipeline
β
tTR_rAcc β fp32 accumulator in TMEM
β
visit(): alpha * acc β tRS_rD registers (fp32, ElementD=float required)
reduce(): tCcD coords β top-K sort β masked softmax β in-place on registers
β
R2S copy: registers β SMEM
β
TMA store: SMEM β D (fp32, global)
shape: (M, N) β zeros everywhere except the K selected positions
Approach 2: Quack Composable Epilogue (CuTe-DSL)
Quack is a library of high-performance CUDA kernels written in CuTe-DSL. It uses a composable mixin system for epilogues.
Architecture
GemmTopKSm100
β GemmTopKEpiMixin (overrides epilogue hooks for TopK)
β GemmSm100 (SM100 mainloop + pipeline)
β GemmDefaultEpiMixin (base alpha/beta epilogue)
β ComposableEpiMixin (EpiOp dispatch)
The ComposableEpiMixin defines a lifecycle equivalent to EVT nodes:
# once per CTA tile
epi_begin(...) # allocate register buffers, capture coordinates
for each N-subtile:
epi_begin_loop(...) # slice coord tensor to current subtile
epi_visit_subtile(params, epi_loop_tensors, tRS_rD) # process subtile
epi_end(...) # decode, write output
epi_visit_subtile = EVT visit(), epi_end = EVT reduce() + write.
TopKCoordOp
The key composable unit is TopKCoordOp, an EpiOp that manages the top-K accumulator:
class TopKCoordOp(EpiOp):
@cute.jit
def begin(self, gemm, param, smem_tensor, ctx):
tile_M, tile_N = gemm.cta_tile_shape_mnk[:2]
# partition_for_epilogue_fn maps each register element to its (m,n) coord
coord = ctx.partition_for_epilogue_fn(
cute.make_identity_tensor((tile_M, tile_N))
)
topk_acc = cute.make_rmem_tensor((gemm.topk_k,), Float32)
topk_acc.fill(-topk_acc.element_type.inf) # init running top-k to -inf
return (topk_acc, coord)
@cute.jit
def begin_loop(self, gemm, state, epi_coord):
topk_acc, coord = state
# Slice coord to current (epi_m, epi_n) N-subtile
coord_sub = cute.group_modes(coord, 3, cute.rank(coord))[None, None, None, epi_coord]
return (topk_acc, coord_sub)
@cute.jit
def end(self, gemm, param, state, epi_tile, tiled_copy_t2r,
tiled_copy_r2s, tile_coord_mnkl, varlen_manager, tidx):
mValues, mIndices = param
topk_acc, coord = state
# ... decode column indices from low bits, apply softmax, autovec_copy to output
The Epilogue Subtile Problem
The SM100 epilogue tile is (128, 32) for a 128Γ128 CTA β 4 N-subtile iterations. Without accumulation, the last subtile (columns 96β127) would overwrite the first three. The topk_acc register buffer persists across all calls via Python closures over the register tensor:
@cute.jit
def epi_visit_subtile(self, params, epi_loop_tensors, tRS_rD, tRS_rC=None):
# topk_acc persists across all N-subtile calls (updated via merge each time)
topk_acc, coord_sub = epi_loop_tensors["topk_coord"]
# tRS_rD may be fp16 on SM100 β .to(Float32) is essential
rD_f32 = cute.make_rmem_tensor(tRS_rD.layout, Float32)
rD_f32.store(tRS_rD.load().to(Float32))
if const_expr(params.alpha is not None):
alpha = utils.load_scalar_or_pointer(params.alpha)
for i in cutlass.range(cute.size(rD_f32), unroll_full=True):
rD_f32[i] = rD_f32[i] * alpha
# IMPORTANT: do NOT encode OOB elements β modifying low bits of -inf creates NaN
rD_i32 = cute.recast_tensor(rD_f32, Int32)
for i in cutlass.range(cute.size(rD_i32), unroll_full=True):
if coord_sub[i][1] < N:
col = Int32(coord_sub[i][1])
enc = (~col if rD_f32[i] >= Float32(0) else col) & idx_mask
rD_i32[i] = (rD_i32[i] & ~idx_mask) | enc
else:
rD_f32[i] = -Float32.inf
# Local top-k for this N-subtile (warp_width=1: one row per thread on SM100)
subtile_topk = bitonic_topk(rD_f32, k, warp_width=1)
# Merge into running accumulator β after 4 subtiles: global top-k across all N columns
bitonic_topk_merge(topk_acc, subtile_topk)
return None # decode + write happens in TopKCoordOp.end()
Data Flow
A (fp16) Γ B (fp16)
β
SM100 UMMA + TMEM pipeline
β
tTR_rAcc β fp32 accumulator in TMEM
β
epi_visit_subtile() β called 4Γ (one per 32-column N-subtile):
fp16 β fp32 conversion (.to(Float32))
alpha scaling
column index encoded into low mantissa bits
bitonic_topk(subtile, k, warp_width=1) β intra-thread, no warp shuffle
bitonic_topk_merge(topk_acc, subtile_topk) β accumulate across subtiles
β
epi_end() / TopKCoordOp.end():
decode column indices from low mantissa bits
optional softmax over top-k values
autovec_copy: registers β global memory (direct, no SMEM staging)
β
values (M, k) fp16 β only the K selected softmax weights
indices (M, k) int32 β their column positions (expert IDs)
Both CUTLASS and quack compute the same operation. The difference is only in the output layout.
After all 4 subtiles, epi_end calls TopKCoordOp.end() which decodes the column indices (encoded in the low logβN mantissa bits), applies softmax, and writes via autovec_copy directly from registers to the compact (M, k) output:
# In TopKCoordOp.end():
vals_store = cute.tiled_divide(topk_vals_out, (vecsize_out,))
indices_store = cute.tiled_divide(topk_indices, (vecsize_out,))
mVals_store = cute.tiled_divide(mValues[row_abs, None], (vecsize_out,))
mIdx_store = cute.tiled_divide(mIndices[row_abs, None], (vecsize_out,))
for i in cutlass.range(k // vecsize_out, unroll_full=True):
cute.autovec_copy(vals_store[None, i], mVals_store[None, i])
cute.autovec_copy(indices_store[None, i], mIdx_store[None, i])
Column Coordinate via Bitonic Sort with Index Encoding
To know which column a value came from after the bitonic sort, the column index is encoded into the low logβN mantissa bits before sorting. Since the mantissa perturbation is tiny relative to the value magnitude, the sort order is preserved:
col = Int32(coord_sub[i][1]) # absolute N column from coordinate tensor
enc = (~col if value >= 0 else col) # invert for positive values so earlier col wins on ties
rD_i32[i] = (rD_i32[i] & ~idx_mask) | (enc & idx_mask)
After sorting, decoding recovers the original column:
enc_idx = Int32(topk_i32[i]) & idx_mask
topk_i32[i] = Int32(topk_i32[i]) & ~idx_mask # clear β clean float
col = (~enc_idx if value >= 0 else enc_idx) & idx_mask
One Row Per Thread
On SM100 with SM100_TMEM_LOAD_32dp32b128x and a 128Γ128 tile: 128 threads Γ 128 fp32 elements each = 16384 total. The T2R copy distributes one complete M row (all N columns) to each epilogue thread. Therefore bitonic_topk with warp_width=1 (no cross-thread shuffle) is correct and sufficient.
Benchmark Results
Hardware: NVIDIA Blackwell (SM100)
Configuration: Top-4, softmax=True, fp16 inputs
Problem M N K cutlass_fused quack_fused torch_eager torch_compile
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
MoE-8 routing 512 8 128 0.0215 ms 0.0279 ms 0.1352 ms 0.1847 ms
MoE-16 routing 512 16 128 0.0236 ms 0.0282 ms 0.1372 ms 0.1826 ms
MoE-64 routing 1024 64 512 0.0323 ms 0.0284 ms 0.1362 ms 0.2097 ms
MoE-128 routing 2048 128 1024 0.0380 ms 0.0285 ms 0.1446 ms 0.2109 ms
Large M, N=64 4096 64 2048 0.0369 ms 0.0284 ms 0.1425 ms 0.2112 ms
Large M, N=128 4096 128 2048 0.0410 ms 0.0281 ms 0.1604 ms 0.2106 ms
Speedup over torch_eager:
Problem cutlass_fused quack_fused
βββββββββββββββββββββββββββββββββββββββββββββββββ
MoE-8 routing 6.3Γ 4.8Γ
MoE-16 routing 5.8Γ 4.9Γ
MoE-64 routing 4.2Γ 4.8Γ
MoE-128 routing 3.8Γ 5.1Γ
Large M, N=64 3.9Γ 5.0Γ
Large M, N=128 3.9Γ 5.7Γ
Both fused kernels are 3.5β8.5Γ faster than PyTorch eager. torch.compile is consistently slower than eager here β it replaces the cuBLAS GEMM with Triton for small N, which loses to cuBLAS on tiny shapes.
Why One is Faster Than the Other
Small N (8β16): cutlass_fused wins
At N=8 the GEMM compute dominates. Both kernels spend the same time on MMA, but cutlass's CUTLASS collective builder selects a well-tuned epilogue schedule (pipelined TMA, optimal staging). Quack's bitonic sort + merge across 4 N-subtiles adds register pressure that exceeds the output-write savings when the output is tiny (8 values/row Γ 512 rows = 4KB total).
Large N (64β128): quack_fused wins
At N=128, cutlass_fused writes a full dense (M, N) fp32 matrix to global memory, while quack_fused writes only the compact (M, k) values and indices.
| Kernel | Output | Size for (M=2048, N=128, k=4) |
|---|---|---|
| cutlass_fused | Dense (M, N) fp32 |
2048 Γ 128 Γ 4B = 1 MB |
| quack_fused | Compact (M, k) fp16 + int32 |
2048 Γ 4 Γ 6B = 48 KB |
That's a 20Γ difference in output write bandwidth. The TMA store of the full D matrix becomes the bottleneck at large MΓN, while quack's direct registerβglobal autovec_copy for 4 values/row scales independently of N.
Additionally, quack's timing is nearly flat (~0.028β0.029ms) across all problem sizes β it is kernel-launch-latency bound, not compute or memory bound. The SM100 GEMM + compact write finishes so quickly that dispatch overhead dominates.
Crossover ~N=32β64
The crossover point is where output-write bandwidth starts dominating over per-subtile epilogue overhead, roughly at N=32β64 for these M and K values.
Output Format
Both kernels compute the same thing: softmax(top_k(alpha * A @ B^T)). They differ only in how the result is stored.
cutlass_fused scatters the K softmax weights back into a full (M, N) fp32 matrix, with zeros at all non-selected positions. The epilogue writes this through the standard SMEMβTMA pipeline.
quack_fused skips the scatter entirely. It writes only the K selected values and their column indices, giving compact (M, k) tensors. For MoE dispatch this is the natural format β you need the routing weights and expert IDs, not the full N-dimensional vector.
Conclusion
Both kernels fuse the complete MoE routing pipeline β GEMM accumulation, top-K selection, softmax normalisation β into a single SM100 kernel pass. Neither the (M, N) score matrix nor any intermediate result touches global memory.
And I should also thank claude since it helped me dig into cutlass template multiverse and understand every nitty gritty details while implementing these kernels, also it helped me quack out my way in understanding EpiOp.
The key engineering insights that make this work on SM100:
- Use of
args.tCcDnot layout algebra β the coordinate tensor is always correct; UMMA layout arithmetic for the SM90 TopK visitor is not ElementD = floatrequired βIsDirectR2S=trueon SM100 makesRegisterElementD = ElementD; mismatched types silently corrupt values- No NaN from encoding
-infβ only encode in-bounds elements - Per-subtile accumulation β the epilogue tile is
(128, 32), not(128, 128); a running register buffer accumulates top-K across 4 N-subtile iterations warp_width=1β one complete row per thread on SM100; no cross-warp shuffle needed