Skip to content

FlashAttention and Kernel-Aware Model Optimization

FlashAttention and Kernel-Aware Model Optimization

Section titled “FlashAttention and Kernel-Aware Model Optimization”

Many model optimization ideas fail because they count FLOPs and ignore memory movement. GPUs are fast at arithmetic, but model inference often stalls on HBM bandwidth, cache layout, kernel launch overhead, synchronization, or bad tensor shapes.

Kernel-aware optimization asks:

Does this model architecture map efficiently to the kernels and hardware that will actually run it?

For staff-level work, this is the difference between paper speedup and production speedup. A model with fewer FLOPs can be slower if its operations are irregular. A model with the same FLOPs can be faster if kernels reduce HBM traffic, fuse operations, and keep tensor cores busy.


When kernel-aware optimization comes up, answer in this order:

  1. Bottleneck: compute, HBM bandwidth, kernel launch overhead, synchronization, or communication?
  2. Operator: attention, MLP, layernorm, softmax, sampling, routing, KV cache copy, all-reduce?
  3. Shape: batch size, sequence length, hidden size, head dimension, group size, alignment.
  4. Dtype: BF16, FP16, FP8, INT8, INT4, mixed precision.
  5. Kernel: cuBLAS, FlashAttention, FlashInfer, Triton, CUTLASS, TensorRT-LLM, custom kernel.
  6. System: continuous batching, CUDA graphs, tensor parallelism, KV paging, streaming.

The core idea:

Model architecture
|
v
Tensor shapes + dtype + memory layout
|
v
Kernel selection
|
v
GPU utilization and latency

Architecture and kernels are coupled. You cannot fully optimize one while ignoring the other.


Naive attention materializes large intermediate matrices. For sequence length LL:

QKTRL×LQK^T \in \mathbb{R}^{L \times L}

Materializing this matrix is expensive in memory. FlashAttention avoids writing the full attention matrix to HBM. It tiles the computation and performs online softmax so attention can be computed with much less memory movement.

Simplified:

Naive attention:
read Q,K,V
write scores to HBM
read scores
write softmax to HBM
read softmax,V
write output
FlashAttention:
stream Q,K,V tiles through SRAM
compute online softmax
write final output

The FLOPs are similar. The memory traffic is much lower. That is why FlashAttention can be much faster even without changing the model.

The staff-level point:

FlashAttention is not a different attention algorithm in model semantics. It is a different schedule for the same dense attention that reduces IO.


FlashAttention-1 showed the IO-aware attention idea. FlashAttention-2 improved parallelism and work partitioning. FlashAttention-3 targeted Hopper GPUs with better use of asynchronous execution and FP8 support.

Production takeaways:

  • Use mature attention kernels before writing custom ones.
  • Match kernel generation to hardware.
  • Check support for head dimension, causal masking, sliding windows, paged KV, and dtype.
  • Benchmark at real sequence lengths.

FlashAttention helps most in prefill and long-context dense attention. Decode can still be dominated by KV cache reads and memory bandwidth.


4. FlashInfer, FlashMLA, and Specialized Kernels

Section titled “4. FlashInfer, FlashMLA, and Specialized Kernels”

Modern serving engines often use specialized kernels beyond vanilla FlashAttention:

  • FlashInfer: optimized kernels for LLM serving, paged KV cache, sampling, attention variants, and decode workloads.
  • FlashMLA: kernels designed around MLA-style latent KV attention.
  • Triton/TileLang kernels: useful for custom attention patterns, sparse attention, MoE dispatch, and hardware-specific optimization.
  • TensorRT-LLM kernels: production inference kernels with graph-level optimizations.

The trend from 2025 onward is specialization:

Generic attention kernel
|
+-- dense prefill kernel
+-- paged decode kernel
+-- MLA kernel
+-- sparse attention kernel
+-- sliding-window kernel
+-- FP8 attention kernel

As models adopt MLA, sparse attention, MTP, and MoE, kernels must become model-aware.


Operator fusion combines multiple operations into one kernel to avoid unnecessary reads/writes.

Example unfused transformer path:

read x
RMSNorm -> write
read normed
Linear -> write
read linear output
Activation -> write
read activation
Linear -> write
Residual add -> write

Fused path:

read x once
RMSNorm + linear/activation/residual in fewer kernels
write fewer intermediates

Common fusions:

  • Bias + activation.
  • RMSNorm + residual.
  • LayerNorm + matmul.
  • QKV projection packing.
  • Rotary embedding fused into attention.
  • Dequantization + GEMM.
  • Sampling kernels.
  • MoE routing + dispatch pieces.

Fusion reduces HBM traffic and kernel launch overhead. It can also make debugging harder and reduce flexibility.


Not all dimensions are equally efficient. Tensor cores prefer aligned shapes. Head dimensions, hidden sizes, expert sizes, and quantization group sizes affect kernel choice.

Bad shape choices can cause:

  • Tensor core underutilization.
  • Padding overhead.
  • Kernel fallback.
  • Poor memory coalescing.
  • Extra reshapes/transposes.

Example:

Good:
hidden size aligned to tensor core tile
head dim supported by FlashAttention kernel
KV head count compatible with GQA kernel
Bad:
unusual head dim
awkward expert hidden size
quantization group not kernel-supported
sparse block size mismatched to hardware

Architecture search for efficient models should include hardware shape constraints, not just parameter count.


A dtype is useful only if kernels support it efficiently.

Examples:

  • BF16 is broadly supported on modern accelerators.
  • FP8 can be fast on Hopper/Blackwell-class hardware if kernels and scaling are correct.
  • INT4 weight-only can help decode memory bandwidth but may not help prefill as much.
  • FP8 attention requires careful scaling and numerics.
  • Sparse 2:4 support requires specific layouts and kernels.

Kernel coverage questions:

  • Does attention support this dtype?
  • Does GEMM support this quantization format?
  • Is dequantization fused?
  • Are group sizes supported?
  • Is KV cache dtype supported?
  • Does tensor parallelism still work?
  • Are fallback kernels visible in profiling?

If 20% of layers fall back to slow kernels, the whole optimization may disappoint.


CUDA graphs reduce launch overhead by capturing and replaying a fixed computation graph. This is useful for decode loops, where many small operations repeat.

The challenge is dynamic serving:

  • Variable batch sizes.
  • Variable sequence lengths.
  • Variable accepted tokens in speculative decoding.
  • MoE routing changes.
  • Dynamic KV cache pages.

Serving engines often bucket shapes or capture common decode shapes.

Dynamic requests
|
v
Shape bucketing
|
v
CUDA graph replay for common shapes
|
v
Lower launch overhead

CUDA graphs are not a model optimization, but they can meaningfully reduce per-token overhead.


Graph compilers can fuse, reorder, specialize, and lower operations:

  • torch.compile
  • XLA
  • TensorRT / TensorRT-LLM
  • TVM
  • Triton
  • vendor-specific compilers

Compilation works best when:

  • Shapes are stable.
  • Control flow is limited.
  • Kernels are supported.
  • Dynamic batching is managed.
  • Graph breaks are minimized.

Compilation struggles when:

  • Shapes change constantly.
  • Python control flow remains in the hot path.
  • Custom ops are unsupported.
  • Sparse routing creates dynamic dispatch.
  • Memory allocation is not controlled.

Production rule:

Compilation is not a magic speed button. It is a way to specialize a stable workload.


Do not optimize kernels from vibes. Profile.

Use:

  • Nsight Systems for timeline and CPU/GPU gaps.
  • Nsight Compute for kernel-level occupancy and memory.
  • PyTorch Profiler for operator breakdown.
  • Serving-engine metrics for queueing and batching.
  • Custom spans for prefill/decode/sampling/KV operations.

Checklist:

1. Separate prefill and decode.
2. Identify top kernels by time.
3. Check tensor shapes and dtype.
4. Look for fallback kernels.
5. Measure HBM bandwidth and tensor core utilization.
6. Check kernel launch gaps.
7. Verify batching and graph capture.
8. Re-benchmark end-to-end.

Kernel optimization without system profiling can make the wrong thing faster.


Good model choices often look boring:

  • Use GQA head counts supported by kernels.
  • Choose head dimensions supported by FlashAttention.
  • Align hidden sizes to tensor core tiles.
  • Pick MoE expert sizes that batch well.
  • Avoid exotic activations without fused kernels.
  • Choose quantization formats with serving support.
  • Design sparse patterns around block kernels.
  • Keep routing/layout simple enough for inference engines.

Bad model choices create “paper efficient” architectures that are hard to serve.

Interview phrase:

A model architecture is not production-efficient until its critical operators have fast kernels at the target shapes.


The model is memory-bound or launch-bound.

One unsupported shape or dtype silently falls back to a slow path.

The kernel is fast alone but bad under batching, KV paging, tensor parallelism, or streaming.

Fused kernels are fast but hard to debug or incompatible with new model variants.

Supporting new GPUs, dtypes, shapes, and model variants becomes a team tax.

FP8, low-bit dequantization, or reordered reductions change outputs enough to hurt quality.


  1. FlashAttention, FlashAttention-2, and FlashAttention-3.
    Core IO-aware attention sequence.

  2. FlashInfer.
    Practical LLM serving kernels.

  3. TensorRT-LLM.
    Production inference stack and kernels.

  4. Triton and TileLang.
    Useful for custom GPU kernels and modern attention variants.

  5. DeepSeek-V3 and DeepSeek-V3.2.
    Read as examples of model/kernel/hardware co-design pressure: MLA, MoE, FP8, MTP, and sparse attention.


Kernel-aware optimization is about making the model executable efficiently, not just mathematically smaller.

The checklist:

  • Profile before changing architecture.
  • Separate compute-bound from memory-bound work.
  • Use mature kernels before custom kernels.
  • Choose shapes and dtypes that hit fast paths.
  • Watch fallback kernels.
  • Fuse operations when memory traffic dominates.
  • Use CUDA graphs or compilation when shapes are stable.
  • Benchmark inside the real serving engine.
  • Re-check accuracy and numerics after kernel changes.

The interview answer:

FLOPs are not latency. FlashAttention wins by reducing memory traffic, not by changing attention semantics. Modern LLM optimization is model-kernel co-design: attention type, KV layout, dtype, tensor shape, batching, and compiler support all decide whether the theoretical optimization becomes production speed.