Lecture 2:PyTorch、einops 与资源核算

6306 字
32 分钟
Lecture 2:PyTorch、einops 与资源核算
文章摘要

Stanford CS336 Language Modeling from Scratch | Spring 2026 | Lecture 2: PyTorch (einops),时长 01:17:25。

从 Scaling Law 预测到资源核算#

训练语言模型的核心目标可以表述为:在有限的资源预算(计算、内存、有时还有数据)下,训练出尽可能好的模型。要实现这一目标,首先需要理解”效率”的定量含义——我们需要能够回答下面这类问题:

在 1024 块 H100 上,训练一个 70B 参数的模型、跑 15 万亿 token,需要多长时间?

这类问题的回答框架非常直接:先计算总 FLOPs(6×70×109×15×10126 \times 70 \times 10^9 \times 15 \times 10^{12}),再查硬件 spec sheet 得到 FLOP/s,乘以一个经验性的利用率系数 MFU(约 0.5),除出天数——答案是大约 143 天

再看一个内存维度的问题:用 AdamW 在 8 块 H100 上,最大能训多大的模型? 每块 H100 有 80GB HBM,使用 AdamW 时每个参数需要 2+2+4+4=122 + 2 + 4 + 4 = 12 字节(参数 bf16 + 梯度 bf16 + 一阶矩 fp32 + 二阶矩 fp32),所以 80×109×8/1253×10980 \times 10^9 \times 8 / 12 \approx 53 \times 10^9,即约 530 亿参数——当然这还没有计入 activation 的开销(取决于 batch size 和序列长度),属于粗略的 back-of-the-envelope 估算。

在正式进入 resource accounting 的细节之前,值得一提的是 Marin 项目的一次实际验证。该项目使用 102310^{23} FLOPs 的预算进行了一次训练运行,其结果与基于 IsoFLOPs 曲线拟合的 scaling law 预测高度吻合——预测 loss 与实际 loss 之差仅在 0.05 以内。

图1:Delphi Scaling Suite 实验结果
图1:Delphi Scaling Suite 实验结果

图中每条曲线是一条 IsoFLOPs 曲线(在固定计算预算下,用一系列较小模型做 runs,找到 compute-optimal 点)。虚线是拟合的 Compute-Optimal Frontier。标注的 Δloss\Delta\text{loss} 表示预测值与实际值之间的偏差。将这条 frontier 外推到 GPT-5 等量级的计算预算,可以得到相应的 loss 估计——当然外推的可靠性取决于 scaling law 在该区间的适用性。

这门课的核心心态是:写每一行代码时,都要思考其性能特征——不是要精确计算每一个数字,而是要建立起对 FLOPs 和内存量级的直觉判断。

Tensor 与浮点精度#

Tensor:一切的基础构建块#

在深度学习训练中,参数(parameters)、梯度(gradients)、优化器状态(optimizer states)、数据(data)和激活值(activations) 全部以 tensor 的形式存储。Tensor 囊括了向量、矩阵以及任意阶的多维数组。以 DeepSeek v3.2 模型为例,模型本身就是一组 tensor,每个 tensor 有各自的 shape 和 precision。

Tensor 的内存占用计算非常直接:元素个数 × 每个元素的字节数。例如一个 4×84 \times 8 的 float32 矩阵占用 32×4=12832 \times 4 = 128 字节。但不要小看这个乘法——GPT-3 中 feedforward 层的单个权重矩阵(dmodel×4dmodeld_{\text{model}} \times 4d_{\text{model}})就占约 2.3 GB(float32 下)。

float32:传统基准#

图2:IEEE 754 float32 位布局
图2:IEEE 754 float32 位布局

float32(又称 fp32 或 single precision)使用 32 位表示一个浮点数:1 位符号 + 8 位指数(exponent)+ 23 位尾数(mantissa/fraction)。8 位指数提供了很大的动态范围(dynamic range),23 位尾数提供了较高的精度(resolution)。

“单精度”这个名字来自科学计算的传统——float32 曾是浮点数的”标准配置”,float64(双精度)是需要更高精度时的升级选项。但在深度学习中,方向是反过来的:即便 32 位也嫌多,因为深度学习中的计算对精度的需求远低于传统数值仿真。

在 PyTorch 中创建 tensor 时,默认 dtype 就是 float32。创建后默认在 CPU 上——要获得 GPU 加速,需要显式地 .to('cuda')cuda_if_available()

float16:动态范围不足#

将 float32 “砍半”得到 float16:1 位符号 + 5 位指数 + 10 位尾数,共 16 位。内存直接减半,理论上计算速度翻倍。但问题在于 5 位指数的动态范围太小。一个直观的例子:用 float16 表示 1×1081 \times 10^{-8} 时,结果直接变成 0——发生了 underflow。

如果直接用 fp16 训练(早期确实有人这么做),会频繁遇到 underflow、overflow 和 NaN,训练极不稳定。

bfloat16:深度学习的甜蜜点#

图3:bfloat16 位布局与 underflow 对比
图3:bfloat16 位布局与 underflow 对比

bfloat16(brain floating point 16)由 Google Brain 在 2018 年提出,核心思路是:保持 16 位的总长度不变,但将部分 mantissa 的位数让给 exponent——结果是 1 位符号 + 8 位指数 + 7 位尾数

这意味着 bfloat16 与 float32 有相同的动态范围(都是 8 位指数),代价是 resolution 更差(7 位 vs 23 位尾数)。在深度学习中这个 trade-off 通常是划算的:训练过程中的梯度更新本身就是 stochastic 的,对 resolution 的敏感度远低于对动态范围的依赖。不过需要注意,即便是 bf16 也并非完全没有风险——这正是混合精度训练(optimizer states 保持 fp32)存在的原因。用 bfloat16 表示 1×1081 \times 10^{-8} 不会 underflow——值为 1.001×1081.001 \times 10^{-8},虽然精度有限但至少是个有意义的非零值。

三种格式的 torch.finfo 对比:

格式resolution最小值最大值每值字节
float3210610^{-6}3.4×1038-3.4 \times 10^{38}3.4×10383.4 \times 10^{38}4
float1610310^{-3}65504-6550465504655042
bfloat1610210^{-2}3.4×1038-3.4 \times 10^{38}3.4×10383.4 \times 10^{38}2

fp8 与 fp4:更激进的低精度#

图4:FP8 两种变体与 FP4 格式
图4:FP8 两种变体与 FP4 格式

fp8 已经标准化(Micikevicius+ 2022),有两个变体:E4M3(4 位指数 + 3 位尾数,范围 [448,448][-448, 448])和 E5M2(5 位指数 + 2 位尾数,范围 [57344,57344][-57344, 57344]),分别适用于需要更高 resolution 或更大动态范围的场景。NVIDIA 的 Transformer Engine 已支持 fp8 训练。

更极端的是 nvfp4(NVIDIA 2025),每个值仅 4 位。4 bit 能表示的值非常有限——全部值可以列在一行:6,4,3,2,1.5,1.0,0.5,0.0,0.5,1.0,1.5,2,3,4,6-6, -4, -3, -2, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2, 3, 4, 6。但 fp4 采用 block scaling 机制:在每个 block 内,所有值共享一个 scale factor(用更多位表示)。这样单个值看到的有效动态范围大于 4 位,只是同一 block 内的值不能有极端的相对差异。Nemotron 3 Super 已用 fp4 完成训练。

需要注意的是,fp8/fp4 的很多操作由 NVIDIA 软件栈在底层处理,用户并不能简单地 torch.zeros(..., dtype=torch.fp4) 来创建 tensor。

另外,训练(training)和推理(inference)的精度需求不同。推理阶段可以把 bf16 训好的模型 quantize 到 1-2 bit——这比直接用 1 bit 训练要容易得多。据目前所知,还没有人用 1 bit 精度从头训练出有意义的语言模型。

混合精度训练#

实践中的标准做法是 mixed precision training

  • bf16 用于参数(parameters)、激活值(activations)和梯度(gradients)
  • fp32 用于优化器状态(optimizer states)

PyTorch 提供了 AMP(Automatic Mixed Precision)库来自动管理精度切换:MatMul 等运算会被 cast 到 bf16(安全且快速),而 exponentiation 等对精度敏感的运算会保持 fp32。使用时只需将代码包裹在 torch.cuda.amp.autocast 上下文中。

einops:用命名维度思考张量操作#

动机:告别”minus 2, minus 1”#

传统 PyTorch 代码中,张量操作依赖数字索引来指定维度——比如 y.transpose(-2, -1) 来转置最后两个维度,或 x.sum(dim=-1) 来沿最后一个维度求和。这种写法容易出错且难以阅读:当你看到 torch.matmul(x, y.transpose(-2, -1)) 时,需要在心里推演 -2-1 分别对应什么。

einops(Einstein Operations)通过命名维度来解决这个问题。其灵感来源于 Einstein 求和约定(Einstein summation convention),核心思想是:给每个维度一个有意义的名字,让操作的语义直接可读。

einsum:通用矩阵乘法#

einsum 可以理解为”带有良好记账的通用矩阵乘法”。其工作方式是:

  1. 为输入 tensor 的每个维度命名
  2. 指定输出 tensor 的维度
  3. 未出现在输出中的维度自动被 sum 掉

一个简单的矩阵乘法 XR3×4,YR4×3X \in \mathbb{R}^{3 \times 4}, Y \in \mathbb{R}^{4 \times 3}

# 传统写法
z = x @ y
# einsum 写法
z = einsum(x, y, "seq1 hidden, hidden seq2 -> seq1 seq2")

einsum 的语义是:对 seq1hiddenseq2 三个变量的所有取值做枚举,将 xseq1,hiddenyhidden,seq2x_{\text{seq1}, \text{hidden}} \cdot y_{\text{hidden}, \text{seq2}} 的乘积累加到 zseq1,seq2z_{\text{seq1}, \text{seq2}} 中。hidden 没有出现在输出中,所以被 sum 掉——这正是矩阵乘法的定义。

einsum 的优势在复杂场景中尤其明显。考虑一个 batched attention score 计算,输入 XR2×3×4X \in \mathbb{R}^{2 \times 3 \times 4}(batch × seq × hidden)和 YR2×4×3Y \in \mathbb{R}^{2 \times 4 \times 3}

# 传统写法——需要理解 @ 对前导维度做 batch、transpose 作用于哪两个轴
z = torch.matmul(x, y.transpose(-2, -1))
# einsum 写法——语义自明
z = einsum(x, y, "batch seq1 hidden, batch seq2 hidden -> batch seq1 seq2")

注意 einsum 中 不需要显式 transpose——维度名已经隐含了对齐关系。如果 YY 的维度顺序是 batch hidden seq2 而非 batch seq2 hidden,那才是不需要 transpose 的情况;这里 hidden 在不同位置出现,einsum 自动处理了对齐。

当 batch 维度很多时(比如 batch + sequence + head 三层嵌套),可以用省略号 ... 来代替所有 batch 维度:

z = einsum(x, y, "... seq1 hidden, ... seq2 hidden -> ... seq1 seq2")

这使得代码具有模块性——不需要知道传入 tensor 具体有几个 batch 维度。

reduce:通用聚合#

reduce 是 sum、mean、max、min 的泛化。传统写法 x.sum(dim=-1) 在 einops 中变为:

result = reduce(x, "... hidden -> ...", "sum")

语义是:hidden 维度在输出中不出现,所以对它做 sum(或 mean/max/min)。这样维度名的存在让操作意图一目了然,不需要去数 dim=-1 到底是哪个维度。

从性能角度看,reduce 只是语法糖——最终调用的是相同的底层 primitive,没有额外开销。

rearrange:维度的拆分与合并#

rearrange 处理的是维度的 reshape 操作,尤其是一个维度实际上是两个维度的乘积的情况。这在 multi-head attention 中很常见:将 hidden dimension 拆分为 heads × head_dim

# x 的 shape 是 (3, 8),其中 8 = 2 heads × 4 hidden_per_head
x = rearrange(x, "... (heads hidden1) -> ... heads hidden1", heads=2)
# 现在 x 的 shape 是 (3, 2, 4)

括号语法 (heads hidden1) 表示这个维度可以分解为 headshidden1 的乘积。需要指定其中一个的大小(heads=2),另一个自动推导。

做完 per-head 的变换后,可以用 rearrange 合并回去:

x = rearrange(x, "... heads hidden2 -> ... (heads hidden2)")

关于拆分的顺序(row-major vs column-major),rearrange 中括号内变量的书写顺序决定了拆分方式——写在前面的变量变化更慢(类似 row-major)。

einops 需要一定的学习成本,但一旦习惯了用命名维度思考,所有的 transpose、reduce、reshape 操作都变得直觉化,不再需要与数字索引搏斗。

FLOPs 计算与硬件规格#

FLOPs 与 FLOP/s:一字之差#

在讨论计算量时,“FLOPs”一词有两个含义,必须区分清楚:

  • FLOPs(小写 s)= floating-point operations,计算总量的度量。例如”GPT-3 训练用了 3.25×10233.25 \times 10^{23} FLOPs”。
  • FLOP/s = floating-point operations per second硬件速度的度量。例如”H100 的 bf16 峰值是 989 TFLOP/s”。

有时 FLOP/s 也被写成大写 FLOPS,但本文一律用 /s 标注以避免混淆。

一个 FLOP 定义为一次基本浮点运算(加法或乘法)。GPU 还能做其他运算,但加法和乘法是绝对主体,占据了几乎所有的计算时间。

MatMul 的 FLOPs:2BDK#

矩阵乘法(MatMul)的 FLOPs 计算是整个 resource accounting 的核心。对于 XRB×DX \in \mathbb{R}^{B \times D} 乘以 WRD×KW \in \mathbb{R}^{D \times K}

FLOPs=2×B×D×K\text{FLOPs} = 2 \times B \times D \times K

因子 2 来自:每个输出元素需要 DD 次乘法和 D1D-1 次加法,近似为 2D2D 次运算,共 B×KB \times K 个输出元素。严格来说加法是 D1D-1 次而非 DD 次,但在实际规模下这个差异可以忽略。

BB 理解为数据点/token 数、D×KD \times K 理解为参数量,这个公式可以改写为:

FLOPsforward=2×tokens×parameters\text{FLOPs}_{\text{forward}} = 2 \times |\text{tokens}| \times |\text{parameters}|

其他元素级操作(elementwise addition、ReLU 等)的 FLOPs 只是矩阵的大小(m×nm \times n),远小于 MatMul 的 O(m×n×k)O(m \times n \times k)。所以在矩阵足够大的前提下,只需要关注 MatMul 的 FLOPs

关于是否有亚立方的矩阵乘法算法(如 Strassen)可以利用:实际中 GPU 上的矩阵乘法优化主要依靠与硬件的 co-design(内存层次、tiling 等),而非渐近更优的算法。同样,虽然直觉上加法应该比乘法快,但在现代硬件设计中二者的速度基本相同。

H100 规格与 MFU#

NVIDIA H100 的 spec sheet 标注 bf16 峰值为 1979 TFLOP/s——但有个关键细节:这个数字包含了 structured sparsity(2:4 稀疏)。对于 dense 计算,需要除以 2,即 约 989 TFLOP/s

有了 FLOP/s 的理论值和实际 benchmark 值,就可以定义 MFU(Model FLOPs Utilization)

MFU=实际 FLOP/sspec sheet FLOP/s(dense)\text{MFU} = \frac{\text{实际 FLOP/s}}{\text{spec sheet FLOP/s(dense)}}

MFU 衡量的是:你的代码在多大程度上”兑现”了硬件的承诺。经验值:

  • MFU ≈ 0.5 对于现代大模型训练,算是不错的水平
  • 纯 MatMul 可以达到 0.7-0.8
  • MFU ≈ 0.1 说明有严重问题需要排查

MFU 低于 1.0 的原因将在后续的算术强度分析中变得清晰——本质上是因为 memory bandwidth 的限制。

直觉量级#

为了建立 FLOPs 的量级直觉:8 块 H100 跑一周,假设 MFU = 1(理想情况),大约能完成 8×989×1012×7×864005×10218 \times 989 \times 10^{12} \times 7 \times 86400 \approx 5 \times 10^{21} FLOPs。这就是为什么 102210^{22}102510^{25} 量级的训练需要数百到数千块 GPU 跑数周到数月。

计算 MFU 的实际流程:

  1. 根据模型结构和 batch size 计算逻辑 FLOPs
  2. 在 GPU 上 benchmark wall clock time(注意要用 torch.cuda.synchronize() 确保异步操作完成)
  3. 实际 FLOP/s = 逻辑 FLOPs / wall clock time
  4. MFU = 实际 FLOP/s / spec sheet FLOP/s

benchmark 时一个常见的坑是忘记 cuda.synchronize()——因为 CUDA 操作是异步的,torch.matmul 之类的调用会立即返回,如果不加同步屏障就会误以为计算很快。正确做法是在操作前后各加一次 synchronize,并多次运行取平均。

算术强度与 Roofline 分析#

计算不只是”做运算”#

到目前为止我们只关注了 FLOPs 本身,但实际运行时间还取决于另一个因素:数据在 HBM(高带宽内存)和计算核心之间的搬运开销

GPU 的计算模型可以简化为三步:(1) 将输入 tensor 从 HBM 送到计算核心;(2) 执行计算;(3) 将结果写回 HBM。因此一次操作的总时间取决于两个因素:

  • Accelerator speed(FLOP/s):计算核心的运算速度
  • Memory bandwidth(bytes/s):HBM 与计算核心之间的数据传输速度

H100 的两个关键参数:

  • bf16 dense FLOP/s = 1979×1012/29891979 \times 10^{12} / 2 \approx 989 TFLOP/s
  • Memory bandwidth = 3.35 TB/s

实际中,communication 和 computation 可以重叠(pipeline)。在理想重叠假设下,总时间 = max(通信时间, 计算时间)

逐例分析:从 ReLU 到 MatMul#

ReLU:极度 memory bound

对长度为 nn(bf16)的向量做 ReLU:

  • 搬运字节数:读 xx2n2n)+ 写 yy2n2n)= 4n4n bytes
  • FLOPs:nn(每个元素做一次比较)
  • 通信时间 = 4n/3.35×10121.2×1064n / 3.35 \times 10^{12} \approx 1.2 \times 10^{-6} s(n=106n = 10^6 时)
  • 计算时间 = n/989×10121.0×109n / 989 \times 10^{12} \approx 1.0 \times 10^{-9} s

通信时间是计算时间的 1000 倍。GPU 的计算核心大部分时间都在等数据到来。

GELU:仍然 memory bound

GELU 的每个元素涉及约 20 次浮点运算(包含 tanh\tanh、乘法等),FLOPs 是 ReLU 的 20 倍。但搬运字节数不变(仍是 4n4n)。所以虽然 GELU 的 FLOPs 多了 20 倍,实际运行时间和 ReLU 完全一样——因为瓶颈在搬运而不在计算。这是一个反直觉的结论:看起来复杂得多的运算,实际上并不更慢。

向量点积:memory bound

xwx \cdot wx,wRnx, w \in \mathbb{R}^n(bf16):

  • 搬运字节数:2n+2n+2=4n+22n + 2n + 2 = 4n + 2 bytes
  • FLOPs:2n12n - 1nn 次乘法 + n1n-1 次加法)
  • 算术强度 ≈ 0.5(每搬运 1 byte 只做 0.5 次运算)

矩阵-向量乘:仍然 memory bound

WxWxWRn×nW \in \mathbb{R}^{n \times n}xRnx \in \mathbb{R}^n

  • 搬运字节数:2n+2n2+2n2n + 2n^2 + 2n bytes(读 xx、读 WW、写 yy
  • FLOPs:n×(2n1)2n2n \times (2n - 1) \approx 2n^2
  • 算术强度 ≈ 1(刚刚超过 0.5)

仍远低于 H100 的阈值,依然 memory bound。这一点预示了 Transformer 推理的性能特征:推理时是逐 token 生成,每次是一个 vector 与权重矩阵的乘法,因此天然 memory bound。

矩阵-矩阵乘(MatMul):终于 compute bound

XWXWX,WRn×nX, W \in \mathbb{R}^{n \times n}n=1024n = 1024):

  • 搬运字节数:2n2+2n2+2n2=6n22n^2 + 2n^2 + 2n^2 = 6n^2 bytes
  • FLOPs:2n32n^3
  • 算术强度 = 2n3/6n2=n/33402n^3 / 6n^2 = n/3 \approx 340

n=1024n = 1024 时,算术强度约 340,超过了 H100 的加速器强度 295。这意味着 MatMul 终于是 compute bound——GPU 的计算核心在满载工作。

算术强度:统一的判断框架#

加速器强度(accelerator intensity) 定义为硬件 spec 的 FLOP/s 除以 bytes/s:

Iaccelerator=FLOP/sbytes/s=989×10123.35×1012295 FLOPs/byteI_{\text{accelerator}} = \frac{\text{FLOP/s}}{\text{bytes/s}} = \frac{989 \times 10^{12}}{3.35 \times 10^{12}} \approx 295 \text{ FLOPs/byte}

含义是:H100 每搬运 1 byte 可以做 295 次浮点运算。

算法的算术强度(arithmetic intensity) 定义为该算法执行的 FLOPs 除以搬运的 bytes:

Ialgorithm=FLOPsbytes movedI_{\text{algorithm}} = \frac{\text{FLOPs}}{\text{bytes moved}}

判断规则非常简单:

  • Ialgorithm<IacceleratorI_{\text{algorithm}} < I_{\text{accelerator}}memory bound(搬运是瓶颈)
  • Ialgorithm>IacceleratorI_{\text{algorithm}} > I_{\text{accelerator}}compute bound(计算是瓶颈)
操作算术强度判定
ReLU0.25memory bound
GELU~5memory bound
向量点积~0.5memory bound
矩阵-向量乘~1memory bound
矩阵-矩阵乘 (n=1024n=1024)~340compute bound

矩阵乘法的算术强度随矩阵规模增长(n/3\propto n/3)。直觉非常清晰:读入和写出的数据量是 O(n2)O(n^2),但计算量是 O(n3)O(n^3),二者之比就是 O(n)O(n)——矩阵越大,每搬运一个 byte 能做的”有用功”越多。这就是为什么大 batch size 和大隐藏维度对 GPU 利用率如此重要。如果算术强度低于 accelerator intensity,缩小矩阵并不会加速计算(因为时间被搬运占据),而一旦超过阈值,GPU 才真正在”干活”。

Roofline 图#

图5:Roofline 分析图
图5:Roofline 分析图

Roofline 图的横轴是算术强度(log scale),纵轴是实际实现的 FLOP/s(log scale)。对于给定的加速器:

  • 左侧斜线区域(bandwidth bound):算术强度低于阈值时,实际 FLOP/s 受限于 memory bandwidth,呈线性增长(斜率 = bandwidth)。不同的 bandwidth 等级(BW1,BW2BW_1, BW_2)对应不同的斜线。
  • 右侧水平区域(compute bound):算术强度超过阈值后,实际 FLOP/s 趋近于加速器峰值,不再随算术强度增加而提升。

图中 Algo 1 落在斜线区域(bandwidth bound at both BW1BW_1 and BW2BW_2),Algo 2 落在水平区域(compute bound)。中间区域可能出现”在 BW1BW_1 下 bandwidth bound 但在 BW2BW_2 下 compute bound”的情况。

这解释了为什么 MFU 通常只有 0.5 左右:虽然 Transformer 的主体是 MatMul(compute bound),但中间穿插的 elementwise 操作(layernorm、activation functions、attention softmax 等)都是 memory bound 的,它们拉低了整体的 MFU。

另外需要注意,算术强度也依赖精度。同样的操作在 bf16 和 fp32 下有不同的搬运字节数和不同的硬件 FLOP/s,因此 roofline 图会随精度变化。

训练过程的资源分析#

运行样例:深层网络#

图6:深层网络架构图
图6:深层网络架构图

为了具体分析训练的资源消耗,考虑一个简单但有代表性的深层网络:输入 XRB×DX \in \mathbb{R}^{B \times D}BB 个数据点,每个 DD 维),经过 LL 层,每层包含一个 D×DD \times D 的线性变换和一个 elementwise ReLU 激活。输出与输入同维。

class DeepNetwork(nn.Module):
def __init__(self, dim, num_layers):
self.blocks = nn.ModuleList([Block(dim) for _ in range(num_layers)])
def forward(self, x):
for block in self.blocks:
x = block(x) # linear + ReLU
return x

参数总量 = D2×LD^2 \times L(每层一个 D×DD \times D 权重矩阵)。

前向传播的 FLOPs#

每层的前向传播就是一次 B×DB \times D 乘以 D×DD \times D 的 MatMul,FLOPs = 2×B×D×D=2BD22 \times B \times D \times D = 2BD^2LL 层合计:

FLOPsforward=L×2BD2=2B×(D2L)=2×tokens×params\text{FLOPs}_{\text{forward}} = L \times 2BD^2 = 2B \times (D^2 L) = 2 \times |\text{tokens}| \times |\text{params}|

反向传播的 FLOPs:正好是前向的两倍#

反向传播通过 chain rule 传播梯度。聚焦第 ll 层:hl=hl1Wlh_l = h_{l-1} \cdot W_l(忽略 ReLU 简化分析)。反向时需要计算两个梯度:

  1. 对输入的梯度(用于向前一层继续传播):

Lhl1=einsum(hl,Wl,"batch out, in out -> batch in")\frac{\partial \mathcal{L}}{\partial h_{l-1}} = \text{einsum}(\nabla h_l, W_l, \text{"batch out, in out -> batch in"})

这是一个 MatMul,FLOPs = 2BD22BD^2

  1. 对参数的梯度(用于更新权重):

LWl=einsum(hl,hl1,"batch out, batch in -> in out")\frac{\partial \mathcal{L}}{\partial W_l} = \text{einsum}(\nabla h_l, h_{l-1}, \text{"batch out, batch in -> in out"})

也是一个 MatMul,FLOPs = 2BD22BD^2

注意这里 einsum 的优势:不需要记忆哪个有 transpose——维度名已经编码了正确的对齐方式。三个 einsum(前向 + 两个反向)的 FLOPs 完全相同(都是 2BD22BD^2),因为 FLOPs 只取决于三个维度的乘积,与哪些维度是 batch、哪些被 sum 无关

所以每层的反向 FLOPs = 2×2BD2=4BD22 \times 2BD^2 = 4BD^2,恰好是前向的 2 倍

6nd 公式#

对整个网络:

  • 前向 FLOPs = 2×B×params2 \times B \times |\text{params}|
  • 反向 FLOPs = 4×B×params4 \times B \times |\text{params}|
  • 一个训练步的总 FLOPs = 6×B×params6 \times B \times |\text{params}|

这就是广泛使用的 6nd 公式nn = 参数量,dd = token 数 / batch size)的来源。

对 Transformer 而言,这个公式在 context length 不太长的情况下是一个很好的近似。当 context length 很大时,attention 的 O(seq2)O(\text{seq}^2) 项会贡献额外的 FLOPs,不在这个线性近似中。

训练内存分解#

训练时 GPU 内存需要同时容纳四个部分:

组成部分每参数字节数说明
参数2(bf16)模型权重本身
梯度2(bf16)与参数 shape 相同
优化器状态4(AdaGrad/fp32)或 8(Adam/fp32)见下文详述
激活值2×B×D×L2 \times B \times D \times L(bf16)每层存 B×DB \times D 的激活(反向传播需要)

理解优化器状态的内存开销,需要先理解几种优化器之间的关系。AdaGrad(2011 年提出)可以看作介于 SGD 和 Adam 之间的算法:SGD 只用梯度本身;Momentum(SGD 的变体)跟踪梯度的一阶矩(即梯度的指数移动平均);AdaGrad 跟踪梯度的二阶矩(即历史梯度平方和);Adam 则是将一阶矩和二阶矩结合。这解释了为什么不同优化器的 per-parameter 内存开销不同:AdaGrad 需要 4 bytes(存一个 fp32 的二阶矩),而 Adam 需要 4+4=84 + 4 = 8 bytes(分别存一阶和二阶矩)。

优化器状态使用 fp32 而非 bf16 是出于稳定性考虑:累积的平方梯度经过多步累加后,bf16 的精度不够用。Adam 的 8 bytes/param 意味着优化器状态往往是最大的内存消费者——比参数本身还要大 4 倍。

不过,优化器状态虽然占内存大,但不是 speed 的瓶颈——optimizer step 主要是 elementwise 操作(memory bound 但很快),不涉及大矩阵乘法。它的影响主要体现在:限制了能放进 HBM 的最大模型大小。

将四部分加起来,以 Adam 为例:每个参数需要 2+2+8=122 + 2 + 8 = 12 bytes,加上 activation 的开销(取决于 BBLL)。这就是开头的 back-of-the-envelope 计算的来源。

内存优化技术#

训练大模型时,activation memory(B×D×L\propto B \times D \times L)往往是 OOM 的直接原因,因为它随 batch size 线性增长。有两个经典技术来缓解这一问题。

梯度累积(Gradient Accumulation)#

训练通常需要较大的 batch size 来保证梯度估计的稳定性(存在一个 critical batch size,由后续课程讨论)。但大 batch size 意味着大 activation memory。

梯度累积的思路非常简单:将一个大 batch 拆分为若干 micro-batch,对每个 micro-batch 分别做 forward + backward 并累加梯度(不 zero-out),每积累够 batch_size/micro_batch_size\text{batch\_size} / \text{micro\_batch\_size} 步后再执行一次 optimizer step 并清零梯度。

for i, micro_batch in enumerate(micro_batches):
loss = model(micro_batch)
loss.backward() # 梯度累加到 .grad
if (i + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()

这个简单的代码改动使得 activation memory 只需容纳一个 micro-batch,而梯度效果等价于大 batch。代价是每个 micro-batch 的 forward/backward 需要串行执行,但总 FLOPs 不变。

激活值检查点(Activation Checkpointing)#

默认情况下,前向传播会保存所有中间层的激活值(pre-ReLU 和 post-ReLU 都保存),以便反向传播使用。这使得 activation memory = 2×B×D×L2 \times B \times D \times L(每层两个激活 tensor)。

推理时不需要梯度,所以只需要保存当前层的激活——内存 O(BD)O(BD),与层数 LL 无关。但训练时必须保存所有层的激活供 backward 使用——除非我们愿意重新计算。

Activation checkpointing(又称 gradient checkpointing 或 rematerialization)的核心思想是:前向传播时只保存部分层的激活(checkpoint),反向传播时从最近的 checkpoint 重新计算缺失的激活。这是经典的 时间换空间 trade-off。

在 PyTorch 中实现非常简单——用 torch.utils.checkpoint 包裹需要 checkpoint 的层:

from torch.utils.checkpoint import checkpoint
def forward(self, x):
for block in self.blocks:
x = checkpoint(block, x) # 不保存 block 内部的中间激活
return x

如果对每个 block(linear + ReLU)做 checkpoint,pre-ReLU 的激活不再保存,activation memory 直接减半。反向到该 block 时,从保存的 block 输入重新做一次 forward 得到 pre-ReLU 激活,然后继续 backward——额外计算开销约 33%(重新算一次 forward)。

三种策略的对比:

策略Activation Memory重计算开销
不做 checkpointO(BDL)O(BDL)0
每层 checkpointO(BDL)O(BD \cdot L)(减半常数)~33%
L\sqrt{L} 层 checkpointO(BDL)O(BD \cdot \sqrt{L})O(L)O(\sqrt{L})
极端:只存输入O(BD)O(BD)O(L)O(L) 倍(L2L^2 总 FLOPs)

最后一种策略虽然内存最省,但重计算开销是 L2L^2(每层 backward 时都要从头 forward 到该层)。一个平衡的选择是 L\sqrt{L} 层设一个 checkpoint:activation memory 降为 O(BDL)O(BD\sqrt{L}),重计算开销也是 O(L)O(\sqrt{L}),二者平衡。

Lecture 2:PyTorch、einops 与资源核算
https://www.xwysyy.cn/posts/cs336/lec02/
作者
xwysyy
发布于
2026-05-17
许可协议
CC BY-NC-SA 4.0
© 2026 xwysyy. All Rights Reserved.
Powered by Astro & Firefly

文章目录