Skip to content

Distributed Training for Large Language Models

Technical Overview

Training a GPT-3-scale model (175 billion parameters) on a single GPU is physically impossible: a single NVIDIA A100 80GB holds at most ~11B parameters in mixed precision, and even if it could fit, the compute would require tens of thousands of GPU-years. Distributed training partitions work across hundreds or thousands of GPUs simultaneously, combining multiple orthogonal parallelism strategies to make trillion-parameter models tractable. The field has converged on a "3D parallelism" approach—combining Data Parallelism, Tensor Parallelism, and Pipeline Parallelism—pioneered by projects like Megatron-LM and DeepSpeed.

Prerequisites

  • Understanding of neural network training: forward pass, backward pass, gradient descent
  • Familiarity with PyTorch tensors, autograd, and nn.Module
  • Basic understanding of GPU memory hierarchy and CUDA programming model
  • Knowledge of collective communication primitives (AllReduce, AllGather)
  • Understanding of transformer architecture (attention, MLP layers)

Core Content

Why a Single GPU Is Insufficient

Memory constraints: GPT-3 has 175B parameters. In FP32 that is 700 GB. Even in BF16 that is 350 GB. An A100 has 80 GB HBM. Training requires storing activations (for backward pass), optimizer states (Adam: 2× the parameter count in FP32), and gradients—typically 16–20× the parameter memory. A 175B model needs ~2.8 TB of GPU memory minimum.

Compute constraints: GPT-3 training used approximately 3.14×10²³ FLOP. An H100 delivers 1,979 TFLOP/s (BF16 tensor core). At 50% MFU (Model FLOP Utilization, a realistic efficiency figure), training would take 100 days on a single H100—economically and practically infeasible.

Data Parallelism (DP)

Data Parallelism is the simplest and most widely used approach. Each GPU holds a complete copy of the model. The training minibatch is split across GPUs. Each GPU computes a forward and backward pass on its shard. Gradients are synchronized (AllReduce) across all replicas. All replicas stay in sync.

PyTorch DDP (DistributedDataParallel):

model = MyModel().cuda()
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank])
# DDP wraps the model; gradient sync happens automatically on backward()

DDP internals: 1. During backward(), DDP hooks fire as gradient computation completes for each nn.Parameter. 2. Gradients are bucketed (default: 25 MB buckets) to amortize AllReduce overhead. 3. AllReduce runs asynchronously, overlapping with remaining backward computation. 4. All-reduce uses ring-AllReduce via NCCL. Each rank ends with the average gradient. 5. The optimizer step then updates parameters identically on all ranks.

Scaling limit: DP works well for models that fit on a single GPU. For 175B+ models, you need other strategies. Communication overhead scales as O(parameter_count × bytes_per_param), which becomes prohibitive without high-bandwidth interconnects.

Model Parallelism

When a model does not fit on one GPU, its parameters must be split across multiple GPUs.

Tensor Parallelism (TP) — Megatron-LM Style

Tensor Parallelism splits individual weight matrices across GPUs. Developed by NVIDIA in Megatron-LM (Shoeybi et al., 2019).

Column-parallel linear layer: For a linear layer Y = XA where X is [B, K] and A is [K, N]: - Split A column-wise: A = [A₁ | A₂ | ... | A_p] where each Aᵢ is [K, N/p] - Each GPU i holds Aᵢ and computes Yᵢ = X × Aᵢ locally - Result Yᵢ on GPU i is a [B, N/p] partial output

Row-parallel linear layer: For Y = XA where X is [B, N]: - Split A row-wise: A = [A₁ᵀ | A₂ᵀ | ...]ᵀ where each Aᵢ is [N/p, K] - Also split input X = [X₁ | X₂ | ...] where each Xᵢ is [B, N/p] - Each GPU i computes partial Yᵢ = Xᵢ × Aᵢ - Final Y = Σ Yᵢ requires AllReduce across TP ranks

Transformer TP layout (Megatron-LM):

Attention Layer:
  Q,K,V projections: column-parallel (split heads across GPUs)
  Output projection: row-parallel (AllReduce output)

MLP Layer:
  First linear (up-projection): column-parallel
  Activation (GELU/SwiGLU): local
  Second linear (down-projection): row-parallel (AllReduce output)

Communication cost: 2 AllReduce operations per transformer layer (one after attention, one after MLP). Each AllReduce moves [B × S × H] bytes where H is the hidden dimension. For H=8192, B=1, S=2048, one AllReduce = 2×1×2048×8192×2 bytes = 64 MB.

Pipeline Parallelism (PP)

Pipeline Parallelism assigns consecutive transformer layers to different GPUs. GPU 0 holds layers 0–7, GPU 1 holds layers 8–15, etc.

Naive pipeline (F-then-B):

GPU0: [F0][F0][F0][F0]  idle  idle  idle  [B0][B0][B0][B0]
GPU1:  idle[F1][F1][F1][F1]  idle  idle  idle[B1][B1][B1]
GPU2:  idle  idle[F2][F2][F2][F2]  idle  idle  idle[B2][B2]
GPU3:  idle  idle  idle[F3][F3][F3][F3]  idle  idle  idle[B3]
         ┗━━━━━━━━━━━━━━━pipeline bubble━━━━━━━━━━━━━━━━━━━┛

Pipeline bubble = (p-1)/p where p is the number of pipeline stages. For p=8, 87.5% of time is wasted.

GPipe (Google, 2018): Splits mini-batch into M micro-batches. Pipeline bubble reduces to (p-1)/(p+M-1). For M=8 micro-batches and p=4 stages, bubble = 3/11 = 27%.

PipeDream 1F1B schedule (CMU, 2018): Alternates one forward micro-batch and one backward micro-batch per stage after pipeline fill. Steady state: each GPU is always computing either F or B. Bubble = (p-1)/M. Requires storing p copies of activations (one per in-flight micro-batch).

Virtual pipeline stages (Megatron-LM v2): Assign non-contiguous layer sets to each GPU (e.g., GPU 0 gets layers 0–1 and layers 8–9). Doubles the number of pipeline stages while keeping GPUs the same. Reduces bubble further but increases communication (pipeline send/receive every V layers).

ZeRO: Zero Redundancy Optimizer

ZeRO (Rajbhandari et al., DeepSpeed 2020) observes that in data parallelism, every GPU holds a complete copy of optimizer states, gradients, and parameters—a massive redundancy. ZeRO partitions these across DDP workers.

ZeRO stages:

                     Memory per GPU (GPT-3 175B, Adam optimizer)
                     ┌──────────────────────────────────────────┐
ZeRO-0 (Baseline DP) │ Params: 350 GB + Grads: 350 GB + Opt: 1400 GB = 2100 GB │
                     ├──────────────────────────────────────────┤
ZeRO-1               │ Partition optimizer states across N GPUs │
                     │ Per-GPU = 350 + 350 + 1400/N             │
                     ├──────────────────────────────────────────┤
ZeRO-2               │ Partition optimizer states + gradients   │
                     │ Per-GPU = 350 + (350+1400)/N             │
                     ├──────────────────────────────────────────┤
ZeRO-3               │ Partition params + gradients + opt states│
                     │ Per-GPU = (350+350+1400)/N = 2100/N      │
                     │ At N=1024 GPUs: ~2 GB per GPU            │
                     └──────────────────────────────────────────┘

ZeRO-3 forward pass: Each GPU holds only 1/N of parameters. Before computing each layer, an AllGather broadcasts the layer's parameters to all GPUs. After the layer, non-owner GPUs discard the parameters. This trades memory for communication: each parameter is communicated once forward + once backward = 2× parameter volume per iteration.

ZeRO-Infinity (2021): Extends ZeRO-3 to NVMe SSD storage. Parameters not in use are offloaded to CPU DRAM or NVMe, enabling models larger than aggregate GPU memory.

3D Parallelism

Combining DP + TP + PP enables training models at scales requiring thousands of GPUs.

3D Parallelism Layout (TP=4, PP=4, DP=2 → 32 GPUs total)

Pipeline Stage 0 (layers 0-7)        Pipeline Stage 1 (layers 8-15)
┌──────────────────────┐              ┌──────────────────────┐
│  TP Rank 0 │ TP Rank 1 │            │  TP Rank 0 │ TP Rank 1 │
│  TP Rank 2 │ TP Rank 3 │            │  TP Rank 2 │ TP Rank 3 │
│  (DP Replica A)       │  ─────────▶ │  (DP Replica A)       │
└──────────────────────┘              └──────────────────────┘
         ‖ gradient sync                        ‖ gradient sync
         ‖ (AllReduce)                           ‖ (AllReduce)
┌──────────────────────┐              ┌──────────────────────┐
│  TP Rank 0 │ TP Rank 1 │            │  TP Rank 0 │ TP Rank 1 │
│  TP Rank 2 │ TP Rank 3 │            │  TP Rank 2 │ TP Rank 3 │
│  (DP Replica B)       │  ─────────▶ │  (DP Replica B)       │
└──────────────────────┘              └──────────────────────┘

Dimension assignments (Megatron-DeepSpeed convention): - TP (innermost): GPUs within a node, connected by NVLink (highest bandwidth) - PP: across nodes in the same rack (high-bandwidth IB) - DP (outermost): across racks (lower bandwidth, but gradient AllReduce can overlap)

GPT-3 training configuration (reported): TP=8, PP=16, DP=16 → 2,048 total GPUs. Each DP replica has 128 GPUs (8 TP × 16 PP).

Gradient Checkpointing (Activation Recomputation)

During backward pass, PyTorch needs all intermediate activations saved from the forward pass. For a transformer with B×S=2048×2048, storing all activations requires ~30 GB per layer—quickly filling GPU memory.

Gradient checkpointing: Mark certain layers as "checkpoints." During forward pass, discard activations after computing them. During backward pass, recompute them from the checkpoint.

from torch.utils.checkpoint import checkpoint

def forward(self, x):
    x = checkpoint(self.layer1, x)  # activations discarded after this
    x = checkpoint(self.layer2, x)
    return x

Memory vs compute tradeoff: Full recomputation reduces activation memory from O(L×B×S×H) to O(√L × B×S×H) with selective checkpointing at every √L layers. Cost: ~33% extra forward compute. In practice, activation memory drops from 30 GB to 3 GB for a large transformer, often worth the compute cost.

Mixed Precision Training

Training in FP32 is 2× slower and 2× more memory-intensive than FP16/BF16. Mixed precision maintains a FP32 "master weight" copy for optimizer updates while doing forward/backward in reduced precision.

NVIDIA Apex / PyTorch AMP workflow:

FP32 master weights ──copy──▶ FP16 working weights
                                     │
                               Forward pass (FP16)
                                     │
                               Backward pass (FP16)
                                     │
                          FP16 gradients ──scale──▶ FP32 gradients
                                     │
                          FP32 optimizer update
                                     │
                          Update FP32 master weights

Loss scaling: FP16 has a dynamic range of ~65504. Gradients in early training are often very small (subnormal in FP16). Loss scaling multiplies the loss by a large factor (e.g., 2¹⁵) before backward, shifts gradients into representable FP16 range, then divides by the scale factor before optimizer update. Dynamic loss scaling (NVIDIA AMP) automatically adjusts scale based on overflow detection.

BF16 (Brain Float 16, Ampere+ GPUs): Same exponent range as FP32 (8-bit exponent), only 7 mantissa bits vs FP16's 10. BF16 does not require loss scaling because its dynamic range matches FP32. Preferred for training on A100, H100. BF16 tensor core throughput on H100: 1,979 TFLOP/s.

Historical Context

The first neural network trained with model parallelism was the Google Brain "cat neuron" network in 2012, which used 16,000 CPU cores. The modern era begins with Megatron-LM (NVIDIA, 2019), which trained an 8.3B-parameter GPT-2 variant on 512 GPUs using TP. GPT-3 (OpenAI, 2020) trained on ~1,024 A100s using a combination of TP and PP (full details unpublished). DeepSpeed ZeRO (Microsoft, 2020) enabled the first training of a 1-trillion-parameter model (Megatron-Turing NLG 530B, 2021, joint NVIDIA/Microsoft). By 2022, 3D parallelism with Megatron-DeepSpeed had become the industry standard.

Production Examples

GPT-NeoX-20B (EleutherAI, 2022): 20B parameters, trained on 96 A100 40GB GPUs. Used ZeRO-1 + Megatron-style TP=4. Training on The Pile dataset. First open-source 20B model with public training code.

LLaMA 2 70B (Meta, 2023): Trained on 2,000 A100 80GB GPUs. FSDP (Fully Sharded Data Parallel, PyTorch's ZeRO-3 equivalent) + TP=8. Used activation checkpointing for the attention layers.

Falcon 180B (TII, 2023): 180B parameters, trained on 4,096 A100 GPUs. TP=8, PP=8, DP=64. Training required custom NCCL communication scheduling to hide PP bubble.

Debugging Notes

Loss divergence: Common causes: learning rate too high, gradient explosion (check gradient norms), NaN from FP16 overflow (check loss scaling), ZeRO-3 AllGather returning wrong parameters (check communicator grouping).

Pipeline stalls: If PP send/receive between stages blocks, check that micro-batch count is sufficient (M >> p). Use NCCL_DEBUG=INFO to identify blocked collectives. Torch profiler (torch.profiler.profile) shows timeline of forward/backward/communication.

Gradient mismatch across DP replicas: After enabling gradient checkpointing, verify all replicas produce identical gradients with NCCL_DEBUG=WARN logging. Recomputation with different random seeds (e.g., dropout) can cause divergence if RNG state isn't synchronized.

ZeRO-3 slow AllGather: If AllGather dominates iteration time, increase allgather_bucket_size in DeepSpeed config (default 500M elements). This trades memory for fewer, larger communication rounds.

Security Implications

Distributed training jobs share gradient data across all workers. A malicious worker (e.g., in a federated learning setting) can submit poisoned gradients to corrupt the global model. Gradient clipping (torch.nn.utils.clip_grad_norm_) limits individual gradient contribution but does not prevent coordinated attacks. In private clusters, the risk is physical security of the compute nodes. Checkpoint files contain the full model weights and must be stored securely (encryption at rest for proprietary models).

Performance Implications

MFU (Model FLOP Utilization): The key efficiency metric. GPT-3 training reportedly achieved ~46% MFU on A100. Modern systems target 50–55%. Below 40% indicates significant overhead from communication, memory transfers, or load imbalance.

Optimal TP degree: TP requires AllReduce within a node. For NVLink bandwidth of 600 GB/s (H100 NVSwitch), TP=8 is efficient. TP beyond 8 typically requires inter-node AllReduce, which is 10× slower. TP=8 is the practical maximum.

Batch size scaling: Larger global batch sizes reduce gradient noise but require more GPUs for same iteration time. GPT-3 used batch size 3.2M tokens (sequence × batch). Increasing DP degree keeps per-GPU batch fixed while increasing global batch.

PP bubble overhead: For PP=8, GPipe bubble is 7/15 ≈ 47% at M=8. Increasing M to 32 reduces bubble to 7/39 ≈ 18%. PipeDream 1F1B achieves ~(p-1)/M regardless. Practical target: <10% bubble overhead.

Failure Modes and Real Incidents

Incident: Gradient explosion in FP16 training (common): FP16 gradients overflow to Inf/NaN when loss scale too high. Symptoms: loss suddenly becomes NaN. Modern frameworks (PyTorch AMP) detect this and skip the optimizer step, reducing scale by 2. If overflows persist for 1000+ steps, model training is stalled.

Incident: ZeRO-3 parameter mismatch (DeepSpeed bug, 2021): A bug in DeepSpeed ZeRO-3 caused AllGather to return stale parameters under certain communication ordering conditions. Models trained with this version produced incorrect results that were reproducible but wrong. Fixed in DeepSpeed 0.6.0.

Incident: PP pipeline stall from slow node (reported by various groups): A single GPU running at 95% of normal speed causes micro-batch queue in PP to stall, reducing throughput to that GPU's speed. Detection: per-stage timing profiling. Mitigation: health checks + automatic node replacement in training frameworks like MegaScale (ByteDance).

Incident: NVLink link failure mid-training: A single NVLink lane failure reduces intra-node bandwidth asymmetrically. NCCL falls back to PCIe, which is 4× slower. Symptom: training throughput drops suddenly. Detection: nvidia-smi nvlink --status and NCCL bandwidth test.

Modern Usage

PyTorch FSDP (Fully Sharded Data Parallel, 2022): PyTorch's native ZeRO-3 implementation, replacing Fairscale and simplifying ZeRO deployment. Used by Meta for LLaMA training.

Megatron-Core (2023): Modular refactoring of Megatron-LM into a library usable with any training framework. Provides TP, PP, and sequence parallelism primitives.

Sequence Parallelism: For very long contexts (128K+ tokens), even with TP, the activation tensor for a single layer is too large. Sequence parallelism (NVIDIA 2023) shards the sequence dimension across TP workers, further reducing per-GPU activation memory.

Expert Parallelism (EP) for MoE: Mixtral 8×7B, GPT-4 (reportedly), and others use Mixture of Experts layers where each token is routed to 2 of 8+ expert FFN layers. Expert Parallelism places each expert on a different GPU and uses All-to-All collective for token routing.

Future Directions

  • Interleaved pipeline schedules: NVIDIA's "dual-pipe" schedule for GPT-NeMo claims <5% bubble overhead at PP=16
  • Automatic parallelism: Systems like Alpa (ICML 2022) and FlexFlow auto-search the parallelism strategy space; reducing manual tuning
  • Communication-computation overlap: Improved NCCL support for pipelining collective communication with transformer layer compute
  • Heterogeneous training: Mixing GPU types (A100 + H100) in the same job using capability-aware task scheduling

Exercises

  1. Memory calculation: For a 70B-parameter transformer model using Adam optimizer with BF16 parameters and FP32 optimizer states, calculate the per-GPU memory requirement for: (a) ZeRO-0 on 1 GPU, (b) ZeRO-2 on 64 GPUs, (c) ZeRO-3 on 512 GPUs. Include parameters, gradients, and optimizer states.

  2. Pipeline bubble analysis: Implement a pipeline schedule simulator in Python. Given pipeline_stages=8 and micro_batches=16, simulate the GPipe and 1F1B schedules and compute the bubble fraction. Plot utilization as a function of micro-batch count.

  3. TP implementation: Implement column-parallel and row-parallel linear layers in PyTorch using torch.distributed. Verify that the output matches a single-GPU reference implementation for a [batch=4, seq=128, hidden=512] input.

  4. DDP gradient sync profiling: Profile a PyTorch DDP training run with torch.profiler. Identify what fraction of time is spent in NCCL AllReduce vs forward pass vs backward pass. Experiment with bucket sizes (bucket_cap_mb) and observe the effect on throughput.

  5. ZeRO-3 communication overhead: Given a 13B-parameter model trained with ZeRO-3 on 128 GPUs, each with a 400 Gbps NIC, calculate the minimum theoretical time per iteration spent on AllGather + ReduceScatter operations. Compare to a typical forward+backward time of 200ms.

References

  • Shoeybi et al., "Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism," 2019
  • Rajbhandari et al., "ZeRO: Memory Optimizations Toward Training Trillion Parameter Models," SC 2020
  • Narayanan et al., "Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM," SC 2021
  • Huang et al., "GPipe: Efficient Training of Giant Neural Networks using Pipeline Parallelism," NeurIPS 2019
  • Narayanan et al., "PipeDream: Generalized Pipeline Parallelism for DNN Training," SOSP 2019
  • Micikevicius et al., "Mixed Precision Training," ICLR 2018
  • Zhao et al., "PyTorch FSDP: Experiences on Scaling Fully Sharded Data Parallel," VLDB 2023
  • Lepikhin et al., "GShard: Scaling Giant Models with Conditional Computation and Automatic Sharding," ICLR 2021