Lecture 7:分布式并行训练基础
Stanford CS336 Language Modeling from Scratch | Spring 2026 | Lecture 7: Parallelism,时长 01:21:02。
从单 GPU 到多 GPU:为什么需要并行#
上一讲的主题是单 GPU 内部的并行——通过编写 kernel 来充分利用 SM、shared memory、寄存器等硬件资源,让一块 GPU 跑得更快。本讲将视角从单 GPU 内部扩展到多 GPU 之间的并行:如何调度成百上千块 GPU 协同训练一个大语言模型。
从硬件视角看,单 GPU 的分层结构(L1 cache / shared memory → L2 cache → HBM)现在被嵌入到一个更大的层次体系中。在单 GPU 时代,HBM 的带宽是我们抱怨的瓶颈;而在多 GPU 时代,HBM 反而变成了”快的那一层”——真正的瓶颈转移到了 GPU 之间的互连网络。
核心挑战不变:计算单元(ALU / Tensor Core)离数据很远。单 GPU 场景下,“远”指的是 HBM 到 SM 的距离;多 GPU 场景下,“远”变成了跨节点甚至跨数据中心的距离。因此,同一个设计原则贯穿两讲:编排计算以避免数据搬运瓶颈。
使用多 GPU 的两个动机#
-
内存不够:模型的参数、激活值、梯度和优化器状态无法放进单块 GPU 的 HBM。以 B200 的 192 GB HBM 为例,一个万亿参数模型显然无法在单卡上训练。
-
速度不够:即使模型能塞进单卡,也可能希望用更多 GPU 来切分工作负载、缩短训练时间。
两者之间存在 trade-off:将工作分散到更多 GPU 上意味着更多的通信开销。因此需要计算”多用几块卡带来的额外算力”与”GPU 间数据搬运的代价”之间的平衡。
集合通信原语(Collective Operations)#
集合通信是分布式编程的基石,其历史可以追溯到 1980 年代的 MPI 标准——并非为大语言模型训练而发明,但至今仍是 LLM 训练的核心原语。“集合”(collective) 的含义是:你只需指定一个多设备间的通信模板,而不需要手动管理每对 GPU 之间的点对点消息传递。
基本术语#
- Rank:一个参与通信的设备(通常是一块 GPU),用整数编号(0, 1, 2, 3, …)
- World size:参与通信的设备总数
基础操作(热身)#
以下三个操作在训练的核心路径中不常直接出现,但它们是理解后续核心操作的基础。
Broadcast:从一个 rank(如 rank 0)向所有 rank 复制同一份张量。典型用途是在初始化时,rank 0 加载 checkpoint 后广播给其他所有 rank。
1# Input: rank0 = [0, 1, 2, 3]2# Output: rank0 = rank1 = rank2 = rank3 = [0, 1, 2, 3]Scatter:将 rank 0 上的一个大张量按 world size 切分,每个 rank 获得一个分片。Scatter 本身不常直接使用,但它是理解 reduce-scatter 的跳板。
1# Input: rank0 = [0, 1, 2, 3]2# Output: rank0 = [0], rank1 = [1], rank2 = [2], rank3 = [3]Gather:scatter 的逆操作——将各 rank 上的分片收集到一个指定 rank(如 rank 0),拼接成完整张量。Gather 是理解 all-gather 的跳板。
1# Input: rank0 = [0], rank1 = [1], rank2 = [2], rank3 = [3]2# Output: rank0 = [0, 1, 2, 3]Reduce:类似 gather,但不是拼接,而是对各 rank 的数据施加一个归约操作(如 sum、max、min)。可以把 gather 看成归约操作为”拼接”的特殊 reduce。
1# Input (sum): rank0 = [0], rank1 = [1], rank2 = [2], rank3 = [3]2# Output: rank0 = [6] (0+1+2+3)核心操作(训练中的主力)#
All-gather:对所有 rank 执行 gather,而不仅仅是 rank 0。操作完成后,每个 rank 都持有完整的拼接结果。名字中的 “all” 表示”输出到所有 rank”。
1# Input: rank0 = [0], rank1 = [1], rank2 = [2], rank3 = [3]2# Output: rank0 = rank1 = rank2 = rank3 = [0, 1, 2, 3]All-gather 的典型场景:每个 rank 持有参数的一个分片(shard),forward pass 之前需要 all-gather 拼出完整参数。
Reduce-scatter:对每个维度分别执行归约,然后将结果分散到不同 rank。具体来说,如果每个 rank 持有一个长度为 W(world size)的向量,reduce-scatter 会对第 i 个分量在所有 rank 上求和,然后将结果放到 rank i 上。
1# Input: rank0 = [0,1,2,3], rank1 = [1,2,3,4], rank2 = [2,3,4,5], rank3 = [3,4,5,6]2# Output: rank0 = [6] (0+1+2+3)3# rank1 = [10] (1+2+3+4)4# rank2 = [14] (2+3+4+5)5# rank3 = [18] (3+4+5+6)Reduce-scatter 的典型场景:backward pass 结束后,各 rank 上的梯度需要求和,然后分布式存储归约后的结果。
All-reduce = reduce-scatter + all-gather。先归约,再让所有 rank 拿到完整的归约结果。这是数据并行(DDP)中最核心的操作。
1# Input: rank0 = [0,1,2,3], rank1 = [1,2,3,4], rank2 = [2,3,4,5], rank3 = [3,4,5,6]2# Output: rank0 = rank1 = rank2 = rank3 = [6, 10, 14, 18]All-to-all:最通用的集合操作。每个 rank 指定要发送给每个其他 rank 的数据。输入张量的第 i 个分量会被发送到 rank i。如果所有 rank 发送的数据量均衡,all-to-all 在数学上等价于对输入矩阵做转置。
All-to-all 的典型场景是 Mixture of Experts (MoE) 训练:每个 rank 持有一部分 data 和一部分 expert,需要根据 router 的动态决策将 token 路由到正确的 expert 所在的 rank。负载均衡的目标正是让 all-to-all 的通信量尽可能均匀。
记忆术语的窍门#
- Reduce:施加关联交换操作(sum / max / min)
- Scatter 与 Gather 互逆:scatter 分发,gather 集中
- All- 前缀:输出目标是所有 rank
GPU 互连硬件拓扑#
理解了集合通信的抽象语义之后,需要知道这些操作跑在什么硬件上——不同层级的互连带宽差距巨大,直接决定了并行策略的选择。
三层互连层次#
GPU 集群的互连呈严格分层结构,从快到慢依次为:
| 层级 | 互连技术 | 典型带宽 | 覆盖范围 |
|---|---|---|---|
| 节点内 | NVLink → NVSwitch | 1.8 TB/s(NVLink 5.0) | 同一节点内 8 块 GPU |
| Pod 内 | Infiniband(经 PCIe → HCA → IB 线缆) | ~0.05 TB/s | ~256 节点 |
| 跨 Pod | Ethernet(经 PCIe → CPU) | 更低 | 整个数据中心 |
作为对比,B200 的 HBM 带宽为 8 TB/s。因此 NVLink 大约是 HBM 的 1/4 速度——在单 GPU 优化语境下很慢,但在多 GPU 语境下已经算快了。
NVLink 与 NVSwitch#
典型配置是 8 块 GPU 通过 NVLink 连接到一个 NVSwitch。NVSwitch 起到交换机的作用:从编程角度看,同一 NVLink domain 内的任意两块 GPU 可以直接通信,硬件自动处理路由。
Nvidia 的 NVL72 配置更进一步:将 9 个 tray(每个 tray 8 块 GPU)通过 NVSwitch 互连,形成一个包含 72 块 GPU 的单一 NVLink domain。这意味着 72 块 GPU 之间都享有 NVLink 级别的高带宽互连。
Infiniband#
当集群规模超出 NVLink domain 的容量后,节点之间通过 Infiniband 连接。Infiniband 的一个关键特性是支持 RDMA(Remote Direct Memory Access):一块 GPU 可以直接读写另一块 GPU 的显存,无需经过 CPU。这避免了传统 Ethernet 路径中从 GPU → CPU kernel socket buffer → TCP 封包 → NIC → 网络的多次拷贝和延迟。
传统 Ethernet 与 RoCE#
标准 Ethernet 不支持 RDMA——数据必须经过 CPU 中转,引入大量延迟。但 RoCE(RDMA over Converged Ethernet) 技术让 Ethernet 也能绕过 CPU 直接传输。Meta 的论文显示他们在探索用 RoCE 替代 Infiniband 来训练 LLaMA 系列模型——这是对 Infiniband 高成本的一种替代方案。
RDMA 的本质#
RDMA 是一个功能性概念,而非具体硬件:它描述的是”一块 GPU 直接读写另一块 GPU 显存”的能力。实现 RDMA 的硬件手段有多种——NVLink/NVSwitch、Infiniband、RoCE 都可以提供 RDMA 能力。
NCCL:集合操作的底层实现#
NCCL(Nvidia Collective Communications Library) 是将集合操作翻译为底层 GPU 间通信的库。当你调用 all_reduce 时,NCCL 负责:
- 探测 GPU 互连的拓扑结构
- 规划最优的消息传递路径(ring topology 或 tree topology)
- 启动专门的 通信 kernel 在 GPU 上执行数据收发
通信操作和计算操作一样,最终都是 GPU kernel——只不过这些 kernel 的工作是在 GPU 之间搬运数据,而非做矩阵运算。
PyTorch 分布式编程与性能基准#
torch.distributed 接口#
PyTorch 的 torch.distributed 库为集合操作提供了干净的高层接口,屏蔽了底层 NCCL 的细节。它支持多种后端:GPU 环境使用 NCCL 后端,CPU 环境使用 Gloo 后端。该库也提供了 FSDP 等高级封装,但本课程选择从原语出发手动实现。
编程模型#
分布式程序的基本模式是 SPMD(Single Program, Multiple Data):同一段代码被复制到多个进程上运行,每个进程拥有不同的 rank(0 到 world_size - 1),在同一段代码中根据 rank 执行不同逻辑。
1def main(rank, world_size, ...):2 # 每个进程都执行这段代码,但 rank 不同3 setup(rank, world_size) # 配置 master_addr/port,初始化进程组4 ...5 cleanup()6
7# 启动 world_size 个进程8spawn(main, world_size=4, ...)初始化时设置的 master_addr 和 master_port 仅用于进程间的元数据协调,实际的数据传输走 NCCL/NVLink 通道。
同步机制#
由于各进程异步运行,执行顺序不确定。dist.barrier() 提供同步点——所有进程必须到达 barrier 后才能继续。过多的 barrier 会降低并行效率,但对于正确性保证是必要的。
All-reduce 代码示例#
1data = torch.arange(4, dtype=torch.float32) + rank # 每个 rank 不同2dist.all_reduce(data, op=dist.ReduceOp.SUM)3# 操作原地修改 data,所有 rank 得到相同结果Reduce-scatter 与 All-gather 代码示例#
1# Reduce-scatter: 输出与输入分离2output = torch.zeros(1)3dist.reduce_scatter_tensor(output, input, op=dist.ReduceOp.SUM)4
5# All-gather: 将 reduce-scatter 的结果收集到所有 rank6activations = [torch.zeros_like(x) for _ in range(world_size)]7dist.all_gather_into_tensor(output, x)通过先做 reduce-scatter 再做 all-gather,可以验证其结果与直接做 all-reduce 完全一致——这从实验上证实了 all-reduce = reduce-scatter + all-gather 的等价关系。
异步通信#
将 async_op=True 传入集合操作后,调用会立即返回,通信在后台进行。这为通信与计算重叠提供了基础——例如在等待 all-reduce 完成的同时加载下一个 batch 的数据。需要确认通信完成时,调用 wait() 或 barrier()。
需要注意两层异步性的叠加:CUDA kernel 本身相对 CPU 是异步的(需要 torch.cuda.synchronize()),而分布式操作又在进程间引入了另一层异步(需要 dist.barrier())。在 benchmark 时,必须先 synchronize() 再 barrier()——顺序不能反。原因是:如果先调用 barrier,各进程可能在自己的 CUDA kernel 还未完成时就通过了 barrier(因为 CUDA 操作相对 Python 是异步的,barrier 调用瞬间返回),此时 barrier 形同虚设,各进程依然在各自独立地等待 CUDA kernel 完成,并未真正同步。
性能基准#
对 100M 个 float32 元素(约 400 MB)的 all-reduce 操作进行 benchmark,在 4 块 GPU 上测得约 1.6 ms 的延迟。
有效带宽(bus bandwidth) 的计算方式:
bandwidth=W×duration2×size×WW−1
其中 W 是 world size,factor 2 来自 all-reduce 同时包含 reduce 和 gather 两个方向的数据搬运,WW−1 是因为 W 个 rank 之间有 W−1 步归约操作。
当 W 较大时,WW−1≈1,有效带宽近似为 duration2×size。实测约 400 GB/s,且这个数值:
- 不依赖 world size——扩展到更多 GPU 时带宽不变
- 不依赖拓扑——NCCL 自动选择 ring 或 tree 拓扑
All-reduce 搬运的数据量是 reduce-scatter 的 2 倍(因为它同时包含 reduce 和 gather 两个方向),但也需要 2 倍的时间。两个 2× 相消,因此 all-reduce 和 reduce-scatter 的有效带宽相同。
数据并行(Data Parallelism)#
本讲的所有并行策略都以 MLP(多层感知机) 为演示对象,而非完整的 Transformer。这个选择并非偷懒——MLP 正是 Transformer 中实际的计算瓶颈,因此 MLP 上的并行模式具有很强的代表性,而完整 Transformer 只是增加了更多 bookkeeping。
数据并行是最直观、最优雅的并行策略。它的核心思想是:切分数据,复制模型。每个 rank 持有完整的模型参数,但只处理全局 batch 的一个子集。
切分策略#
假设全局 batch size 为 B,world size 为 W,则每个 rank 分到的 local batch size 为 B/W。对于一个 B×D 的数据矩阵(D 是特征维度),按行切分,每个 rank 获得 WB×D 的数据切片。
1local_batch_size = batch_size // world_size2start = rank * local_batch_size3data = full_data[start : start + local_batch_size].to(rank)实际中,每个 rank 应该直接加载自己那部分数据,而非先全量加载再切分。
训练流程#
数据并行的训练循环与单 GPU 训练几乎完全相同,唯一的差异是在 backward pass 和 optimizer step 之间插入一步梯度同步:
1for step in range(num_steps):2 # Forward: 每个 rank 用自己的 local batch3 x = data4 for layer in range(num_layers):5 x = relu(x @ params[layer])6 loss = x.sum()7
8 # Backward9 loss.backward()10
11 # === 数据并行的关键一步 ===12 for param in params:13 dist.all_reduce(param.grad, op=dist.ReduceOp.AVG)14
15 # Optimizer step16 optimizer.step()All-reduce 之后,所有 rank 上的梯度完全相同(因为是全局平均),进而参数更新也完全相同。这保证了训练过程中各 rank 的模型参数始终保持一致。
DDP 的优美之处#
数据并行对模型架构完全透明——它不关心 forward pass 的具体实现,只在梯度层面做同步。无论是 MLP、Transformer 还是其他架构,DDP 的改动都是同样的一行 all-reduce。这种模块化设计使得 DDP 成为最容易实现和调试的并行策略。
状态总结#
| 量 | 各 rank 间关系 |
|---|---|
| 数据(loss) | 不同(各自处理不同子集) |
| 梯度(all-reduce 前) | 不同 |
| 梯度(all-reduce 后) | 相同 |
| 参数 | 始终相同 |
约束条件#
- 全局 batch size 至少为 world size(否则无法切分)
- 理想情况下 batch size 是 world size 的整数倍(否则需要 padding)
DDP 的局限性#
DDP 要求每个 rank 持有完整的模型参数。如果模型太大以至于单卡放不下参数 + 梯度 + 优化器状态,DDP 就不够用了。这正是下一讲的 FSDP / ZeRO 要解决的问题——它们将 all-reduce 拆成 reduce-scatter + all-gather,在这两步之间插入参数分片逻辑,使得每个 rank 只需持有参数的一个分片。
通信与计算重叠的优化#
上面的实现在整个 backward pass 完成后才统一做 all-reduce,但实际上可以做得更高效:backward pass 中每一层的梯度计算完成后,立即发起该层梯度的 all-reduce(使用 async),同时继续计算前面层的梯度。这样通信和计算可以重叠,显著减少总等待时间。这也是 Assignment 2 中需要探索的优化。
张量并行(Tensor Parallelism)#
与数据并行切分 batch 维度不同,张量并行切分的是模型的参数矩阵本身——每个 rank 持有每一层参数的一个列切片(column shard)。这种方式也被称为 column tensor parallel(也存在按行切分的 row tensor parallel,但本讲不展开)。
切分策略#
对于一个 D×D 的参数矩阵,张量并行将其按列切分为 W 份,每个 rank 持有 D×WD 的参数切片。为了简化演示,假设所有 rank 都持有完整的输入数据(实际中张量并行常与数据并行结合,此时每个 rank 不一定持有全部数据)。
1local_num_dim = num_dim // world_size2# 每个 rank 持有所有层的参数切片3params[layer].shape = (num_dim, local_num_dim) # 而非 (num_dim, num_dim)Forward Pass#
每个 rank 用完整的输入 X(shape 为 B×D)乘以自己持有的参数切片,得到局部激活值(shape 为 B×WD)。由于非线性激活(如 ReLU)是逐元素操作,可以直接在局部激活值上应用。
但在进入下一层之前,必须将各 rank 的局部激活值拼接成完整的 B×D 张量——否则下一层的矩阵乘法维度不匹配。这一步通过 all-gather 实现:
1for layer in range(num_layers):2 x = relu(x @ params[layer]) # x: (B, local_D)3
4 # All-gather: 拼接各 rank 的局部激活值5 activations = [torch.zeros(batch_size, local_num_dim) for _ in range(world_size)]6 dist.all_gather_into_tensor(output, x)7 x = torch.cat(activations, dim=1) # x: (B, D)All-gather 与 Reduce-scatter 的对偶性#
张量并行有一个优雅的数学结构:
- Forward pass:每层需要 all-gather(拼接局部激活值)
- Backward pass:对应地需要 reduce-scatter(归约梯度并分片存储)
需要注意的是,backward pass 中的 reduce-scatter 不是 PyTorch autograd 自动完成的——在本课程的从零实现中,需要手动调用 dist.reduce_scatter_tensor()。工业框架(如 PyTorch 的 tensor parallel 模块)会自动处理这一步,但理解底层机制正是本课程的目的。
这种 all-gather / reduce-scatter 的对偶关系是分布式训练中反复出现的模式。
与数据并行的对比#
| 维度 | 数据并行 | 张量并行 |
|---|---|---|
| 切分对象 | batch 维度 | 参数矩阵的列维度 |
| 模型侵入性 | 无(模型不可感知) | 高(需要改写 forward/backward) |
| 通信频率 | 每个 training step 一次 | 每一层 forward/backward 各一次 |
| 通信量 | 梯度大小 | 激活值大小(通常更大) |
| 数学基础 | 梯度的可加性 | 矩阵乘法可按列拆分 |
张量并行的通信量明显更大——每一层都要做一次 all-gather,而数据并行只在整个 backward 结束后做一次 all-reduce。因此张量并行强烈依赖高带宽互连。
流水线并行(Pipeline Parallelism)#
流水线并行的切分维度是网络的深度——将模型的不同层分配给不同 rank。每个 rank 持有一个层子集的完整参数(包括完整维度),但只负责这些层的计算。
切分策略#
1local_num_layers = num_layers // world_size2# rank 0 持有 layer 0~3, rank 1 持有 layer 4~7, ...3# 每层参数维度不变: (num_dim, num_dim)Forward Pass 与点对点通信#
与数据并行和张量并行使用集合操作不同,流水线并行使用点对点的 send/recv:
1for micro_batch in micro_batches:2 if rank > 0:3 dist.recv(x, src=rank - 1) # 从上一级接收激活值4
5 for layer in local_layers:6 x = relu(x @ params[layer])7
8 if rank < world_size - 1:9 dist.send(x, dst=rank + 1) # 发送激活值给下一级Rank 0 接收原始数据,依次通过自己负责的层后将激活值发送给 rank 1,rank 1 处理后再发给 rank 2,以此类推。
Pipeline Bubble 问题#
流水线并行的核心问题是 pipeline bubble:当 rank 0 在处理数据时,rank 1、2、3 都在闲等;当 rank 3 在处理时,rank 0、1、2 又在闲等。这导致大量 GPU 时间被浪费在等待上。
Micro-batching:减少 Bubble#
解决 bubble 的关键技术是 micro-batching:将全局 batch 进一步切分为多个 micro-batch。每个 micro-batch 体积更小,处理更快,可以更快地传递给下一级。这样各 rank 可以交替处理不同的 micro-batch,大幅减少闲置时间。
通信与计算重叠#
流水线并行的结构天然适合做通信/计算重叠。上面展示的 naive 版本使用阻塞的 send/recv,并未实现通信与计算的重叠。要实现重叠,需要将 send/recv 替换为非阻塞的 isend/irecv(加 “i” 前缀),并增加额外的代码来管理异步状态——这并不像听上去那么简单,但它是让流水线并行真正高效的关键。
并行策略的选择与权衡#
五种并行切分维度#
| 并行策略 | 切分维度 | 核心集合操作 |
|---|---|---|
| Data parallelism | batch 维度 | all-reduce(DDP)/ all-gather + reduce-scatter(FSDP/ZeRO) |
| Tensor parallelism | 参数矩阵的宽度维度 | all-gather(forward)+ reduce-scatter(backward) |
| Pipeline parallelism | 网络深度维度 | point-to-point send/recv |
| Sequence parallelism | 序列长度维度 | 专用通信(用于并行化 attention) |
| Expert parallelism | expert 维度(MoE) | all-to-all |
硬件约束决定策略组合#
并行策略的选择强烈依赖硬件拓扑。一个核心原则是:通信量大的并行策略必须跑在高带宽互连上。
- 张量并行:每层 forward/backward 都需要通信,通信量大 → 只应在 NVLink/NVSwitch domain 内使用(节点内 8 块或 NVL72 的 72 块 GPU)
- 数据并行:每个 training step 只通信一次梯度 → 可以跨 Infiniband 节点使用
- 流水线并行:通信量最小(只传激活值,且可以用 micro-batching 分摊)→ 甚至可以跨低速网络使用。去中心化训练方案中,GPU 可能分布在地理上遥远的位置,这时流水线并行是唯一可行的选择
典型的组合策略#
在大规模训练中,通常将多种并行策略嵌套组合:
- 节点内(NVLink domain):tensor parallelism
- 节点间(Infiniband):data parallelism / FSDP
- 跨 pod(低速网络):pipeline parallelism
Critical Batch Size#
数据并行理论上可以无限扩展——只要不断增大 batch size 即可。但存在一个 critical batch size:超过该值后,继续增大 batch size 对训练收敛的边际收益递减,额外的 GPU 算力被浪费。此时应该切换到张量并行或其他策略来利用多余的 GPU。
冗余计算 vs 通信的 Trade-off#
贯穿本讲所有并行策略的一个 meta-pattern 是 “重算/冗余存储 vs 通信”的 trade-off。在单 GPU 上,这表现为 activation checkpointing(重新计算激活值以节省显存);在多 GPU 上,这表现为数据并行的冗余参数存储——每个 rank 都持有完整参数,存在冗余存储和冗余计算,但换来了通信量的减少:不需要在 forward/backward 中搬运优化器状态。FSDP/ZeRO 则选择消除这种冗余,代价是引入更频繁的通信。可以将”存在另一块 GPU 上”看成是单 GPU 存储层次(L1 → L2 → HBM)的自然延伸。
JAX/TPU 的替代路径#
PyTorch 采用显式编程的方式:开发者手动指定使用哪种集合操作。JAX + TPU 提供了一种声明式的替代方案:开发者只需定义模型和 sharding strategy(哪些张量的哪个维度被切分到哪些设备),编译器自动推导所需的通信操作。这种方式更简洁,但屏蔽了底层细节——本课程选择 PyTorch 的显式路径正是为了让学生理解从原语到训练的完整链条。
部分内容可能已过时