Lecture 5:GPU 与 TPU 硬件原理

10906 字
55 分钟
Lecture 5:GPU 与 TPU 硬件原理
文章摘要

Stanford CS336 Language Modeling from Scratch | Spring 2026 | Lecture 5: GPUs, TPUs,时长 1:18:39。

计算规模驱动:为什么需要理解 GPU#

语言模型的进步本质上由计算资源驱动——更快的硬件、更高的利用率、更多的芯片加上更好的并行化策略,三者共同推动了模型能力的增长。理解系统层面的细节,是高效使用这些资源的前提。

如果这是 1990 年代的讨论,焦点将完全在 时钟频率 上:CPU 串行执行指令,通过缩小晶体管尺寸来提升时钟速度——这就是 Dennard Scaling。这种范式在 2000 年代左右触顶:晶体管数量虽然还在按摩尔定律增长,但更小的晶体管不再意味着更快的时钟,物理限制(功耗墙、漏电流)使得频率提升停滞。

在频率无法继续提升的世界里,如何让语言模型继续 scale?答案是 并行计算。GPU 和 TPU 代表了这一范式:不再让单条指令跑得更快,而是让成百上千条指令同时执行。如果把 NVIDIA GPU 的浮点运算能力按年份画成曲线,会看到从 K20、M40 时代较为平缓的增长,到 P100/V100 之后出现超指数增长——这不是靠频率,而是靠并行度和专用硬件单元共同实现的。

从 2017 年 V100 引入 Tensor Core 开始,GPU 的计算能力曲线开始急剧上扬。之后的几代硬件(A100、H100)通过结构化稀疏性和更低精度数值格式(FP8 等)进一步推高了峰值 FLOPS。这些硬件演进的细节将贯穿整节课的讨论。

图1:GPU 并行计算的超指数增长——Bill Dally HotChips 演讲中展示的 1000x/10 年增长曲线,增长来自数值表示(~16x)、复杂指令(~12.5x)、工艺制程(~2.5x)和稀疏性(~2x)的叠加
图1:GPU 并行计算的超指数增长——Bill Dally HotChips 演讲中展示的 1000x/10 年增长曲线,增长来自数值表示(~16x)、复杂指令(~12.5x)、工艺制程(~2.5x)和稀疏性(~2x)的叠加

一个关键的趋势是:计算吞吐的增长速度远快于内存带宽和通信带宽的增长速度。灰色曲线(计算 FLOPS)增长最快,绿色曲线(内存带宽)增长缓慢,蓝色曲线(设备间互联带宽)更慢。这意味着随着时间推移,内存和通信将越来越成为瓶颈。这就是为什么本节课讨论的大部分优化技术都是 内存优化——在现代硬件上,如何减少数据搬运次数比如何加速计算更加关键。

GPU 硬件模型#

CPU 与 GPU 的哲学差异#

CPU 的设计目标是 快速串行执行:处理复杂的分支逻辑、条件判断和控制流。因此 CPU 拥有庞大的控制单元和少量 ALU,追求的是极低延迟——从收到指令到完成指令的时间尽可能短。

GPU 则完全不同,它追求的是 吞吐量(throughput)。向 GPU 派发一个任务后,该任务可能需要较长时间才能完成(高延迟),但 GPU 可以同时处理海量任务,聚合吞吐量远超 CPU。实现这一点的基础是 GPU 拥有数百个轻量级计算单元,它们全部并行执行。

流式多处理器(SM)#

GPU 的基本计算单元是 SM(Streaming Multiprocessor)。SM 类似于一个独立的”核心”,拥有自己的子组件和可访问的内存层级。每个 SM 内部包含多个流式处理器(Streaming Processors),可以并行执行不同的线程。关键数字:A100 的 die 设计包含 128 个 SM,实际启用 108 个(SXM4 版本);H100 拥有 132 个 SM。这些 SM 可以独立编程、独立执行不同的任务。

图2:GPU 执行单元解剖——左侧为单个 SM 的内部结构(包含多个 Streaming Processor),右侧为 GA100 全芯片布局(128 个 SM)
图2:GPU 执行单元解剖——左侧为单个 SM 的内部结构(包含多个 Streaming Processor),右侧为 GA100 全芯片布局(128 个 SM)

内存层级#

GPU 的内存层级是理解性能优化的核心:

内存类型特性A100 延迟(周期数)
寄存器(Registers)最快、最局部,存储内存地址等~1
L1 / 共享内存(Shared Memory)SM 内部,块级共享20-30
L2 缓存跨 SM,较慢~200
全局内存 / HBMDRAM,容量大但慢200-300

L1/共享内存约 20-30 个周期可完成读取;命中 L2 缓存则慢得多;如果需要访问全局内存(HBM),延迟是 L1 的约 10 倍。这个巨大的差异是所有 GPU 内存优化的根本动机:尽可能把数据留在共享内存中操作,减少对全局内存的访问。

图3:GPU 内存层级实物——左上表格为 A100 各级内存的访问延迟(周期数),右侧为 NVIDIA GA100 die shot,展示 SM、L2 缓存分区和 HBM 控制器的物理布局。SRAM 成本约为 DRAM 的 100 倍,但速度快约 8 倍
图3:GPU 内存层级实物——左上表格为 A100 各级内存的访问延迟(周期数),右侧为 NVIDIA GA100 die shot,展示 SM、L2 缓存分区和 HBM 控制器的物理布局。SRAM 成本约为 DRAM 的 100 倍,但速度快约 8 倍

补充细节:共享内存和 L1 缓存的区别在于——共享内存是程序员可控的(通过代码显式管理其中存什么),而 L1 缓存是硬件自动管理的。另外,L2 缓存的延迟差异主要来自物理距离和互联方式,而非底层存储技术的差异(L1/L2/共享内存都是 SRAM 实现)。全局内存之所以慢得多,是因为它使用 DRAM 技术。

为什么不把整个芯片都做成 SRAM?三个原因:(1)SRAM 制造成本极高;(2)SRAM 必须物理上靠近计算单元(信号传播距离受限);(3)SRAM 需要持续供电才能保持数据——这使得大面积 SRAM 非常耗能。从总体拥有成本(TCO)和能耗角度看,大规模 SRAM 是不经济的。

编程模型:Thread、Block、Warp#

GPU 编程模型有三个关键概念:

  1. Thread(线程):最基本的执行单元。所有线程执行相同的指令(SIMT 模型——Single Instruction, Multiple Threads),只是操作的数据不同。这种设计简化了编程模型:你不需要为每个线程编写不同的程序,只需定义一组指令和输入,GPU 负责在所有输入上并行执行。

  2. Block(线程块):一组线程的集合。Block 的关键意义在于——一个 Block 保证运行在同一个 SM 上。由于 SM 拥有共享内存,Block 内的所有线程可以通过共享内存进行数据交换和复用。这在后续讨论 Tiling 时至关重要。

  3. Warp:GPU 调度的基本单位,由 32 个连续编号的线程组成。Warp 内的所有线程同步执行相同的指令。调度器以 Warp 为单位决定哪组线程下一步执行。Warp 粒度的调度降低了调度器的开销。

一个重要的澄清:当说”所有线程执行相同指令”时,这指的是同一个 Warp 内的线程,而非整个 Block 的所有线程。

内存模型总结#

GPU 设备代码可以访问的内存层级从最快到最慢:

  • 寄存器:存储内存地址等元数据,极快但容量极小
  • 本地内存(Local Memory):线程私有,位于 SM 内部
  • 共享内存(Shared Memory):Block 内所有线程可共享访问,用于线程间数据传递和数据复用
  • 全局内存(Global Memory):即 HBM/DRAM,容量大但延迟高
  • 常量内存(Constant Memory):只读,实践中使用较少
  • 主机内存(Host Memory):CPU 侧内存,可用于 GPU 显存不足时的数据卸载

图4:GPU 内存模型——Thread 拥有私有 Registers,Block 内共享 Shared Memory,跨 Block 通信必须经过 Global Memory(慢速)
图4:GPU 内存模型——Thread 拥有私有 Registers,Block 内共享 Shared Memory,跨 Block 通信必须经过 Global Memory(慢速)

核心原则:一旦跳出共享内存范围,速度就会急剧下降。 因此,整节课的优化策略都围绕同一个目标——通过合理的 Block 组织来减少全局内存读取次数。

GPU 线程的轻量级特性#

GPU 的线程非常轻量——可以随时被暂停和恢复。调度器可以决定”当前 Warp 由于等待数据而阻塞,切换到另一个就绪的 Warp 执行”。这种快速切换能力使得 GPU 能在某些任务阻塞时维持高吞吐量。

TPU 与 GPU 的趋同演化#

架构相似性#

TPU(Tensor Processing Unit)是 Google 设计的 ML 专用加速器,与 GPU 在高层架构上惊人地相似——这是一种 趋同演化(convergent evolution) 现象。如果目标是构建一个能效比高的矩阵乘法加速器,最终设计会收敛到类似的解决方案。

TPU 与 GPU 共享以下核心结构:

  • 专用的矩阵乘法电路(底层都使用 脉动阵列(systolic array) 来流式传入数据做矩阵乘法)
  • 可做并行向量操作的组件
  • 分层内存结构:慢速大容量内存(HBM)+ 快速小容量本地内存(SMEM/共享内存)

Google 的 Jax Book 中提供了 GPU 概念到 TPU 概念的精确映射表:GPU 的 Tensor Core(矩阵乘法单元)对应 TPU 的 MXU(Matrix Multiply Unit);GPU 的 SM 对应 TPU 的 TensorCore(注意此处命名冲突,见下文)。

关键命名冲突#

一个极易混淆的命名问题:TPU 将其流式处理器称为”Tensor Core”,而 GPU 将其矩阵乘法单元称为”Tensor Core”——同名但含义完全不同。在 TPU 语境中,Tensor Core = 一个处理器(类比 GPU 的 SM);在 GPU 语境中,Tensor Core = 矩阵乘法硬件单元。需要根据上下文判断所指。

核心差异:少而大 vs 多而小#

GPU 和 TPU 的关键设计分歧在于矩阵乘法单元的 数量与大小

指标GPU (H100)TPU
流式处理器数量132 个 SM2 个 TensorCore
矩阵乘法单元数量528 个8 个

GPU 依赖大量小型矩阵乘法单元,提供更高的编程灵活性——可以将不同的小 matmul 分配给不同单元。TPU 则依赖少量大型矩阵乘法单元,更加”锁定”于大规模矩阵运算。一个直观的例子:某 TPU 实验中 batch size 扫描只能到 64 就停了,因为 TensorCore 拒绝接受维度小于 64 的输入——硬件物理上要求输入必须足够大。

真正的差异在网络#

一个值得注意的硬件差异是:TPU 的等效 L2 缓存比 GPU 快得多,这来自 Google 在芯片设计上的不同权衡——他们在硅片面积分配上更多地倾向于快速缓存,这是 TPU 的一个卖点。

如果要问 GPU 和 TPU 之间最大的实质性差异是什么,答案不在单芯片架构(两者的芯片本质上都是矩阵乘法器),而在 网络互联。TPU Pod 的互联拓扑和通信模式与 GPU 集群的 NVLink/InfiniBand 方案有根本不同,但这超出了本节课的范围。

结论:核心概念可迁移#

由于 GPU 和 TPU 都依赖相同的两个基本要素——内存层级(快速内存 vs 慢速内存)和矩阵乘法单元——本节课讲授的所有优化概念在两种硬件间完全可迁移。高效分配内存的方式有限,高效做矩阵乘法的方式也有限,因此无论细节设计如何不同,最终都会抵达相同的优化原理。

Tensor Core 与矩阵乘法的特权地位#

在 V100 之前,GPU 上做矩阵乘法需要手动编程——早期的研究者发现可以巧妙地利用图形着色器来执行矩阵乘法运算,通过特定的渲染设置来获得更快的 matmul。这是一种很 hacker 的方式。

从 V100(2017 年)开始,NVIDIA 引入了 Tensor Core——专门为矩阵乘法设计的硬件电路。一旦这个专用硬件存在,矩阵乘法就成为了机器学习中 唯一的特权操作:Tensor Core 上矩阵乘法的吞吐量比 GPU 上任何其他浮点运算快 超过 10 倍

这一硬件现实有深远的架构含义:任何能够随计算规模增长的近未来 ML 架构,都必然包含矩阵乘法。因为这是唯一能有效利用 GPU 峰值算力的操作。可并行但非矩阵乘法的操作(如逐元素操作、归一化等)与 matmul 之间存在巨大的吞吐量鸿沟。

GPU 各组件的增长速率差异#

不同 GPU 组件的性能增长速率存在显著差异:

  • 计算 FLOPS(灰色曲线):增长最快,几乎呈超指数趋势
  • 内存带宽(绿色曲线):增长相对缓慢
  • 设备间通信带宽(蓝色曲线):增长最慢

图5:计算、内存、互联的增长速率差异——HW FLOPS 20 年增长 60000x(每 2 年 3x),DRAM 带宽仅增长 100x(每 2 年 1.6x),互联带宽增长 30x(每 2 年 1.4x)
图5:计算、内存、互联的增长速率差异——HW FLOPS 20 年增长 60000x(每 2 年 3x),DRAM 带宽仅增长 100x(每 2 年 1.6x),互联带宽增长 30x(每 2 年 1.4x)

这种不对称增长意味着:在 GPU 发展早期,计算和内存带宽差距不大,内存瓶颈不显著。但随着年代推移,计算与内存之间的剪刀差越来越大。这解释了为什么现代 GPU 优化几乎等同于内存优化——计算是”免费的”(相对于内存访问而言),如何减少数据搬运才是关键。

这个趋势在推理场景中更为极端。推理比训练更受内存带宽限制,这催生了诸如 prefill-decode 解耦 的方案:将计算密集的 prefill 阶段放在一种芯片上(优化算力),将内存带宽密集的 decode 阶段放在另一种芯片上(优化带宽)。Step-fun 3 甚至将不同层类型(attention vs MLP)路由到不同的加速器上。

本节课的心智模型#

至此应该形成的心智模型是:

  1. GPU 是大规模并行设备,同时执行海量指令
  2. 计算增长远快于内存,因此内存是真正的瓶颈
  3. 所有操作都必须尊重内存层级——尽量在共享内存中完成计算,避免反复访问全局内存
  4. 矩阵乘法是唯一能充分利用 GPU 峰值算力的操作

Roofline 模型与性能瓶颈诊断#

在进入具体优化技巧之前,需要建立一个分析框架来判断”我的程序为什么慢”。这就是 Roofline 模型

考虑一个简单的实验:对两个方阵做矩阵乘法,x 轴是矩阵维度,y 轴是吞吐量(每秒处理的元素数)。直觉上,更大的矩阵意味着更多工作量,应该能更好地利用硬件,因此吞吐量应该随维度增大而上升。实际测量结果大致符合这一预期,但中间穿插着各种诡异的周期性下降和参差不齐的性能波动。理解这些异常行为需要 Roofline 模型和后续的六个优化技巧。

Roofline 模型的两个区域#

经典 Roofline 模型描述了两个截然不同的性能区域:

  1. 内存受限区(Memory-bound):对角线上升部分。当每次内存读取对应的计算量不够多时,计算单元无法被完全利用——它们在等待数据到来。此时,增加问题规模(提高计算强度)可以提升吞吐量。

  2. 计算受限区(Compute-bound):水平平坦部分。当计算强度超过某个阈值后,计算单元已被完全饱和,此时再增加工作量不会提升吞吐量——瓶颈已经从内存切换到了计算本身。

算术强度(Arithmetic Intensity) 定义为每次内存访问对应的计算量(FLOPs/Byte)。它是连接这两个区域的关键指标。目标是使自己的工作负载位于计算受限区——一旦到达平坦区域,硬件利用率已经是最优的。

目标:让代码跑在 Roofline 的平坦区#

要使 GPU 代码高效运行,需要:

  • 提升算术强度——让每次内存读取对应更多的计算
  • 避免停留在对角线区域——那意味着在浪费计算能力

接下来的六个优化技巧本质上都服务于同一个目标:通过各种手段提升算术强度或减少内存访问次数,将工作负载从内存受限区推向计算受限区。

技巧一:避免控制分歧(Control Divergence)#

第一个优化技巧与 GPU 的 SIMT 执行模型直接相关。在 CPU 上,if-else 语句的行为很直观:根据条件选择一个分支执行,然后继续。但在 GPU 上,由于 Warp 内所有线程必须执行相同的指令,if 语句的行为完全不同。

GPU 上 if 语句的真实执行方式#

当 GPU 代码中出现条件分支时,Warp 内的所有线程会执行两个分支,但不满足条件的线程会 mask 掉该分支的计算结果(相当于空转)。执行流程如下:

  1. 线程遇到 if 条件,部分线程满足条件、部分不满足
  2. 先执行 if 分支——不满足条件的线程空闲等待
  3. 再执行 else 分支——满足 if 条件的线程空闲等待
  4. 两个分支都执行完毕后,线程重新汇合

图6:控制分歧执行示意——上半部分为 SIMT 模型(Instruction Decoder 向所有 CUDA Core 广播同一指令),下半部分展示 if/else 分支的时间线:两组线程交替执行,另一组空闲等待
图6:控制分歧执行示意——上半部分为 SIMT 模型(Instruction Decoder 向所有 CUDA Core 广播同一指令),下半部分展示 if/else 分支的时间线:两组线程交替执行,另一组空闲等待

这被称为 控制分歧(Control Divergence)。它的代价是:执行一个 if-else 语句,GPU 实际上消耗了执行两个分支的时间,其中大量计算资源在空转。

实践中的应对策略#

这就是为什么 GPU 代码中应该极力避免 if 语句。一个典型的例子是 ReLU 实现:

  • 低效写法(if 语句)if x > 0: output = x else: output = 0——会导致控制分歧,部分线程空等
  • 高效写法(乘法掩码)output = x * (x > 0)——所有线程同时执行乘法操作,无分歧

掩码乘法可以一次完成所有线程的计算,而 if 语句可能需要两个操作周期。这是 GPU 编程中的一个基本原则:用算术操作替代条件分支

不过,控制分歧是与后续的内存优化相对独立的话题。接下来的五个技巧都围绕一个主线展开:如何最小化内存数据搬运。

技巧二:低精度计算(Low Precision)#

图7:低精度驱动的 FLOPS 增长——从 K20X 的 Scalar FP32 到 H100 的 FP8 Transformer Engine,单芯片推理性能 10 年增长 1000x
图7:低精度驱动的 FLOPS 增长——从 K20X 的 Scalar FP32 到 H100 的 FP8 Transformer Engine,单芯片推理性能 10 年增长 1000x

低精度计算是硬件厂商当前投入最多精力的优化方向。如果观察 GPU FLOPS 的超指数增长曲线,相当大一部分增长来自数值格式的降级——从 FP32 到 BF16 再到 INT8,每次减半比特数就意味着减半内存搬运量,从而缓解内存瓶颈。

基本原理#

以一个对长度为 nn 的向量执行 ReLU 为例。在 FP32 下:每个元素需要一次读取(4 bytes)和一次写入(4 bytes),总共 8 bytes 的内存访问,但只做了 1 次 FLOP。算术强度 = 1/8 FLOP/byte,极低。如果切换到 FP16/BF16,每次读写只需 2 bytes,总共 4 bytes/FLOP——内存开销直接减半。

实际实现:混合精度训练#

实际的低精度矩阵乘法不会将所有环节都保持低精度。典型的流程是:

  1. 输入矩阵以低精度(如 BF16)存储
  2. 相乘步骤在低精度下执行
  3. 累加步骤(部分和的求和)在 FP32 精度下执行
  4. 输出结果可能以 FP32 写出

低精度训练之所以是”黑色艺术”,不是因为降低精度本身困难,而是需要仔细判断 哪些操作可以安全降精度、哪些必须保持高精度。经过多年的经验积累,社区形成了一些共识:矩阵乘法中的权重和激活值通常可以用低精度;softmax、指数运算等可能需要 FP32 或至少 BF16;第一层和最后一层通常难以量化。最后一层的原因较为清晰——它直接驱动 loss,是一阶因子,量化误差会导致不稳定和显著的 loss 增加。第一层为何难以量化,目前缺乏明确的直觉解释。

FP8 与 MXFP8#

BF16 之后的下一个前沿是 FP8。一旦降到 8 bit,就不再有单一的标准格式:

  • E4M3:4 bit 指数 + 3 bit 尾数——更高精度但范围较小
  • E5M2:5 bit 指数 + 2 bit 尾数——更大范围但精度较低

两种格式用于不同场景,没有万能选择。

FP8 训练面临的核心问题是:只有 4 bit 指数意味着动态范围极其有限,很容易溢出或下溢。因此必须引入 缩放因子(Scaling Factor)——一个 FP32 的乘数,用于将数值调整到 FP8 能表示的范围内。

图8:FP8/MXFP8 数值格式——左侧对比 FP16、BF16、FP8 E4M3、FP8 E5M2 的 bit 布局;右侧展示 MXFP8 的多缩放因子结构(每 32 元素一个 E8M0 缩放因子)和前向/后向 pipeline 中的量化位置
图8:FP8/MXFP8 数值格式——左侧对比 FP16、BF16、FP8 E4M3、FP8 E5M2 的 bit 布局;右侧展示 MXFP8 的多缩放因子结构(每 32 元素一个 E8M0 缩放因子)和前向/后向 pipeline 中的量化位置

传统 FP8 对整个矩阵使用单一缩放因子。但一个矩阵的不同区域可能有非常不同的数值范围——序列的某些位置激活值远大于其他位置。于是出现了 MXFP8(Microscaling FP8)

  • 每 32 个元素使用一个独立的缩放因子
  • 缩放因子本身是 E8M0 格式(8 bit 全为指数,即纯粹的 2 的幂次)
  • 元素使用 E4M3 格式(更多尾数位提升局部精度)

MXFP8 的转置难题#

MXFP8 引入了一个非直觉的复杂性:由于缩放因子按固定的 1-out-of-32 模式排列(沿行方向),转置矩阵不再具有相同的缩放模式。传统的矩阵转置是零成本操作,但在 MXFP8 下,转置后必须重新量化整个矩阵以适配新的缩放模式。

实际实现的解决方案出人意料地暴力:为每个量化矩阵同时存储原始版本和转置版本。即每次量化时生成两份副本——一份原始排列,一份转置后重新量化的排列。需要转置时直接使用预存的转置版本。

MXFP8 训练的实际收益#

在实际训练中使用 MXFP8:

  • 只对安全的层做量化(跳过第一层和最后一层)
  • 主要针对矩阵乘法操作量化(收益最大)
  • 每次量化需要生成原始 + 转置两个副本
  • 实际加速约 20-30%(不是理论上的 2x,因为量化/反量化操作有开销)
  • 低精度不仅节省内存带宽,也带来近线性的 算力收益(量化数值相乘本身更快),但量化/反量化的额外开销会稀释综合收益

MXFP4:极端低精度的前沿#

更极端的探索是 MXFP4——4 bit 浮点,整个数值范围可以在一张幻灯片上列完:-6 到 +6 之间的离散值。结构:

  • 每 16 个元素共享一个缩放因子
  • 缩放因子为 E4M3 的 FP8

目前已有论文证明 FP4 训练的可行性,但尚未有在大规模生产模型上成功应用的公开报告。这可能是下一代模型训练的方向。

量化与推理的关系#

训练和推理的量化策略有所不同。当前的最优实践涉及多种技术的组合:

  • 训练时使用 量化感知训练(Quantization-Aware Training)
  • 训练后使用 后训练量化(Post-Training Quantization)
  • 推理时的缩放因子可能通过某种拟合来确定

这仍然是一个活跃的研究领域——即使是工业界团队也在持续进行”量化科学”的探索。

一个相关的方向是 结构化稀疏性(Structured Sparsity)。NVIDIA 硬件支持 2:4 结构化稀疏,已有成功案例(如 MoE 可视为一种结构稀疏操作)。不过,从经验结果看,结构化稀疏矩阵带来的计算收益与其导致的表示能力损失相互抵消——在实际大规模训练中,这一方向尚未充分兑现其理论潜力。

技巧三:算子融合(Operator Fusion)#

图9:GPU 工厂类比——左侧为小工厂(Compute)配窄传送带(Memory),右侧计算规模扩大但传送带不变,瓶颈转移到内存带宽
图9:GPU 工厂类比——左侧为小工厂(Compute)配窄传送带(Memory),右侧计算规模扩大但传送带不变,瓶颈转移到内存带宽

算子融合是一个概念简单但效果显著的优化。用工厂类比来理解:GPU 有一个”内存仓库”和一个”计算工厂”,中间由一条”窄传送带”(内存带宽)连接。如果你的计算流程涉及多个操作,朴素实现会让数据在仓库和工厂之间反复往返——每次操作都是一次完整的”取材料 → 加工 → 送回仓库”循环。

问题:多次全局内存往返#

考虑一个简单的计算 sin2(x)+cos2(x)\sin^2(x) + \cos^2(x)。PyTorch 的计算图会产生以下操作节点:输入 xx → sin → cos → square(两次) → add → 输出。

朴素执行模式下,每个操作节点 都是一个独立的 CUDA kernel:读取全局内存中的输入,执行计算,将结果写回全局内存。下一个节点再从全局内存读取上一步的输出……如此循环。每一步都付出了全局内存读写的延迟代价。

解决方案:融合为单一 kernel#

融合的思路很自然:既然这些操作只是对相同数据的连续变换,为什么不把它们合并成一个 kernel?

融合后的执行流程:

  1. 一次从全局内存读取 xx
  2. 在 SM 的共享内存/寄存器中完成整条计算链(sin、cos、square、add)
  3. 一次将最终结果写回全局内存

内存访问次数从多次读写缩减为恰好一读一写。

编译器自动融合#

对于这类简单的元素级操作链,PyTorch 的编译器(torch.compile)或 Jax 的编译器(jax.jit/XLA)可以自动完成融合——它们会识别出计算图中可以合并的操作序列,自动生成一个融合的 CUDA kernel。

但更复杂的融合(涉及矩阵乘法 + 后续操作的融合,或需要特殊内存访问模式的融合)可能需要手动实现。Flash Attention 就是这种高级融合的典型案例——编译器无法自动推导出最优的融合策略,需要人工设计算法。

技巧四:重计算(Recomputation)#

重计算是反直觉的——为什么要重复已经做过的计算?但在计算极度廉价、内存访问极度昂贵的 GPU 世界中,这个权衡是划算的。

标准反向传播的内存开销#

图10:标准 backprop 的内存开销——前向 pass 存储中间激活值 s2、s1(3 次写入),后向 pass 读回这些激活值(3 次读取),总计 8 次内存访问
图10:标准 backprop 的内存开销——前向 pass 存储中间激活值 s2、s1(3 次写入),后向 pass 读回这些激活值(3 次读取),总计 8 次内存访问

标准的反向传播需要在前向过程中保存中间激活值(activations),以便在反向传播时使用。以一个简单的三层 sigmoid 网络为例:

前向阶段:读取 xx → 计算 s1=σ(x)s_1 = \sigma(x) → 存储 s1s_1 → 计算 s2=σ(s1)s_2 = \sigma(s_1) → 存储 s2s_2 → 计算输出

反向阶段:读取梯度 doutd_{out} → 读取 s2s_2 → 计算反向 → 读取 s1s_1 → 计算反向 → 输出 dxd_x

总计内存访问:前向 1 读 + 3 写,反向 3 读 + 1 写 = 8 次内存访问。每个中间激活值都需要写入全局内存保存,后续反向时再读出——代价高昂。

重计算策略#

重计算的核心思想:不保存中间激活值,在反向传播需要时重新计算它们

改进后的流程:

  • 前向阶段:读取 xx → 计算 sigmoid 链 → 只写出最终输出。中间激活值 s1s_1s2s_2 不保存
  • 反向阶段:读取 doutd_{out}xx → 重新前向计算 sigmoid 链(在 SM 内部即时得到 s1s_1s2s_2) → 用这些临时值计算梯度 → 写出 dxd_x

总计内存访问:前向 1 读 + 1 写,反向 2 读 + 1 写 = 5 次内存访问

从 8 次降到 5 次——减少了 37.5% 的内存访问,代价仅仅是多做了一次前向计算(那些 sigmoid 运算)。在计算资源远超内存带宽的现代 GPU 上,这笔交易非常划算:用冗余的计算换取更少的内存搬运。

适用场景#

重计算在以下条件下最有效:

  • 操作本身计算量小但中间结果多(如连续的逐元素操作)
  • 训练过程中需要保存大量激活值
  • 硬件的计算/内存比率很高(即计算远比内存搬运便宜)

这也是 Flash Attention 后向传播中使用的关键技巧——丢弃 n2n^2 大小的注意力矩阵中间结果,在反向时逐块重新计算。

技巧五:合并内存访问(Coalesced Memory Access)#

这个优化与 DRAM 的物理结构直接相关。理解 DRAM 的 burst 特性后,可以通过调整数据访问模式获得显著加速。

DRAM 的 Burst 特性#

全局内存使用的 DRAM 在物理上由网格状的存储单元组成。当访问某个地址时,DRAM 不会只返回该地址的一个值——它会将同一行(burst section)中的所有值一并返回。这是因为 DRAM 的寻址开销主要在”行选择”阶段(激活电压、使能放大器),一旦选中了某一行,读出该行内的多个元素几乎是免费的。

图11:DRAM Burst Section——地址空间被划分为连续的 burst section(彩色块),访问任一位置时同一 section 内的所有位置免费送达。左下为 DRAM 物理单元网格示意
图11:DRAM Burst Section——地址空间被划分为连续的 burst section(彩色块),访问任一位置时同一 section 内的所有位置免费送达。左下为 DRAM 物理单元网格示意

典型的 burst 大小为 128 bytes。这意味着一次内存访问可以”免费”带回 128 bytes 的连续数据——前提是这些数据在同一个连续内存块中。

合并访问(Coalesced Access)的定义#

当一个 Warp 中的所有线程的内存访问地址落在同一个 burst section 内时,称这次访问为 合并的(coalesced)。此时,一次 DRAM 读取就能满足整个 Warp 的数据需求。

反之,如果 Warp 中的线程访问分散在不同 burst section 中的地址,则需要多次 DRAM 读取——每个 burst section 都要单独激活一次。

矩阵访问的行主序问题#

这个概念在矩阵运算中特别重要。考虑一个 4×4 的行主序(row-major)矩阵,内存中的线性排列顺序是第一行、第二行、第三行、第四行依次排列。

沿列方向读取(不合并):如果 4 个线程分别读取矩阵的第一列(即第 1 行第 1 个、第 2 行第 1 个、第 3 行第 1 个、第 4 行第 1 个),这些元素在内存中是不连续的——每个元素位于不同的 burst section 中。结果:为了读 4 个元素,触发了 4 次独立的 burst 读取,大量带回的数据被浪费。

沿行方向读取(合并):如果 4 个线程分别读取矩阵同一行的 4 个连续元素,这些元素在内存中连续排列,属于同一个 burst section。结果:一次 burst 读取就能满足所有线程的需求。

助记规则:在行主序矩阵中,线程沿主轴(行方向)移动是不合并的。原因在于:如果多个线程各自沿行方向向右推进(每个线程处理自己行中的元素),这些线程在同一时刻访问的地址分散在不同行的不同位置,彼此间隔整行的距离——落在不同的 burst section 中。反之,如果同一 Warp 的所有线程在同一时刻访问同一行内的连续列元素,这些地址在内存中紧密排列,属于同一个 burst section——这才是合并的。

实践意义#

合并访问可以带来数倍的内存读取效率提升。这也是为什么在 GPU 编程中需要关注矩阵的存储顺序(row-major vs column-major),以及为什么某些看似无关的矩阵排列选择会对性能产生巨大影响。

技巧六:分块(Tiling)#

分块是本节课中影响最大的优化技巧,也是理解 Flash Attention 的关键。它的核心思想是:将大矩阵切分为小块(tile),依次加载到共享内存中操作,最大化数据在快速内存中的复用次数。

动机:矩阵乘法中的重复读取#

在一个 n×nn \times n 的矩阵乘法中,如果不做任何优化,输入矩阵的每个元素会被从全局内存读取 nn 次——因为它参与了输出矩阵中一整行(或一整列)的计算。这是巨大的内存浪费。

图13:朴素矩阵乘法的内存访问模式——右侧表格展示每个 thread 的访问顺序,注意 M₀₀ 和 N₁₀ 等元素被不同 thread 重复读取多次,且访问不合并
图13:朴素矩阵乘法的内存访问模式——右侧表格展示每个 thread 的访问顺序,注意 M₀₀ 和 N₁₀ 等元素被不同 thread 重复读取多次,且访问不合并

自然的想法是:既然同一个元素要被多次使用,能否读取一次后将其保留在快速内存中,反复使用?

Tiling 算法#

图12:Tiling 算法示意——将矩阵 M 和 N 切分为子矩阵,按 phase 加载到共享内存中:(1) 加载 M₀₀ 和 N₀₀ 到 SHM → (2) 计算部分和 → (3) 加载下一组 tile,重复直到完成
图12:Tiling 算法示意——将矩阵 M 和 N 切分为子矩阵,按 phase 加载到共享内存中:(1) 加载 M₀₀ 和 N₀₀ 到 SHM → (2) 计算部分和 → (3) 加载下一组 tile,重复直到完成

将输入矩阵切分为 t×tt \times t 的子矩阵(tile)。算法流程:

  1. 将第一组 tile(如 M00M_{00}N00N_{00})从全局内存加载到共享内存
  2. 在共享内存中执行子矩阵乘法,将结果累加到输出 tile
  3. 加载下一组 tile(如 M01M_{01}N10N_{10}),继续累加
  4. 所有相关 tile 处理完毕后,将输出 tile 写回全局内存

数学分析#

  • 无 tiling:每个输入元素从全局内存读取 nn
  • 有 tiling(tile 大小 tt:每个输入元素从全局内存读取 n/tn/t 次,在共享内存中被访问 tt

全局内存访问次数降低了 tt 倍。极端情况:如果 t=nt = n(整个矩阵一次性放入共享内存),则每个元素只从全局内存读取 1 次——但这受限于共享内存的容量。实际中,tile 大小受 SM 共享内存容量限制(A100 上每个 SM 约 192 KB),需要在 tile 大小和 SM 资源之间权衡。

Tile 大小的选择与陷阱#

Tile 大小并非越大越好——它与矩阵维度之间的整除关系至关重要。

理想情况:矩阵维度恰好是 tile 大小的整数倍。例如 tile 大小 128×128 处理 256×256 的矩阵,得到 4 个完美的 tile。

灾难情况:矩阵维度比整数倍多出 1。例如 tile 大小 128×128 处理 257×257 的矩阵——会产生两条极细的”边界 tile”,几乎没有数据但仍占用完整的 SM 资源。这导致大量 SM 在空转。

PyTorch 的 torch.compile 配合 max_autotune 选项时,会花费相当长的时间自动尝试各种 tile 大小配置,通过 benchmarking 找到对当前矩阵维度最优的 tile 方案。

Tiling 与 Coalescing 的交互#

Tile 的对齐还会影响合并访问的效率。理想情况下,tile 的行宽与 burst section 的大小对齐——此时读取一个 tile 的一行只需要恰好对应数量的 burst 读取。

图14:Tile 对齐 vs 不对齐——左侧 Aligned Layout 中 burst section 恰好覆盖一个 tile 行(One Nice Tile),右侧 Unaligned Layout 中 tile 行跨越两个 burst section(Two Bad Tiles),内存读取量翻倍
图14:Tile 对齐 vs 不对齐——左侧 Aligned Layout 中 burst section 恰好覆盖一个 tile 行(One Nice Tile),右侧 Unaligned Layout 中 tile 行跨越两个 burst section(Two Bad Tiles),内存读取量翻倍

但如果矩阵维度有偏移(如维度为 257 而非 256),则 tile 的行起始地址可能与 burst section 边界错位。此时读取一行 tile 需要触发两个 burst section——内存读取量翻倍。解决方案是 padding:在矩阵维度上填充到合适的对齐边界。

Andrej Karpathy 在 NanoGPT Speedrun 中的一个著名发现就体现了这一点:将词表大小从 50257 增加到 50304(补齐到 128 的倍数)后,获得了 25% 的加速。看似”增加了无用计算”,实际上是通过 padding 实现了 tile 对齐和合并访问的双重优化。

实用建议#

矩阵维度的最佳实践:

  • 能被 32 整除(warp 大小对齐)
  • 最好能被 128 或 256 整除(常见 tile 大小对齐)
  • 不需要追求更高的 2 的幂——一旦能被 burst section 整除,额外的整除性不提供进一步收益

波量化效应(Wave Quantization)#

Tiling 还会产生另一个更微妙的性能陷阱。以 A100(108 个 SM)为例,使用 256×128 的 tile 大小:

  • 矩阵维度 1792:产生 98 个 tile。98 < 108 SM,所有 tile 在一轮(一个”wave”)内并行处理完毕——全部 SM 都有活干
  • 矩阵维度 1793:产生 120 个 tile。120 > 108,第一轮处理 108 个 tile 后,还剩 12 个 tile。这 12 个 tile 需要启动第二轮——但此时只有 12 个 SM 在工作,其余 96 个 SM 完全空闲

图15:波量化效应——左侧曲线显示 1792→1793 处吞吐量骤降。右侧数学分析:tile 大小 256×128 时,1792 产生 98 tiles(< 108 SMs,一轮完成),1793 产生 120 tiles(> 108 SMs,需要两轮,第二轮仅 12 个 SM 工作)
图15:波量化效应——左侧曲线显示 1792→1793 处吞吐量骤降。右侧数学分析:tile 大小 256×128 时,1792 产生 98 tiles(< 108 SMs,一轮完成),1793 产生 120 tiles(> 108 SMs,需要两轮,第二轮仅 12 个 SM 工作)

仅仅增加一个维度,就从”一轮搞定”变成”需要两轮,且第二轮利用率只有 11%“。这就是 波量化——tile 总数与 SM 数量的整除关系决定了硬件利用率是否出现断崖式下降。在实测的 matmul 吞吐量图中,那些周期性的突然下降正是波量化效应的体现。

对 matmul 吞吐量曲线的完整解释#

至此,可以完整解释开头那张”矩阵维度 vs 吞吐量”的诡异曲线:

  1. 对角线上升趋势:小矩阵计算强度不足,位于 Roofline 的内存受限区。矩阵变大后计算强度提高,进入计算受限区
  2. 不同曲线层:按矩阵维度的最大因子着色——只能被 1 整除的(蓝色)性能最差(对齐/合并完全失效),能被 32 整除的(紫色)性能最好(完美对齐)
  3. 周期性骤降:波量化效应——tile 数恰好超过 SM 数时触发

Flash Attention:所有技巧的综合应用#

图16:Flash Attention 性能对比——左侧柱状图:PyTorch 标准实现 vs FlashAttention 的各操作耗时分解;中间表格:标准实现 HBM 读写 40.3 GB vs FlashAttention 仅 4.4 GB,运行时间从 41.7ms 降至 7.3ms;右侧:Block Size 对 HBM 访问量和运行时间的影响曲线
图16:Flash Attention 性能对比——左侧柱状图:PyTorch 标准实现 vs FlashAttention 的各操作耗时分解;中间表格:标准实现 HBM 读写 40.3 GB vs FlashAttention 仅 4.4 GB,运行时间从 41.7ms 降至 7.3ms;右侧:Block Size 对 HBM 访问量和运行时间的影响曲线

Flash Attention 是前面所有优化技巧的集大成之作——它将 tiling、fusion 和 recomputation 三种思想融合在一个高度优化的 attention kernel 中,实现了 attention 计算延迟的数量级降低和内存占用从 O(n2)O(n^2)O(n)O(n) 的缩减。

Attention 的计算结构#

标准 attention 包含三个矩阵乘法和一个全局 softmax:

Attention(Q,K,V)=softmax(QKTd)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right) V

操作分解:

  1. 计算 QKTQK^T——矩阵乘法
  2. 对结果施加 softmax——逐行全局归一化
  3. VV 相乘——矩阵乘法

矩阵乘法部分,运用 tiling 即可高效实现。真正的难点在于 softmax 是一个全局操作——它需要看到一整行的所有值才能计算归一化因子,这似乎使得分块处理变得不可能。

核心创新:在线 Softmax(Online Softmax)#

标准 softmax 的计算方式:

softmax(xi)=eximax(x)jexjmax(x)\text{softmax}(x_i) = \frac{e^{x_i - \max(x)}}{\sum_j e^{x_j - \max(x)}}

这需要两轮遍历——第一轮找最大值,第二轮计算指数和与归一化。这要求整行数据同时可见。

在线 softmax 的关键观察:可以 逐块增量计算 softmax。算法维护两个运行状态:

  • 当前已见的最大值 mm
  • 当前的指数累加和 \ell

当处理新的一块数据时:

  1. 如果新块中出现了更大的值 mnew>mm_{new} > m,则更新最大值
  2. 对之前的累加结果乘以修正因子 emoldmnewe^{m_{old} - m_{new}} 来补偿最大值的变化
  3. 将新块的指数贡献累加到 \ell

这样,softmax 可以一块一块地处理,而无需一次性看到整行数据。每个 tile 处理完毕后,只需要保存极少量的中间状态(当前最大值和累加和)——这些可以存放在寄存器中,开销几乎为零。

Flash Attention 的完整算法#

图17:Flash Attention tiling 架构——左侧金字塔展示三级内存层级及其带宽(SRAM 19 TB/s、HBM 1.5 TB/s、CPU DRAM 12.8 GB/s);右侧为 KQV tiling 的内外循环结构,虚线框内操作在 SRAM 中完成
图17:Flash Attention tiling 架构——左侧金字塔展示三级内存层级及其带宽(SRAM 19 TB/s、HBM 1.5 TB/s、CPU DRAM 12.8 GB/s);右侧为 KQV tiling 的内外循环结构,虚线框内操作在 SRAM 中完成

将 tiling、online softmax 和 fusion 三者结合:

  1. Tiled MatMul:将 QQKK 矩阵分块,按 tile 执行内积。结果保留在 SRAM(共享内存)中,不写回 HBM
  2. 在 tile 内计算 softmax:利用在线 softmax 算法,逐 tile 更新指数和与归一化因子。所有中间结果保持在共享内存中
  3. Fusion:整条计算链(matmul → exp → softmax → matmul with V)在一次 kernel 调用中完成,中间结果不离开 SRAM
  4. 最终输出:所有 tile 处理完毕后,执行最终归一化并与 VV 相乘,结果写回 HBM

从 Flash Attention 2 的可视化中可以清晰看到:虚线框内的操作全部在 SRAM 中执行(tiled matmul、exp 计算、running softmax),蓝色方框是驻留在 HBM 中的大矩阵(QQKKVV 和输出 OO)。数据只在开始时从 HBM 读入 SRAM、在结束时从 SRAM 写回 HBM——中间的全部计算都在快速内存中完成。

Recomputation 在 Flash Attention 中的作用#

如果保存前向过程的中间激活值(即 n×nn \times n 的 attention 矩阵),内存需求是 O(n2)O(n^2)——对于长序列这是不可接受的。Flash Attention 的解决方案:

  • 前向时:不保存完整的 n2n^2 attention 矩阵,只保存每行的 softmax 统计量(最大值和累加和,O(n)O(n) 存储)
  • 反向时:利用保存的 softmax 统计量,重新逐 tile 计算 attention 矩阵的对应部分,即时计算出反向传播所需的值

这正是”重计算”技巧的经典应用:用廉价的重复计算换取昂贵的 O(n2)O(n^2) 内存。

Flash Attention 的性能来源#

Flash Attention 的加速本质上来自 HBM 访问次数的亚二次方缩减。标准 attention 实现需要 O(n2)O(n^2) 次 HBM 读写(写入和读取完整的 attention 矩阵),而 Flash Attention 通过 tiling + fusion 将 HBM 访问降低到接近线性。在 Roofline 的视角下,Flash Attention 通过大幅提升算术强度(同样的 FLOPS,更少的内存访问),将 attention 从深度内存受限区推向了计算受限区的边界。

总结#

Flash Attention 是一个极好的教学案例——它完美展示了如何将 GPU 系统层面的知识转化为实际的算法创新:

  • Tiling:将 attention 矩阵按 tile 分块处理
  • Online Softmax:使全局归一化操作可以逐 tile 增量执行
  • Fusion:整条计算链在单一 kernel 中完成,中间结果不离开 SRAM
  • Recomputation:反向传播时重新计算 attention 值而非存储 O(n2)O(n^2) 的中间结果
Lecture 5:GPU 与 TPU 硬件原理
https://www.xwysyy.cn/posts/cs336/lec05/
作者
xwysyy
发布于
2026-05-17
许可协议
CC BY-NC-SA 4.0
© 2026 xwysyy. All Rights Reserved.
Powered by Astro & Firefly

文章目录