AI in Multiple GPUs: ZeRO & FSDP
Summary
The ZeRO (Zero Redundancy Optimizer) framework addresses memory redundancy in Distributed Data Parallelism (DDP) by partitioning model states across multiple GPUs, enabling the training of significantly larger models. While DDP improves throughput by splitting batches, it replicates model parameters, gradients, and optimizer states on every GPU, leading to substantial VRAM waste for models like GPT-3. ZeRO operates in three stages: ZeRO-1 partitions only optimizer states, ZeRO-2 partitions optimizer states and gradients, and ZeRO-3 partitions all three components, including model parameters. For a 7B-parameter model with 8 GPUs, ZeRO-3 reduces memory usage per GPU from 112 GB in vanilla DDP to 14 GB, an 8x reduction. PyTorch implements ZeRO-3 through its Fully Sharded Data Parallel (FSDP) module, with FSDP2 being the recommended, more optimized version.
Key takeaway
For AI architects and Machine Learning Engineers training large language models, implementing ZeRO-3 via PyTorch's FSDP2 is crucial for overcoming VRAM limitations. This memory optimization strategy allows you to train models with billions of parameters on existing hardware, significantly reducing the per-GPU memory footprint. While it increases communication, the memory savings are essential for scaling model size, making it a critical technique for advanced distributed training.
Key insights
ZeRO optimizes distributed AI training by sharding model states across GPUs, drastically reducing memory footprint.
Principles
- Memory redundancy is a key bottleneck in DDP.
- Partitioning model states reduces VRAM usage.
- Trade memory for communication overhead.
Method
ZeRO partitions optimizer states, gradients, and model parameters across GPUs. Parameters are gathered just-in-time for computation and then discarded, minimizing peak memory.
In practice
- Use PyTorch FSDP2 for ZeRO-3 implementation.
- Apply `fully_shard` layer-by-layer.
- Consider ZeRO for large models, not smaller ones.
Topics
- ZeRO Optimizer
- Distributed Data Parallelism
- GPU Memory Optimization
- Large Model Training
- PyTorch FSDP
Best for: Machine Learning Engineer, Deep Learning Engineer, AI Architect
Related on AIssential
Editorial summary, takeaway, and curation by AIssential. Original article published by Towards Data Science.