Lecture 7:分布式并行训练基础

5530 字
28 分钟
Lecture 7:分布式并行训练基础
文章摘要

Stanford CS336 Language Modeling from Scratch | Spring 2026 | Lecture 7: Parallelism,时长 01:21:02。

从单 GPU 到多 GPU:为什么需要并行#

上一讲的主题是单 GPU 内部的并行——通过编写 kernel 来充分利用 SM、shared memory、寄存器等硬件资源,让一块 GPU 跑得更快。本讲将视角从单 GPU 内部扩展到多 GPU 之间的并行:如何调度成百上千块 GPU 协同训练一个大语言模型。

从硬件视角看,单 GPU 的分层结构(L1 cache / shared memory → L2 cache → HBM)现在被嵌入到一个更大的层次体系中。在单 GPU 时代,HBM 的带宽是我们抱怨的瓶颈;而在多 GPU 时代,HBM 反而变成了”快的那一层”——真正的瓶颈转移到了 GPU 之间的互连网络。

核心挑战不变:计算单元(ALU / Tensor Core)离数据很远。单 GPU 场景下,“远”指的是 HBM 到 SM 的距离;多 GPU 场景下,“远”变成了跨节点甚至跨数据中心的距离。因此,同一个设计原则贯穿两讲:编排计算以避免数据搬运瓶颈

使用多 GPU 的两个动机#

  1. 内存不够:模型的参数、激活值、梯度和优化器状态无法放进单块 GPU 的 HBM。以 B200 的 192 GB HBM 为例,一个万亿参数模型显然无法在单卡上训练。

  2. 速度不够:即使模型能塞进单卡,也可能希望用更多 GPU 来切分工作负载、缩短训练时间。

两者之间存在 trade-off:将工作分散到更多 GPU 上意味着更多的通信开销。因此需要计算”多用几块卡带来的额外算力”与”GPU 间数据搬运的代价”之间的平衡。

集合通信原语(Collective Operations)#

集合通信是分布式编程的基石,其历史可以追溯到 1980 年代的 MPI 标准——并非为大语言模型训练而发明,但至今仍是 LLM 训练的核心原语。“集合”(collective) 的含义是:你只需指定一个多设备间的通信模板,而不需要手动管理每对 GPU 之间的点对点消息传递。

基本术语#

  • Rank:一个参与通信的设备(通常是一块 GPU),用整数编号(0, 1, 2, 3, …)
  • World size:参与通信的设备总数

基础操作(热身)#

以下三个操作在训练的核心路径中不常直接出现,但它们是理解后续核心操作的基础。

Broadcast:从一个 rank(如 rank 0)向所有 rank 复制同一份张量。典型用途是在初始化时,rank 0 加载 checkpoint 后广播给其他所有 rank。

# Input: rank0 = [0, 1, 2, 3]
# Output: rank0 = rank1 = rank2 = rank3 = [0, 1, 2, 3]

Scatter:将 rank 0 上的一个大张量按 world size 切分,每个 rank 获得一个分片。Scatter 本身不常直接使用,但它是理解 reduce-scatter 的跳板。

# Input: rank0 = [0, 1, 2, 3]
# Output: rank0 = [0], rank1 = [1], rank2 = [2], rank3 = [3]

Gather:scatter 的逆操作——将各 rank 上的分片收集到一个指定 rank(如 rank 0),拼接成完整张量。Gather 是理解 all-gather 的跳板。

# Input: rank0 = [0], rank1 = [1], rank2 = [2], rank3 = [3]
# Output: rank0 = [0, 1, 2, 3]

Reduce:类似 gather,但不是拼接,而是对各 rank 的数据施加一个归约操作(如 sum、max、min)。可以把 gather 看成归约操作为”拼接”的特殊 reduce。

# Input (sum): rank0 = [0], rank1 = [1], rank2 = [2], rank3 = [3]
# Output: rank0 = [6] (0+1+2+3)

核心操作(训练中的主力)#

All-gather:对所有 rank 执行 gather,而不仅仅是 rank 0。操作完成后,每个 rank 都持有完整的拼接结果。名字中的 “all” 表示”输出到所有 rank”。

# Input: rank0 = [0], rank1 = [1], rank2 = [2], rank3 = [3]
# Output: rank0 = rank1 = rank2 = rank3 = [0, 1, 2, 3]

All-gather 的典型场景:每个 rank 持有参数的一个分片(shard),forward pass 之前需要 all-gather 拼出完整参数。

Reduce-scatter:对每个维度分别执行归约,然后将结果分散到不同 rank。具体来说,如果每个 rank 持有一个长度为 WW(world size)的向量,reduce-scatter 会对第 ii 个分量在所有 rank 上求和,然后将结果放到 rank ii 上。

# Input: rank0 = [0,1,2,3], rank1 = [1,2,3,4], rank2 = [2,3,4,5], rank3 = [3,4,5,6]
# Output: rank0 = [6] (0+1+2+3)
# rank1 = [10] (1+2+3+4)
# rank2 = [14] (2+3+4+5)
# rank3 = [18] (3+4+5+6)

Reduce-scatter 的典型场景:backward pass 结束后,各 rank 上的梯度需要求和,然后分布式存储归约后的结果。

All-reduce = reduce-scatter + all-gather。先归约,再让所有 rank 拿到完整的归约结果。这是数据并行(DDP)中最核心的操作。

# Input: rank0 = [0,1,2,3], rank1 = [1,2,3,4], rank2 = [2,3,4,5], rank3 = [3,4,5,6]
# Output: rank0 = rank1 = rank2 = rank3 = [6, 10, 14, 18]

All-to-all:最通用的集合操作。每个 rank 指定要发送给每个其他 rank 的数据。输入张量的第 ii 个分量会被发送到 rank ii。如果所有 rank 发送的数据量均衡,all-to-all 在数学上等价于对输入矩阵做转置

All-to-all 的典型场景是 Mixture of Experts (MoE) 训练:每个 rank 持有一部分 data 和一部分 expert,需要根据 router 的动态决策将 token 路由到正确的 expert 所在的 rank。负载均衡的目标正是让 all-to-all 的通信量尽可能均匀。

记忆术语的窍门#

  • Reduce:施加关联交换操作(sum / max / min)
  • ScatterGather 互逆:scatter 分发,gather 集中
  • All- 前缀:输出目标是所有 rank

GPU 互连硬件拓扑#

理解了集合通信的抽象语义之后,需要知道这些操作跑在什么硬件上——不同层级的互连带宽差距巨大,直接决定了并行策略的选择。

三层互连层次#

GPU 集群的互连呈严格分层结构,从快到慢依次为:

层级互连技术典型带宽覆盖范围
节点内NVLink → NVSwitch1.8 TB/s(NVLink 5.0)同一节点内 8 块 GPU
Pod 内Infiniband(经 PCIe → HCA → IB 线缆)~0.05 TB/s~256 节点
跨 PodEthernet(经 PCIe → CPU)更低整个数据中心

作为对比,B200 的 HBM 带宽为 8 TB/s。因此 NVLink 大约是 HBM 的 1/4 速度——在单 GPU 优化语境下很慢,但在多 GPU 语境下已经算快了。

典型配置是 8 块 GPU 通过 NVLink 连接到一个 NVSwitch。NVSwitch 起到交换机的作用:从编程角度看,同一 NVLink domain 内的任意两块 GPU 可以直接通信,硬件自动处理路由。

Nvidia 的 NVL72 配置更进一步:将 9 个 tray(每个 tray 8 块 GPU)通过 NVSwitch 互连,形成一个包含 72 块 GPU 的单一 NVLink domain。这意味着 72 块 GPU 之间都享有 NVLink 级别的高带宽互连。

Infiniband#

当集群规模超出 NVLink domain 的容量后,节点之间通过 Infiniband 连接。Infiniband 的一个关键特性是支持 RDMA(Remote Direct Memory Access):一块 GPU 可以直接读写另一块 GPU 的显存,无需经过 CPU。这避免了传统 Ethernet 路径中从 GPU → CPU kernel socket buffer → TCP 封包 → NIC → 网络的多次拷贝和延迟。

传统 Ethernet 与 RoCE#

标准 Ethernet 不支持 RDMA——数据必须经过 CPU 中转,引入大量延迟。但 RoCE(RDMA over Converged Ethernet) 技术让 Ethernet 也能绕过 CPU 直接传输。Meta 的论文显示他们在探索用 RoCE 替代 Infiniband 来训练 LLaMA 系列模型——这是对 Infiniband 高成本的一种替代方案。

RDMA 的本质#

RDMA 是一个功能性概念,而非具体硬件:它描述的是”一块 GPU 直接读写另一块 GPU 显存”的能力。实现 RDMA 的硬件手段有多种——NVLink/NVSwitch、Infiniband、RoCE 都可以提供 RDMA 能力。

NCCL:集合操作的底层实现#

NCCL(Nvidia Collective Communications Library) 是将集合操作翻译为底层 GPU 间通信的库。当你调用 all_reduce 时,NCCL 负责:

  1. 探测 GPU 互连的拓扑结构
  2. 规划最优的消息传递路径(ring topology 或 tree topology)
  3. 启动专门的 通信 kernel 在 GPU 上执行数据收发

通信操作和计算操作一样,最终都是 GPU kernel——只不过这些 kernel 的工作是在 GPU 之间搬运数据,而非做矩阵运算。

PyTorch 分布式编程与性能基准#

torch.distributed 接口#

PyTorch 的 torch.distributed 库为集合操作提供了干净的高层接口,屏蔽了底层 NCCL 的细节。它支持多种后端:GPU 环境使用 NCCL 后端,CPU 环境使用 Gloo 后端。该库也提供了 FSDP 等高级封装,但本课程选择从原语出发手动实现。

编程模型#

分布式程序的基本模式是 SPMD(Single Program, Multiple Data):同一段代码被复制到多个进程上运行,每个进程拥有不同的 rank(0 到 world_size - 1),在同一段代码中根据 rank 执行不同逻辑。

def main(rank, world_size, ...):
# 每个进程都执行这段代码,但 rank 不同
setup(rank, world_size) # 配置 master_addr/port,初始化进程组
...
cleanup()
# 启动 world_size 个进程
spawn(main, world_size=4, ...)

初始化时设置的 master_addrmaster_port 仅用于进程间的元数据协调,实际的数据传输走 NCCL/NVLink 通道。

同步机制#

由于各进程异步运行,执行顺序不确定。dist.barrier() 提供同步点——所有进程必须到达 barrier 后才能继续。过多的 barrier 会降低并行效率,但对于正确性保证是必要的。

All-reduce 代码示例#

data = torch.arange(4, dtype=torch.float32) + rank # 每个 rank 不同
dist.all_reduce(data, op=dist.ReduceOp.SUM)
# 操作原地修改 data,所有 rank 得到相同结果

Reduce-scatter 与 All-gather 代码示例#

# Reduce-scatter: 输出与输入分离
output = torch.zeros(1)
dist.reduce_scatter_tensor(output, input, op=dist.ReduceOp.SUM)
# All-gather: 将 reduce-scatter 的结果收集到所有 rank
activations = [torch.zeros_like(x) for _ in range(world_size)]
dist.all_gather_into_tensor(output, x)

通过先做 reduce-scatter 再做 all-gather,可以验证其结果与直接做 all-reduce 完全一致——这从实验上证实了 all-reduce = reduce-scatter + all-gather 的等价关系。

异步通信#

async_op=True 传入集合操作后,调用会立即返回,通信在后台进行。这为通信与计算重叠提供了基础——例如在等待 all-reduce 完成的同时加载下一个 batch 的数据。需要确认通信完成时,调用 wait()barrier()

需要注意两层异步性的叠加:CUDA kernel 本身相对 CPU 是异步的(需要 torch.cuda.synchronize()),而分布式操作又在进程间引入了另一层异步(需要 dist.barrier())。在 benchmark 时,必须先 synchronize()barrier()——顺序不能反。原因是:如果先调用 barrier,各进程可能在自己的 CUDA kernel 还未完成时就通过了 barrier(因为 CUDA 操作相对 Python 是异步的,barrier 调用瞬间返回),此时 barrier 形同虚设,各进程依然在各自独立地等待 CUDA kernel 完成,并未真正同步。

性能基准#

对 100M 个 float32 元素(约 400 MB)的 all-reduce 操作进行 benchmark,在 4 块 GPU 上测得约 1.6 ms 的延迟。

有效带宽(bus bandwidth) 的计算方式:

bandwidth=2×size×W1WW×duration\text{bandwidth} = \frac{2 \times \text{size} \times \frac{W-1}{W}}{W \times \text{duration}}

其中 WW 是 world size,factor 2 来自 all-reduce 同时包含 reduce 和 gather 两个方向的数据搬运,W1W\frac{W-1}{W} 是因为 WW 个 rank 之间有 W1W-1 步归约操作。

WW 较大时,W1W1\frac{W-1}{W} \approx 1,有效带宽近似为 2×sizeduration\frac{2 \times \text{size}}{\text{duration}}。实测约 400 GB/s,且这个数值:

  • 不依赖 world size——扩展到更多 GPU 时带宽不变
  • 不依赖拓扑——NCCL 自动选择 ring 或 tree 拓扑

All-reduce 搬运的数据量是 reduce-scatter 的 2 倍(因为它同时包含 reduce 和 gather 两个方向),但也需要 2 倍的时间。两个 2× 相消,因此 all-reduce 和 reduce-scatter 的有效带宽相同。

数据并行(Data Parallelism)#

本讲的所有并行策略都以 MLP(多层感知机) 为演示对象,而非完整的 Transformer。这个选择并非偷懒——MLP 正是 Transformer 中实际的计算瓶颈,因此 MLP 上的并行模式具有很强的代表性,而完整 Transformer 只是增加了更多 bookkeeping。

数据并行是最直观、最优雅的并行策略。它的核心思想是:切分数据,复制模型。每个 rank 持有完整的模型参数,但只处理全局 batch 的一个子集。

切分策略#

假设全局 batch size 为 BB,world size 为 WW,则每个 rank 分到的 local batch size 为 B/WB/W。对于一个 B×DB \times D 的数据矩阵(DD 是特征维度),按行切分,每个 rank 获得 BW×D\frac{B}{W} \times D 的数据切片。

local_batch_size = batch_size // world_size
start = rank * local_batch_size
data = full_data[start : start + local_batch_size].to(rank)

实际中,每个 rank 应该直接加载自己那部分数据,而非先全量加载再切分。

训练流程#

数据并行的训练循环与单 GPU 训练几乎完全相同,唯一的差异是在 backward pass 和 optimizer step 之间插入一步梯度同步

for step in range(num_steps):
# Forward: 每个 rank 用自己的 local batch
x = data
for layer in range(num_layers):
x = relu(x @ params[layer])
loss = x.sum()
# Backward
loss.backward()
# === 数据并行的关键一步 ===
for param in params:
dist.all_reduce(param.grad, op=dist.ReduceOp.AVG)
# Optimizer step
optimizer.step()

All-reduce 之后,所有 rank 上的梯度完全相同(因为是全局平均),进而参数更新也完全相同。这保证了训练过程中各 rank 的模型参数始终保持一致。

DDP 的优美之处#

数据并行对模型架构完全透明——它不关心 forward pass 的具体实现,只在梯度层面做同步。无论是 MLP、Transformer 还是其他架构,DDP 的改动都是同样的一行 all-reduce。这种模块化设计使得 DDP 成为最容易实现和调试的并行策略。

状态总结#

各 rank 间关系
数据(loss)不同(各自处理不同子集)
梯度(all-reduce 前)不同
梯度(all-reduce 后)相同
参数始终相同

约束条件#

  • 全局 batch size 至少为 world size(否则无法切分)
  • 理想情况下 batch size 是 world size 的整数倍(否则需要 padding)

DDP 的局限性#

DDP 要求每个 rank 持有完整的模型参数。如果模型太大以至于单卡放不下参数 + 梯度 + 优化器状态,DDP 就不够用了。这正是下一讲的 FSDP / ZeRO 要解决的问题——它们将 all-reduce 拆成 reduce-scatter + all-gather,在这两步之间插入参数分片逻辑,使得每个 rank 只需持有参数的一个分片。

通信与计算重叠的优化#

上面的实现在整个 backward pass 完成后才统一做 all-reduce,但实际上可以做得更高效:backward pass 中每一层的梯度计算完成后,立即发起该层梯度的 all-reduce(使用 async),同时继续计算前面层的梯度。这样通信和计算可以重叠,显著减少总等待时间。这也是 Assignment 2 中需要探索的优化。

张量并行(Tensor Parallelism)#

与数据并行切分 batch 维度不同,张量并行切分的是模型的参数矩阵本身——每个 rank 持有每一层参数的一个列切片(column shard)。这种方式也被称为 column tensor parallel(也存在按行切分的 row tensor parallel,但本讲不展开)。

切分策略#

对于一个 D×DD \times D 的参数矩阵,张量并行将其按列切分为 WW 份,每个 rank 持有 D×DWD \times \frac{D}{W} 的参数切片。为了简化演示,假设所有 rank 都持有完整的输入数据(实际中张量并行常与数据并行结合,此时每个 rank 不一定持有全部数据)。

local_num_dim = num_dim // world_size
# 每个 rank 持有所有层的参数切片
params[layer].shape = (num_dim, local_num_dim) # 而非 (num_dim, num_dim)

Forward Pass#

每个 rank 用完整的输入 XX(shape 为 B×DB \times D)乘以自己持有的参数切片,得到局部激活值(shape 为 B×DWB \times \frac{D}{W})。由于非线性激活(如 ReLU)是逐元素操作,可以直接在局部激活值上应用。

但在进入下一层之前,必须将各 rank 的局部激活值拼接成完整的 B×DB \times D 张量——否则下一层的矩阵乘法维度不匹配。这一步通过 all-gather 实现:

for layer in range(num_layers):
x = relu(x @ params[layer]) # x: (B, local_D)
# All-gather: 拼接各 rank 的局部激活值
activations = [torch.zeros(batch_size, local_num_dim) for _ in range(world_size)]
dist.all_gather_into_tensor(output, x)
x = torch.cat(activations, dim=1) # x: (B, D)

All-gather 与 Reduce-scatter 的对偶性#

张量并行有一个优雅的数学结构:

  • Forward pass:每层需要 all-gather(拼接局部激活值)
  • Backward pass:对应地需要 reduce-scatter(归约梯度并分片存储)

需要注意的是,backward pass 中的 reduce-scatter 不是 PyTorch autograd 自动完成的——在本课程的从零实现中,需要手动调用 dist.reduce_scatter_tensor()。工业框架(如 PyTorch 的 tensor parallel 模块)会自动处理这一步,但理解底层机制正是本课程的目的。

这种 all-gather / reduce-scatter 的对偶关系是分布式训练中反复出现的模式。

与数据并行的对比#

维度数据并行张量并行
切分对象batch 维度参数矩阵的列维度
模型侵入性无(模型不可感知)高(需要改写 forward/backward)
通信频率每个 training step 一次每一层 forward/backward 各一次
通信量梯度大小激活值大小(通常更大)
数学基础梯度的可加性矩阵乘法可按列拆分

张量并行的通信量明显更大——每一层都要做一次 all-gather,而数据并行只在整个 backward 结束后做一次 all-reduce。因此张量并行强烈依赖高带宽互连

流水线并行(Pipeline Parallelism)#

流水线并行的切分维度是网络的深度——将模型的不同层分配给不同 rank。每个 rank 持有一个层子集的完整参数(包括完整维度),但只负责这些层的计算。

切分策略#

local_num_layers = num_layers // world_size
# rank 0 持有 layer 0~3, rank 1 持有 layer 4~7, ...
# 每层参数维度不变: (num_dim, num_dim)

Forward Pass 与点对点通信#

与数据并行和张量并行使用集合操作不同,流水线并行使用点对点的 send/recv

for micro_batch in micro_batches:
if rank > 0:
dist.recv(x, src=rank - 1) # 从上一级接收激活值
for layer in local_layers:
x = relu(x @ params[layer])
if rank < world_size - 1:
dist.send(x, dst=rank + 1) # 发送激活值给下一级

Rank 0 接收原始数据,依次通过自己负责的层后将激活值发送给 rank 1,rank 1 处理后再发给 rank 2,以此类推。

Pipeline Bubble 问题#

流水线并行的核心问题是 pipeline bubble:当 rank 0 在处理数据时,rank 1、2、3 都在闲等;当 rank 3 在处理时,rank 0、1、2 又在闲等。这导致大量 GPU 时间被浪费在等待上。

Micro-batching:减少 Bubble#

解决 bubble 的关键技术是 micro-batching:将全局 batch 进一步切分为多个 micro-batch。每个 micro-batch 体积更小,处理更快,可以更快地传递给下一级。这样各 rank 可以交替处理不同的 micro-batch,大幅减少闲置时间。

通信与计算重叠#

流水线并行的结构天然适合做通信/计算重叠。上面展示的 naive 版本使用阻塞的 send/recv并未实现通信与计算的重叠。要实现重叠,需要将 send/recv 替换为非阻塞的 isend/irecv(加 “i” 前缀),并增加额外的代码来管理异步状态——这并不像听上去那么简单,但它是让流水线并行真正高效的关键。

并行策略的选择与权衡#

五种并行切分维度#

并行策略切分维度核心集合操作
Data parallelismbatch 维度all-reduce(DDP)/ all-gather + reduce-scatter(FSDP/ZeRO)
Tensor parallelism参数矩阵的宽度维度all-gather(forward)+ reduce-scatter(backward)
Pipeline parallelism网络深度维度point-to-point send/recv
Sequence parallelism序列长度维度专用通信(用于并行化 attention)
Expert parallelismexpert 维度(MoE)all-to-all

硬件约束决定策略组合#

并行策略的选择强烈依赖硬件拓扑。一个核心原则是:通信量大的并行策略必须跑在高带宽互连上

  • 张量并行:每层 forward/backward 都需要通信,通信量大 → 只应在 NVLink/NVSwitch domain 内使用(节点内 8 块或 NVL72 的 72 块 GPU)
  • 数据并行:每个 training step 只通信一次梯度 → 可以跨 Infiniband 节点使用
  • 流水线并行:通信量最小(只传激活值,且可以用 micro-batching 分摊)→ 甚至可以跨低速网络使用。去中心化训练方案中,GPU 可能分布在地理上遥远的位置,这时流水线并行是唯一可行的选择

典型的组合策略#

在大规模训练中,通常将多种并行策略嵌套组合

  • 节点内(NVLink domain):tensor parallelism
  • 节点间(Infiniband):data parallelism / FSDP
  • 跨 pod(低速网络):pipeline parallelism

Critical Batch Size#

数据并行理论上可以无限扩展——只要不断增大 batch size 即可。但存在一个 critical batch size:超过该值后,继续增大 batch size 对训练收敛的边际收益递减,额外的 GPU 算力被浪费。此时应该切换到张量并行或其他策略来利用多余的 GPU。

冗余计算 vs 通信的 Trade-off#

贯穿本讲所有并行策略的一个 meta-pattern 是 “重算/冗余存储 vs 通信”的 trade-off。在单 GPU 上,这表现为 activation checkpointing(重新计算激活值以节省显存);在多 GPU 上,这表现为数据并行的冗余参数存储——每个 rank 都持有完整参数,存在冗余存储和冗余计算,但换来了通信量的减少:不需要在 forward/backward 中搬运优化器状态。FSDP/ZeRO 则选择消除这种冗余,代价是引入更频繁的通信。可以将”存在另一块 GPU 上”看成是单 GPU 存储层次(L1 → L2 → HBM)的自然延伸。

JAX/TPU 的替代路径#

PyTorch 采用显式编程的方式:开发者手动指定使用哪种集合操作。JAX + TPU 提供了一种声明式的替代方案:开发者只需定义模型和 sharding strategy(哪些张量的哪个维度被切分到哪些设备),编译器自动推导所需的通信操作。这种方式更简洁,但屏蔽了底层细节——本课程选择 PyTorch 的显式路径正是为了让学生理解从原语到训练的完整链条。

Lecture 7:分布式并行训练基础
https://www.xwysyy.cn/posts/cs336/lec07/
作者
xwysyy
发布于
2026-05-17
许可协议
CC BY-NC-SA 4.0
© 2026 xwysyy. All Rights Reserved.
Powered by Astro & Firefly

文章目录