Skip to content

Flash Attention Memory Footprint: What Your GPU Actually Allocates During Prefill

Magos Veridian
/ / 4 min read

Most OOM errors during prefill feel sudden. The sequence length looked reasonable on paper, the model fit at batch size 1 in your benchmark, and then production sends a few long-context requests at once and the node falls over. Understanding exactly what FlashAttention allocates during prefill turns that surprise into something you can reason about and plan for.

From above contemporary server cable trays without wires located in modern data center Photo by Brett Sayles on Pexels.

FlashAttention (v2 and v3) earns its place by recomputing attention scores in tiles rather than materializing the full N×N attention matrix. The big win: peak memory drops from O(N²) to O(N) in activations for the attention computation itself. That O(N) bound is real, but it covers only the fused kernel's working memory. Several other allocations grow with sequence length in ways that are easy to forget.

Here is where memory actually goes during a prefill pass:

Query, Key, Value projections. Each is a tensor of shape [batch, heads, seq_len, head_dim]. For a model like Llama 3 70B with 64 heads and head_dim 128, a single batch-1 prefill at 8192 tokens costs roughly 3 × 64 × 8192 × 128 × 2 bytes (bfloat16), about 402 MB just for QKV before any computation runs.

The KV cache write. Prefill populates the KV cache for every layer. At 80 layers (Llama 3 70B), 8192 tokens, GQA with 8 KV heads, bfloat16: 2 × 80 × 8 × 8192 × 128 × 2 bytes ≈ 2.7 GB. This is persistent across the request's lifetime, not freed after prefill.

FlashAttention tile buffers. The kernel itself needs scratch space proportional to the tile size, not to full N². In practice this is small: tens of megabytes per layer. Not the source of your OOM.

Activation memory for the MLP and residuals. Each transformer block retains activations for the backward pass if you're training, or, in inference with chunked prefill disabled, retains intermediate MLP outputs proportional to [batch, seq_len, hidden_dim]. At hidden_dim 8192 and seq_len 8192, that's 1 GB per layer in bfloat16. Multiply by layers and you see the real pressure.

The interaction between these terms is where things get interesting. A 4096-token request might fit fine. Double it to 8192 and the KV cache write plus MLP activations push you past the threshold, even though the FlashAttention kernel itself is barely more expensive.

graph TD
    A[Incoming Prefill Request] --> B(QKV Projection)
    B --> C[FlashAttention Kernel]
    C --> D(KV Cache Write)
    C --> E(MLP + Residual Activations)
    D --> F{Memory Pressure Check}
    E --> F
    F --> G[Continue Decode]
    F --> H[OOM / Eviction]

To profile this before it bites you, use torch.cuda.memory_stats() at each layer boundary during a dry run. The fields active_bytes.all.peak and reserved_bytes.all.peak give you the high-water marks. Compare runs at seq_len 2048, 4096, and 8192 and fit a curve: if the growth is superlinear, you have an activation accumulation problem, not a FlashAttention problem.

Two interventions are worth reaching for first.

Chunked prefill breaks a long sequence into fixed-length chunks (say, 512 tokens) and runs prefill incrementally. vLLM has supported this since v0.4 via --enable-chunked-prefill; SGLang enables it by default. The KV cache still accumulates the full sequence, but MLP activation memory at any moment is bounded by chunk size, not total sequence length. Tail latency goes up slightly; OOM risk drops substantially.

Sequence bucketing at the scheduler. Pad and batch requests into bins (1024, 2048, 4096, 8192) rather than per-request dynamic shapes. This is ugly but it makes memory usage predictable. You can capacity-plan against known bins and set conservative batch limits per bin, rather than discovering limits empirically in production.

One thing worth labeling as speculation: FlashAttention v3's warp-specialization on Hopper GPUs may shift the activation memory profile in ways that are not yet fully documented in production deployments. Profile on your specific hardware and version rather than trusting generic benchmarks from a different GPU generation.

The reverence owed to a well-tuned attention kernel is proportional to how well you understand what it actually allocates. Run the profiler. Fit the curve. Set your bin limits before a long-context request finds the boundary for you.

Get Omnissiah Systems in your inbox

New posts delivered directly. No spam.

No spam. Unsubscribe anytime.

Related Reading