Resilient Large-Scale Training: Integrating TorchFT with TorchTitan on AMD GPUs
Summary
AMD has integrated PyTorch's native fault-tolerance framework, TorchFT, with the TorchTitan training framework on its Primus-SaFE Kubernetes platform to enable resilient, checkpoint-less training for large AI models on AMD GPUs. This solution addresses the limitations of traditional checkpoint-and-restart mechanisms, which incur significant overhead and waste computation at scale. The architecture decouples intra-group parallelism (FSDP2, TP, PP) from inter-group fault tolerance, allowing healthy replica groups to continue training independently during failures. A lightweight TorchFT Lighthouse service coordinates membership and step synchronization across groups, facilitating dynamic recovery where failed nodes can quickly rejoin by synchronizing state from a healthy peer. This system was validated on a 4-node cluster of AMD Instinct MI325X GPUs, demonstrating sub-millisecond checkpoint staging and 0.56-second FT checkpoint load times for a Llama 3 8B model.
Key takeaway
For MLOps Engineers deploying large-scale AI model training on AMD GPUs, adopting the TorchFT and TorchTitan integration on Primus-SaFE can significantly enhance training stability and efficiency. This approach minimizes wasted computation from failures and maximizes GPU utilization by enabling dynamic, checkpoint-less recovery. You should explore configuring your training jobs with `--fault_tolerance.enable` and leverage Primus-SaFE's elastic scaling to ensure continuous progress even when hardware failures occur.
Key insights
Decoupling parallelism from fault tolerance enables resilient, checkpoint-less distributed training for large AI models.
Principles
- Hardware failures are expected at large scale.
- Decouple parallelism from fault tolerance.
- Peer-to-peer state transfer is faster than storage.
Method
Integrate TorchFT with TorchTitan on Kubernetes, using a Lighthouse coordinator for membership and step synchronization, allowing healthy replica groups to continue training while failed groups recover and rejoin via peer-to-peer state transfer.
In practice
- Deploy TorchFT Lighthouse as a Kubernetes service.
- Configure replica groups with unique IDs and group size.
- Utilize Primus-SaFE for automated workload management.
Topics
- Large-Scale Training
- Fault Tolerance
- AMD GPUs
- PyTorch
- Kubernetes
Code references
Best for: Machine Learning Engineer, Deep Learning Engineer, MLOps Engineer
Related on AIssential
Editorial summary, takeaway, and curation by AIssential. Original article published by AMD ROCm Blogs.