Optimizing Training with FlashAttention varlen
Summary
The `varlen` variant of FlashAttention significantly optimizes Transformer model training by processing batches of variable-length examples as one concatenated sequence, eliminating the need for padding and reducing wasted computation. Unlike the standard batch-dimension approach, which pads sequences to a `max_seqlen` or packs multiple documents into fixed-length rows, `varlen` calculates attention only within individual documents. This method avoids both "document bleed" (cross-document attention noise) and unnecessary FLOPs on padding tokens. For instance, in a scenario with 64 documents of 512 tokens each, `varlen` can reduce attention scores by 8x compared to a padded batch. While practical benefits are lower due to factors like windowed attention and GPU optimization for large matrix multiplications, `varlen` still offers a 1.8% speedup in specific benchmarks like nanochat d24.
Key takeaway
For Machine Learning Engineers optimizing Transformer training pipelines, adopting FlashAttention's `varlen` approach is crucial for efficiency. You should refactor your data loaders and model `forward` passes to use a 1D token buffer with `cu_seqlens` to define document boundaries, enabling `torch.compile` with `fullgraph=True` and `dynamic=False`. This will reduce wasted computation on padding and cross-document attention, potentially yielding measurable speedups, especially in highly optimized custom training setups.
Key insights
FlashAttention `varlen` optimizes Transformer training by processing variable-length sequences as a single buffer, eliminating padding and cross-document attention.
Principles
- Attention should only be calculated within document boundaries.
- Fixed-size tensors are crucial for `torch.compile` efficiency.
Method
Implement `varlen` by reshaping inputs to a 1D buffer, providing `cu_seqlens` (cumulative sequence lengths) to attention, and configuring `max_num_docs` and `max_seq_len` for `torch.compile` with `fullgraph=True` and `dynamic=False`.
In practice
- Use `cu_seqlens` to define document boundaries in the 1D buffer.
- Optimize `max_num_docs` to dataset characteristics for performance.
- Leverage RoPE for position encoding without context extension.
Topics
- FlashAttention
- Transformer Training
- Variable Sequence Lengths
- Computational Efficiency
- PyTorch Compilation
Code references
Best for: Machine Learning Engineer, Deep Learning Engineer, AI Engineer
Related on AIssential
Editorial summary, takeaway, and curation by AIssential. Original article published by Chris McCormick.