Skip to content

Activation Checkpointing: Trading Compute for Training Memory

Activation Checkpointing: Trading Compute for Training Memory

Section titled “Activation Checkpointing: Trading Compute for Training Memory”

Activation checkpointing, also called gradient checkpointing or recomputation, is a training optimization that saves memory by not storing every intermediate activation from the forward pass. During backpropagation, the missing activations are recomputed.

The tradeoff is direct:

Save activation memory, spend extra compute.

This is one of the most important techniques for training large transformers because activation memory grows with batch size, sequence length, hidden size, and layer count. When training long-context models, activation memory can dominate.


Training memory includes:

parameters
gradients
optimizer states
activations
temporary buffers
communication buffers
fragmentation / framework overhead

For Adam-style training, parameter-related memory can be large, but activations are the part that scales strongly with:

BLdmodelNLB \cdot L \cdot d_{model} \cdot N_L

where:

  • BB is microbatch size.
  • LL is sequence length.
  • dmodeld_{model} is hidden dimension.
  • NLN_L is number of layers.

Long-context training makes LL large, so activation checkpointing becomes mandatory.


Without checkpointing:

Forward:
layer 1 -> save activations
layer 2 -> save activations
layer 3 -> save activations
layer 4 -> save activations
Backward:
use saved activations for gradients

With checkpointing:

Forward:
layer 1 -> discard many activations
layer 2 -> save checkpoint
layer 3 -> discard many activations
layer 4 -> save checkpoint
Backward:
recompute missing forward activations
compute gradients

The model computes some forward operations twice, but peak memory drops.


Let:

  • MAM_A be activation memory without checkpointing.
  • MCM_C be activation memory with checkpointing.
  • FF be forward compute.
  • BB be backward compute.

Without checkpointing:

computeF+B\text{compute} \approx F + B

With checkpointing:

computeF+B+R\text{compute} \approx F + B + R

where RR is recomputation compute.

The goal is to reduce memory enough to increase batch size, sequence length, or model size, while keeping throughput acceptable.

Checkpointing is worth it when:

value of larger training configuration>cost of recompute\text{value of larger training configuration} > \text{cost of recompute}

For LLM training, that is often true.


Checkpointing granularity controls how much is saved.

Options:

  • Whole transformer block.
  • Attention sub-block.
  • MLP sub-block.
  • Every kk layers.
  • Selective checkpointing for memory-heavy operations.
  • Full recomputation.
Coarse checkpointing:
save at block boundaries
less bookkeeping
more recompute
Fine checkpointing:
save inside block
more control
less recompute
more complexity

Common transformer checkpoint boundary:

RMSNorm -> Attention -> Residual -> RMSNorm -> MLP -> Residual
^ ^
checkpoint boundary checkpoint boundary

Granularity should be chosen from profiling, not guesswork.


Not all activations cost the same. Some are cheap to recompute; others are expensive or numerically sensitive.

Good candidates for recompute:

  • LayerNorm/RMSNorm outputs.
  • MLP intermediate activations.
  • Attention projections.
  • Dropout masks if deterministic handling exists.

More delicate:

  • Attention softmax intermediates.
  • Random operations.
  • Custom kernels.
  • Operations with non-deterministic reductions.

FlashAttention already avoids storing the full attention matrix and recomputes pieces in backward. This is effectively an IO-aware recomputation strategy.

Staff-level point:

Activation checkpointing is not only a PyTorch flag. Modern attention kernels and training stacks already make selective recomputation decisions.


Checkpointing interacts with:

  • Tensor parallelism.
  • Pipeline parallelism.
  • Sequence/context parallelism.
  • ZeRO/FSDP.
  • Expert parallelism.

Pipeline parallelism stores activations for in-flight microbatches. More microbatches can mean more activation memory. Checkpointing reduces that pressure but increases recompute inside pipeline stages.

Sequence/context parallelism splits sequence activations across devices. Checkpointing and sequence partitioning often combine for long-context training.

FSDP/ZeRO reduce parameter/optimizer memory, not activation memory. You still need checkpointing for long sequences.


Checkpointing is too aggressive and recompute dominates.

Random operations are not replayed consistently.

Recompute increases stage time and pipeline bubbles.

Temporary buffers, communication buffers, or fragmentation dominate.

Intermediate activations are not saved, making numeric diffing more difficult.


Before enabling checkpointing broadly:

  • Measure activation memory by layer.
  • Separate parameter memory from activation memory.
  • Profile recompute overhead.
  • Choose checkpoint granularity intentionally.
  • Test determinism.
  • Check interaction with FlashAttention backward.
  • Benchmark tokens/sec, not only max batch size.
  • Validate loss curves after enabling.

The interview answer:

Activation checkpointing trades extra forward recomputation for lower activation memory. It is essential for large and long-context training, but the right granularity depends on profiling and on interactions with attention kernels, pipeline parallelism, and sequence/context parallelism.