Lecture 2:PyTorch、einops 与资源核算
Stanford CS336 Language Modeling from Scratch | Spring 2026 | Lecture 2: PyTorch (einops),时长 01:17:25。
从 Scaling Law 预测到资源核算#
训练语言模型的核心目标可以表述为:在有限的资源预算(计算、内存、有时还有数据)下,训练出尽可能好的模型。要实现这一目标,首先需要理解”效率”的定量含义——我们需要能够回答下面这类问题:
在 1024 块 H100 上,训练一个 70B 参数的模型、跑 15 万亿 token,需要多长时间?
这类问题的回答框架非常直接:先计算总 FLOPs(6×70×109×15×1012),再查硬件 spec sheet 得到 FLOP/s,乘以一个经验性的利用率系数 MFU(约 0.5),除出天数——答案是大约 143 天。
再看一个内存维度的问题:用 AdamW 在 8 块 H100 上,最大能训多大的模型? 每块 H100 有 80GB HBM,使用 AdamW 时每个参数需要 2+2+4+4=12 字节(参数 bf16 + 梯度 bf16 + 一阶矩 fp32 + 二阶矩 fp32),所以 80×109×8/12≈53×109,即约 530 亿参数——当然这还没有计入 activation 的开销(取决于 batch size 和序列长度),属于粗略的 back-of-the-envelope 估算。
在正式进入 resource accounting 的细节之前,值得一提的是 Marin 项目的一次实际验证。该项目使用 1023 FLOPs 的预算进行了一次训练运行,其结果与基于 IsoFLOPs 曲线拟合的 scaling law 预测高度吻合——预测 loss 与实际 loss 之差仅在 0.05 以内。

图中每条曲线是一条 IsoFLOPs 曲线(在固定计算预算下,用一系列较小模型做 runs,找到 compute-optimal 点)。虚线是拟合的 Compute-Optimal Frontier。标注的 Δloss 表示预测值与实际值之间的偏差。将这条 frontier 外推到 GPT-5 等量级的计算预算,可以得到相应的 loss 估计——当然外推的可靠性取决于 scaling law 在该区间的适用性。
这门课的核心心态是:写每一行代码时,都要思考其性能特征——不是要精确计算每一个数字,而是要建立起对 FLOPs 和内存量级的直觉判断。
Tensor 与浮点精度#
Tensor:一切的基础构建块#
在深度学习训练中,参数(parameters)、梯度(gradients)、优化器状态(optimizer states)、数据(data)和激活值(activations) 全部以 tensor 的形式存储。Tensor 囊括了向量、矩阵以及任意阶的多维数组。以 DeepSeek v3.2 模型为例,模型本身就是一组 tensor,每个 tensor 有各自的 shape 和 precision。
Tensor 的内存占用计算非常直接:元素个数 × 每个元素的字节数。例如一个 4×8 的 float32 矩阵占用 32×4=128 字节。但不要小看这个乘法——GPT-3 中 feedforward 层的单个权重矩阵(dmodel×4dmodel)就占约 2.3 GB(float32 下)。
float32:传统基准#

float32(又称 fp32 或 single precision)使用 32 位表示一个浮点数:1 位符号 + 8 位指数(exponent)+ 23 位尾数(mantissa/fraction)。8 位指数提供了很大的动态范围(dynamic range),23 位尾数提供了较高的精度(resolution)。
“单精度”这个名字来自科学计算的传统——float32 曾是浮点数的”标准配置”,float64(双精度)是需要更高精度时的升级选项。但在深度学习中,方向是反过来的:即便 32 位也嫌多,因为深度学习中的计算对精度的需求远低于传统数值仿真。
在 PyTorch 中创建 tensor 时,默认 dtype 就是 float32。创建后默认在 CPU 上——要获得 GPU 加速,需要显式地 .to('cuda') 或 cuda_if_available()。
float16:动态范围不足#
将 float32 “砍半”得到 float16:1 位符号 + 5 位指数 + 10 位尾数,共 16 位。内存直接减半,理论上计算速度翻倍。但问题在于 5 位指数的动态范围太小。一个直观的例子:用 float16 表示 1×10−8 时,结果直接变成 0——发生了 underflow。
如果直接用 fp16 训练(早期确实有人这么做),会频繁遇到 underflow、overflow 和 NaN,训练极不稳定。
bfloat16:深度学习的甜蜜点#

bfloat16(brain floating point 16)由 Google Brain 在 2018 年提出,核心思路是:保持 16 位的总长度不变,但将部分 mantissa 的位数让给 exponent——结果是 1 位符号 + 8 位指数 + 7 位尾数。
这意味着 bfloat16 与 float32 有相同的动态范围(都是 8 位指数),代价是 resolution 更差(7 位 vs 23 位尾数)。在深度学习中这个 trade-off 通常是划算的:训练过程中的梯度更新本身就是 stochastic 的,对 resolution 的敏感度远低于对动态范围的依赖。不过需要注意,即便是 bf16 也并非完全没有风险——这正是混合精度训练(optimizer states 保持 fp32)存在的原因。用 bfloat16 表示 1×10−8 不会 underflow——值为 1.001×10−8,虽然精度有限但至少是个有意义的非零值。
三种格式的 torch.finfo 对比:
| 格式 | resolution | 最小值 | 最大值 | 每值字节 |
|---|---|---|---|---|
| float32 | 10−6 | −3.4×1038 | 3.4×1038 | 4 |
| float16 | 10−3 | −65504 | 65504 | 2 |
| bfloat16 | 10−2 | −3.4×1038 | 3.4×1038 | 2 |
fp8 与 fp4:更激进的低精度#

fp8 已经标准化(Micikevicius+ 2022),有两个变体:E4M3(4 位指数 + 3 位尾数,范围 [−448,448])和 E5M2(5 位指数 + 2 位尾数,范围 [−57344,57344]),分别适用于需要更高 resolution 或更大动态范围的场景。NVIDIA 的 Transformer Engine 已支持 fp8 训练。
更极端的是 nvfp4(NVIDIA 2025),每个值仅 4 位。4 bit 能表示的值非常有限——全部值可以列在一行:−6,−4,−3,−2,−1.5,−1.0,−0.5,0.0,0.5,1.0,1.5,2,3,4,6。但 fp4 采用 block scaling 机制:在每个 block 内,所有值共享一个 scale factor(用更多位表示)。这样单个值看到的有效动态范围大于 4 位,只是同一 block 内的值不能有极端的相对差异。Nemotron 3 Super 已用 fp4 完成训练。
需要注意的是,fp8/fp4 的很多操作由 NVIDIA 软件栈在底层处理,用户并不能简单地 torch.zeros(..., dtype=torch.fp4) 来创建 tensor。
另外,训练(training)和推理(inference)的精度需求不同。推理阶段可以把 bf16 训好的模型 quantize 到 1-2 bit——这比直接用 1 bit 训练要容易得多。据目前所知,还没有人用 1 bit 精度从头训练出有意义的语言模型。
混合精度训练#
实践中的标准做法是 mixed precision training:
- bf16 用于参数(parameters)、激活值(activations)和梯度(gradients)
- fp32 用于优化器状态(optimizer states)
PyTorch 提供了 AMP(Automatic Mixed Precision)库来自动管理精度切换:MatMul 等运算会被 cast 到 bf16(安全且快速),而 exponentiation 等对精度敏感的运算会保持 fp32。使用时只需将代码包裹在 torch.cuda.amp.autocast 上下文中。
einops:用命名维度思考张量操作#
动机:告别”minus 2, minus 1”#
传统 PyTorch 代码中,张量操作依赖数字索引来指定维度——比如 y.transpose(-2, -1) 来转置最后两个维度,或 x.sum(dim=-1) 来沿最后一个维度求和。这种写法容易出错且难以阅读:当你看到 torch.matmul(x, y.transpose(-2, -1)) 时,需要在心里推演 -2 和 -1 分别对应什么。
einops(Einstein Operations)通过命名维度来解决这个问题。其灵感来源于 Einstein 求和约定(Einstein summation convention),核心思想是:给每个维度一个有意义的名字,让操作的语义直接可读。
einsum:通用矩阵乘法#
einsum 可以理解为”带有良好记账的通用矩阵乘法”。其工作方式是:
- 为输入 tensor 的每个维度命名
- 指定输出 tensor 的维度
- 未出现在输出中的维度自动被 sum 掉
一个简单的矩阵乘法 X∈R3×4,Y∈R4×3:
1# 传统写法2z = x @ y3
4# einsum 写法5z = einsum(x, y, "seq1 hidden, hidden seq2 -> seq1 seq2")einsum 的语义是:对 seq1、hidden、seq2 三个变量的所有取值做枚举,将 xseq1,hidden⋅yhidden,seq2 的乘积累加到 zseq1,seq2 中。hidden 没有出现在输出中,所以被 sum 掉——这正是矩阵乘法的定义。
einsum 的优势在复杂场景中尤其明显。考虑一个 batched attention score 计算,输入 X∈R2×3×4(batch × seq × hidden)和 Y∈R2×4×3:
1# 传统写法——需要理解 @ 对前导维度做 batch、transpose 作用于哪两个轴2z = torch.matmul(x, y.transpose(-2, -1))3
4# einsum 写法——语义自明5z = einsum(x, y, "batch seq1 hidden, batch seq2 hidden -> batch seq1 seq2")注意 einsum 中 不需要显式 transpose——维度名已经隐含了对齐关系。如果 Y 的维度顺序是 batch hidden seq2 而非 batch seq2 hidden,那才是不需要 transpose 的情况;这里 hidden 在不同位置出现,einsum 自动处理了对齐。
当 batch 维度很多时(比如 batch + sequence + head 三层嵌套),可以用省略号 ... 来代替所有 batch 维度:
1z = einsum(x, y, "... seq1 hidden, ... seq2 hidden -> ... seq1 seq2")这使得代码具有模块性——不需要知道传入 tensor 具体有几个 batch 维度。
reduce:通用聚合#
reduce 是 sum、mean、max、min 的泛化。传统写法 x.sum(dim=-1) 在 einops 中变为:
1result = reduce(x, "... hidden -> ...", "sum")语义是:hidden 维度在输出中不出现,所以对它做 sum(或 mean/max/min)。这样维度名的存在让操作意图一目了然,不需要去数 dim=-1 到底是哪个维度。
从性能角度看,reduce 只是语法糖——最终调用的是相同的底层 primitive,没有额外开销。
rearrange:维度的拆分与合并#
rearrange 处理的是维度的 reshape 操作,尤其是一个维度实际上是两个维度的乘积的情况。这在 multi-head attention 中很常见:将 hidden dimension 拆分为 heads × head_dim。
1# x 的 shape 是 (3, 8),其中 8 = 2 heads × 4 hidden_per_head2x = rearrange(x, "... (heads hidden1) -> ... heads hidden1", heads=2)3# 现在 x 的 shape 是 (3, 2, 4)括号语法 (heads hidden1) 表示这个维度可以分解为 heads 和 hidden1 的乘积。需要指定其中一个的大小(heads=2),另一个自动推导。
做完 per-head 的变换后,可以用 rearrange 合并回去:
1x = rearrange(x, "... heads hidden2 -> ... (heads hidden2)")关于拆分的顺序(row-major vs column-major),rearrange 中括号内变量的书写顺序决定了拆分方式——写在前面的变量变化更慢(类似 row-major)。
einops 需要一定的学习成本,但一旦习惯了用命名维度思考,所有的 transpose、reduce、reshape 操作都变得直觉化,不再需要与数字索引搏斗。
FLOPs 计算与硬件规格#
FLOPs 与 FLOP/s:一字之差#
在讨论计算量时,“FLOPs”一词有两个含义,必须区分清楚:
- FLOPs(小写 s)= floating-point operations,计算总量的度量。例如”GPT-3 训练用了 3.25×1023 FLOPs”。
- FLOP/s = floating-point operations per second,硬件速度的度量。例如”H100 的 bf16 峰值是 989 TFLOP/s”。
有时 FLOP/s 也被写成大写 FLOPS,但本文一律用 /s 标注以避免混淆。
一个 FLOP 定义为一次基本浮点运算(加法或乘法)。GPU 还能做其他运算,但加法和乘法是绝对主体,占据了几乎所有的计算时间。
MatMul 的 FLOPs:2BDK#
矩阵乘法(MatMul)的 FLOPs 计算是整个 resource accounting 的核心。对于 X∈RB×D 乘以 W∈RD×K:
FLOPs=2×B×D×K
因子 2 来自:每个输出元素需要 D 次乘法和 D−1 次加法,近似为 2D 次运算,共 B×K 个输出元素。严格来说加法是 D−1 次而非 D 次,但在实际规模下这个差异可以忽略。
将 B 理解为数据点/token 数、D×K 理解为参数量,这个公式可以改写为:
FLOPsforward=2×∣tokens∣×∣parameters∣
其他元素级操作(elementwise addition、ReLU 等)的 FLOPs 只是矩阵的大小(m×n),远小于 MatMul 的 O(m×n×k)。所以在矩阵足够大的前提下,只需要关注 MatMul 的 FLOPs。
关于是否有亚立方的矩阵乘法算法(如 Strassen)可以利用:实际中 GPU 上的矩阵乘法优化主要依靠与硬件的 co-design(内存层次、tiling 等),而非渐近更优的算法。同样,虽然直觉上加法应该比乘法快,但在现代硬件设计中二者的速度基本相同。
H100 规格与 MFU#
NVIDIA H100 的 spec sheet 标注 bf16 峰值为 1979 TFLOP/s——但有个关键细节:这个数字包含了 structured sparsity(2:4 稀疏)。对于 dense 计算,需要除以 2,即 约 989 TFLOP/s。
有了 FLOP/s 的理论值和实际 benchmark 值,就可以定义 MFU(Model FLOPs Utilization):
MFU=spec sheet FLOP/s(dense)实际 FLOP/s
MFU 衡量的是:你的代码在多大程度上”兑现”了硬件的承诺。经验值:
- MFU ≈ 0.5 对于现代大模型训练,算是不错的水平
- 纯 MatMul 可以达到 0.7-0.8
- MFU ≈ 0.1 说明有严重问题需要排查
MFU 低于 1.0 的原因将在后续的算术强度分析中变得清晰——本质上是因为 memory bandwidth 的限制。
直觉量级#
为了建立 FLOPs 的量级直觉:8 块 H100 跑一周,假设 MFU = 1(理想情况),大约能完成 8×989×1012×7×86400≈5×1021 FLOPs。这就是为什么 1022 到 1025 量级的训练需要数百到数千块 GPU 跑数周到数月。
计算 MFU 的实际流程:
- 根据模型结构和 batch size 计算逻辑 FLOPs
- 在 GPU 上 benchmark wall clock time(注意要用
torch.cuda.synchronize()确保异步操作完成) - 实际 FLOP/s = 逻辑 FLOPs / wall clock time
- MFU = 实际 FLOP/s / spec sheet FLOP/s
benchmark 时一个常见的坑是忘记 cuda.synchronize()——因为 CUDA 操作是异步的,torch.matmul 之类的调用会立即返回,如果不加同步屏障就会误以为计算很快。正确做法是在操作前后各加一次 synchronize,并多次运行取平均。
算术强度与 Roofline 分析#
计算不只是”做运算”#
到目前为止我们只关注了 FLOPs 本身,但实际运行时间还取决于另一个因素:数据在 HBM(高带宽内存)和计算核心之间的搬运开销。
GPU 的计算模型可以简化为三步:(1) 将输入 tensor 从 HBM 送到计算核心;(2) 执行计算;(3) 将结果写回 HBM。因此一次操作的总时间取决于两个因素:
- Accelerator speed(FLOP/s):计算核心的运算速度
- Memory bandwidth(bytes/s):HBM 与计算核心之间的数据传输速度
H100 的两个关键参数:
- bf16 dense FLOP/s = 1979×1012/2≈989 TFLOP/s
- Memory bandwidth = 3.35 TB/s
实际中,communication 和 computation 可以重叠(pipeline)。在理想重叠假设下,总时间 = max(通信时间, 计算时间)。
逐例分析:从 ReLU 到 MatMul#
ReLU:极度 memory bound
对长度为 n(bf16)的向量做 ReLU:
- 搬运字节数:读 x(2n)+ 写 y(2n)= 4n bytes
- FLOPs:n(每个元素做一次比较)
- 通信时间 = 4n/3.35×1012≈1.2×10−6 s(n=106 时)
- 计算时间 = n/989×1012≈1.0×10−9 s
通信时间是计算时间的 1000 倍。GPU 的计算核心大部分时间都在等数据到来。
GELU:仍然 memory bound
GELU 的每个元素涉及约 20 次浮点运算(包含 tanh、乘法等),FLOPs 是 ReLU 的 20 倍。但搬运字节数不变(仍是 4n)。所以虽然 GELU 的 FLOPs 多了 20 倍,实际运行时间和 ReLU 完全一样——因为瓶颈在搬运而不在计算。这是一个反直觉的结论:看起来复杂得多的运算,实际上并不更慢。
向量点积:memory bound
x⋅w,x,w∈Rn(bf16):
- 搬运字节数:2n+2n+2=4n+2 bytes
- FLOPs:2n−1(n 次乘法 + n−1 次加法)
- 算术强度 ≈ 0.5(每搬运 1 byte 只做 0.5 次运算)
矩阵-向量乘:仍然 memory bound
Wx,W∈Rn×n,x∈Rn:
- 搬运字节数:2n+2n2+2n bytes(读 x、读 W、写 y)
- FLOPs:n×(2n−1)≈2n2
- 算术强度 ≈ 1(刚刚超过 0.5)
仍远低于 H100 的阈值,依然 memory bound。这一点预示了 Transformer 推理的性能特征:推理时是逐 token 生成,每次是一个 vector 与权重矩阵的乘法,因此天然 memory bound。
矩阵-矩阵乘(MatMul):终于 compute bound
XW,X,W∈Rn×n(n=1024):
- 搬运字节数:2n2+2n2+2n2=6n2 bytes
- FLOPs:2n3
- 算术强度 = 2n3/6n2=n/3≈340
当 n=1024 时,算术强度约 340,超过了 H100 的加速器强度 295。这意味着 MatMul 终于是 compute bound——GPU 的计算核心在满载工作。
算术强度:统一的判断框架#
加速器强度(accelerator intensity) 定义为硬件 spec 的 FLOP/s 除以 bytes/s:
Iaccelerator=bytes/sFLOP/s=3.35×1012989×1012≈295 FLOPs/byte
含义是:H100 每搬运 1 byte 可以做 295 次浮点运算。
算法的算术强度(arithmetic intensity) 定义为该算法执行的 FLOPs 除以搬运的 bytes:
Ialgorithm=bytes movedFLOPs
判断规则非常简单:
- Ialgorithm<Iaccelerator → memory bound(搬运是瓶颈)
- Ialgorithm>Iaccelerator → compute bound(计算是瓶颈)
| 操作 | 算术强度 | 判定 |
|---|---|---|
| ReLU | 0.25 | memory bound |
| GELU | ~5 | memory bound |
| 向量点积 | ~0.5 | memory bound |
| 矩阵-向量乘 | ~1 | memory bound |
| 矩阵-矩阵乘 (n=1024) | ~340 | compute bound |
矩阵乘法的算术强度随矩阵规模增长(∝n/3)。直觉非常清晰:读入和写出的数据量是 O(n2),但计算量是 O(n3),二者之比就是 O(n)——矩阵越大,每搬运一个 byte 能做的”有用功”越多。这就是为什么大 batch size 和大隐藏维度对 GPU 利用率如此重要。如果算术强度低于 accelerator intensity,缩小矩阵并不会加速计算(因为时间被搬运占据),而一旦超过阈值,GPU 才真正在”干活”。
Roofline 图#

Roofline 图的横轴是算术强度(log scale),纵轴是实际实现的 FLOP/s(log scale)。对于给定的加速器:
- 左侧斜线区域(bandwidth bound):算术强度低于阈值时,实际 FLOP/s 受限于 memory bandwidth,呈线性增长(斜率 = bandwidth)。不同的 bandwidth 等级(BW1,BW2)对应不同的斜线。
- 右侧水平区域(compute bound):算术强度超过阈值后,实际 FLOP/s 趋近于加速器峰值,不再随算术强度增加而提升。
图中 Algo 1 落在斜线区域(bandwidth bound at both BW1 and BW2),Algo 2 落在水平区域(compute bound)。中间区域可能出现”在 BW1 下 bandwidth bound 但在 BW2 下 compute bound”的情况。
这解释了为什么 MFU 通常只有 0.5 左右:虽然 Transformer 的主体是 MatMul(compute bound),但中间穿插的 elementwise 操作(layernorm、activation functions、attention softmax 等)都是 memory bound 的,它们拉低了整体的 MFU。
另外需要注意,算术强度也依赖精度。同样的操作在 bf16 和 fp32 下有不同的搬运字节数和不同的硬件 FLOP/s,因此 roofline 图会随精度变化。
训练过程的资源分析#
运行样例:深层网络#

为了具体分析训练的资源消耗,考虑一个简单但有代表性的深层网络:输入 X∈RB×D(B 个数据点,每个 D 维),经过 L 层,每层包含一个 D×D 的线性变换和一个 elementwise ReLU 激活。输出与输入同维。
1class DeepNetwork(nn.Module):2 def __init__(self, dim, num_layers):3 self.blocks = nn.ModuleList([Block(dim) for _ in range(num_layers)])4
5 def forward(self, x):6 for block in self.blocks:7 x = block(x) # linear + ReLU8 return x参数总量 = D2×L(每层一个 D×D 权重矩阵)。
前向传播的 FLOPs#
每层的前向传播就是一次 B×D 乘以 D×D 的 MatMul,FLOPs = 2×B×D×D=2BD2。L 层合计:
FLOPsforward=L×2BD2=2B×(D2L)=2×∣tokens∣×∣params∣
反向传播的 FLOPs:正好是前向的两倍#
反向传播通过 chain rule 传播梯度。聚焦第 l 层:hl=hl−1⋅Wl(忽略 ReLU 简化分析)。反向时需要计算两个梯度:
- 对输入的梯度(用于向前一层继续传播):
∂hl−1∂L=einsum(∇hl,Wl,"batch out, in out -> batch in")
这是一个 MatMul,FLOPs = 2BD2。
- 对参数的梯度(用于更新权重):
∂Wl∂L=einsum(∇hl,hl−1,"batch out, batch in -> in out")
也是一个 MatMul,FLOPs = 2BD2。
注意这里 einsum 的优势:不需要记忆哪个有 transpose——维度名已经编码了正确的对齐方式。三个 einsum(前向 + 两个反向)的 FLOPs 完全相同(都是 2BD2),因为 FLOPs 只取决于三个维度的乘积,与哪些维度是 batch、哪些被 sum 无关。
所以每层的反向 FLOPs = 2×2BD2=4BD2,恰好是前向的 2 倍。
6nd 公式#
对整个网络:
- 前向 FLOPs = 2×B×∣params∣
- 反向 FLOPs = 4×B×∣params∣
- 一个训练步的总 FLOPs = 6×B×∣params∣
这就是广泛使用的 6nd 公式(n = 参数量,d = token 数 / batch size)的来源。
对 Transformer 而言,这个公式在 context length 不太长的情况下是一个很好的近似。当 context length 很大时,attention 的 O(seq2) 项会贡献额外的 FLOPs,不在这个线性近似中。
训练内存分解#
训练时 GPU 内存需要同时容纳四个部分:
| 组成部分 | 每参数字节数 | 说明 |
|---|---|---|
| 参数 | 2(bf16) | 模型权重本身 |
| 梯度 | 2(bf16) | 与参数 shape 相同 |
| 优化器状态 | 4(AdaGrad/fp32)或 8(Adam/fp32) | 见下文详述 |
| 激活值 | 2×B×D×L(bf16) | 每层存 B×D 的激活(反向传播需要) |
理解优化器状态的内存开销,需要先理解几种优化器之间的关系。AdaGrad(2011 年提出)可以看作介于 SGD 和 Adam 之间的算法:SGD 只用梯度本身;Momentum(SGD 的变体)跟踪梯度的一阶矩(即梯度的指数移动平均);AdaGrad 跟踪梯度的二阶矩(即历史梯度平方和);Adam 则是将一阶矩和二阶矩结合。这解释了为什么不同优化器的 per-parameter 内存开销不同:AdaGrad 需要 4 bytes(存一个 fp32 的二阶矩),而 Adam 需要 4+4=8 bytes(分别存一阶和二阶矩)。
优化器状态使用 fp32 而非 bf16 是出于稳定性考虑:累积的平方梯度经过多步累加后,bf16 的精度不够用。Adam 的 8 bytes/param 意味着优化器状态往往是最大的内存消费者——比参数本身还要大 4 倍。
不过,优化器状态虽然占内存大,但不是 speed 的瓶颈——optimizer step 主要是 elementwise 操作(memory bound 但很快),不涉及大矩阵乘法。它的影响主要体现在:限制了能放进 HBM 的最大模型大小。
将四部分加起来,以 Adam 为例:每个参数需要 2+2+8=12 bytes,加上 activation 的开销(取决于 B 和 L)。这就是开头的 back-of-the-envelope 计算的来源。
内存优化技术#
训练大模型时,activation memory(∝B×D×L)往往是 OOM 的直接原因,因为它随 batch size 线性增长。有两个经典技术来缓解这一问题。
梯度累积(Gradient Accumulation)#
训练通常需要较大的 batch size 来保证梯度估计的稳定性(存在一个 critical batch size,由后续课程讨论)。但大 batch size 意味着大 activation memory。
梯度累积的思路非常简单:将一个大 batch 拆分为若干 micro-batch,对每个 micro-batch 分别做 forward + backward 并累加梯度(不 zero-out),每积累够 batch_size/micro_batch_size 步后再执行一次 optimizer step 并清零梯度。
1for i, micro_batch in enumerate(micro_batches):2 loss = model(micro_batch)3 loss.backward() # 梯度累加到 .grad4 if (i + 1) % accumulation_steps == 0:5 optimizer.step()6 optimizer.zero_grad()这个简单的代码改动使得 activation memory 只需容纳一个 micro-batch,而梯度效果等价于大 batch。代价是每个 micro-batch 的 forward/backward 需要串行执行,但总 FLOPs 不变。
激活值检查点(Activation Checkpointing)#
默认情况下,前向传播会保存所有中间层的激活值(pre-ReLU 和 post-ReLU 都保存),以便反向传播使用。这使得 activation memory = 2×B×D×L(每层两个激活 tensor)。
推理时不需要梯度,所以只需要保存当前层的激活——内存 O(BD),与层数 L 无关。但训练时必须保存所有层的激活供 backward 使用——除非我们愿意重新计算。
Activation checkpointing(又称 gradient checkpointing 或 rematerialization)的核心思想是:前向传播时只保存部分层的激活(checkpoint),反向传播时从最近的 checkpoint 重新计算缺失的激活。这是经典的 时间换空间 trade-off。
在 PyTorch 中实现非常简单——用 torch.utils.checkpoint 包裹需要 checkpoint 的层:
1from torch.utils.checkpoint import checkpoint2
3def forward(self, x):4 for block in self.blocks:5 x = checkpoint(block, x) # 不保存 block 内部的中间激活6 return x如果对每个 block(linear + ReLU)做 checkpoint,pre-ReLU 的激活不再保存,activation memory 直接减半。反向到该 block 时,从保存的 block 输入重新做一次 forward 得到 pre-ReLU 激活,然后继续 backward——额外计算开销约 33%(重新算一次 forward)。
三种策略的对比:
| 策略 | Activation Memory | 重计算开销 |
|---|---|---|
| 不做 checkpoint | O(BDL) | 0 |
| 每层 checkpoint | O(BD⋅L)(减半常数) | ~33% |
| 每 L 层 checkpoint | O(BD⋅L) | O(L) 倍 |
| 极端:只存输入 | O(BD) | O(L) 倍(L2 总 FLOPs) |
最后一种策略虽然内存最省,但重计算开销是 L2(每层 backward 时都要从头 forward 到该层)。一个平衡的选择是 每 L 层设一个 checkpoint:activation memory 降为 O(BDL),重计算开销也是 O(L),二者平衡。
部分内容可能已过时