Lecture 6:GPU Kernel 编程与 Triton 实战
Stanford CS336 Language Modeling from Scratch | Spring 2026 | Lecture 6: Kernels, Triton, XLA,时长 1:26:41。
GPU 硬件架构:内存层级与计算单元#
理解 GPU kernel 编程的第一步,是在脑中建立一幅清晰的硬件层级图。以 NVIDIA 的 GPU 为例——从 M100 到 H100 再到 B200,每一代架构都在这同一套基本框架上做量级提升。
SM:流式多处理器#
一块 GPU 芯片上排布着约 100 到 200 个 Streaming Multiprocessor (SM),这个数量在代际之间变化不大。每个 SM 是一个相对独立的计算单元,内部包含一组寄存器、共享内存/L1 缓存,以及 Tensor Core 等专用计算硬件。以 B200 为例,每个 SM 拥有 65,536 个寄存器,总计 256 KB 的寄存器文件——这一数字同样在历代间保持相对稳定。
四级内存层级#

GPU 的内存体系呈金字塔状分布,自顶向底分别是:
-
寄存器(Registers):位于 SM 内部,速度最快,容量最小。每个 SM 约 256 KB(B200),每个线程最多可使用 255 个寄存器。寄存器是线程私有的——一个线程的寄存器对其他线程不可见。
-
L1 缓存 / 共享内存(Shared Memory):同样位于 SM 内部,大小与寄存器在同一数量级。L1 和 Shared Memory 实际上是同一块物理存储,只是 L1 由硬件自动管理,而 Shared Memory 由程序员显式控制。Shared Memory 是线程块(Thread Block)级别共享的——同一个线程块内的所有线程可以通过它进行通信。
-
L2 缓存:不再是每个 SM 独有,而是整个芯片共享的全局缓存。容量比 L1/Shared Memory 大一些,但带宽相应降低。
-
HBM(High Bandwidth Memory):即所谓的”全局内存”或”显存”。B200 配备了高达数十 GB 的 HBM,这也是代际间增长最显著的参数。所有 SM 共享同一块 HBM,它是数据在 GPU 上的主要驻留位置。
速度与容量的此消彼长#
这四级内存遵循一条清晰的规律:越靠近计算单元,速度越快,但容量越小。寄存器的访问延迟只有个位数时钟周期;L1/Shared Memory 稍慢;L2 再慢一档;HBM 的访问延迟可达数百个时钟周期。然而 B200 的 HBM 带宽已经达到 8 TB/s——在绝对意义上并不慢,只是相比寄存器而言差距悬殊。
这套层级的核心启示是:大容量内存(HBM)慢而远,小容量内存(寄存器、L1)快而近。整个 kernel 编程的核心目标,就是尽可能让计算发生在快速内存中,减少对 HBM 的读写次数。
新一代硬件的补充特性#
H100 和 B200 引入了 Thread Block Cluster(线程块集群),允许多个线程块之间进行有限的分布式内存共享。B200 还新增了 Tensor Memory,专门服务于 Tensor Core,其层级位于寄存器和 Shared Memory 之间。这些特性在编程时通常对程序员透明,但在硬件内部发挥着作用。对于入门级的 kernel 编程,暂时可以忽略它们。
GPU 编程模型:线程、线程块与网格#
有了硬件层级的认知,接下来需要理解如何在这套硬件上组织计算。GPU 的编程模型提供了一层优雅的抽象,将底层硬件细节封装起来——编写正确的程序只需理解这层抽象,但要编写高性能的程序,则必须深入硬件。
三级并行层次#

GPU 的并行计算组织为三个层级:
线程(Thread) 是最小的执行单元。每个线程执行同一段代码,但作用于数据的不同部分。可以把一个线程想象成一支笔,负责在输出矩阵上填写一个(或几个)元素。
线程块(Thread Block),也称为 CTA(Concurrent Thread Array),是一组线程的集合。同一个线程块内的线程被调度到同一个 SM 上执行,它们共享该 SM 的 Shared Memory,可以彼此通信和同步。线程块是协作的基本单位。
网格(Grid) 是所有线程块的集合。当发起一个 kernel 调用时,实际上就是在启动一个由大量线程块组成的网格,让它们并行地完成整个计算任务。
为什么需要线程块?#
一个自然的疑问是:为什么不能直接用一个扁平的线程网格,让每个线程各自处理一个数据元素?
对于 element-wise 操作(如 GELU 激活函数),这种模式确实够用——每个线程独立读取一个元素、计算、写回,线程之间不需要任何通信。
但对于 需要跨元素聚合的操作(如 softmax、矩阵乘法),线程之间必须交换中间结果。如果线程间的通信只能通过 HBM 完成——每个线程写回 HBM,另一个线程再从 HBM 读取——那代价将极其高昂。
线程块的存在正是为了解决这个问题。同一线程块内的线程共享 Shared Memory,可以在不经过 HBM 的情况下高效通信。线程块的典型工作模式是:从 HBM 批量读入一块数据到 Shared Memory → 在 Shared Memory 中完成需要线程间协作的计算 → 将结果批量写回 HBM。
编程模型 vs 硬件现实#
编程模型和硬件之间存在一个重要的张力关系。编程模型提供了清晰的抽象:你定义线程块的数量和每个线程块的行为,线程块内部的代码读起来就像普通的 Python/C。如果只关心正确性,这些就足够了。
但性能高度依赖硬件。Warp 的调度方式、寄存器的分配策略、内存访问的合并模式、Shared Memory 的 bank 结构——这些硬件层面的细节直接决定了同一个逻辑正确的 kernel 是跑得飞快还是慢得离谱。这也是为什么要专门学习 GPU kernel 编程:目的不是写出正确的代码(PyTorch 已经帮你做了),而是榨取硬件性能。
与内存层级的映射#
三级并行层次与内存层级之间存在清晰的对应关系:
- HBM 对所有线程全局可见
- Shared Memory 在线程块级别共享
- 寄存器 是线程私有的
这种映射关系是理解 Triton 编程的基础——在 Triton 中,程序员的思维粒度正是线程块级别:一个线程块从 HBM 加载数据、在 Shared Memory 中操作、最后写回 HBM。
硬件细节与性能陷阱#
编程模型提供了简洁的抽象,但性能却由硬件的各种约束和特性决定。以下五个硬件细节是 kernel 性能调优中最常遇到的考量。
Warp 与控制分歧#
在编程模型中,线程是最小的执行单元,线程块是协作的基本单位。但在硬件层面,还存在一个中间层级——Warp。每个线程块内的线程被划分为若干 warp,每个 warp 固定包含 32 个线程。例如,一个包含 64 个线程的线程块会被拆分为两个 warp(线程 0-31 和线程 32-63)。
Warp 的核心特性是 SIMT(Single Instruction, Multiple Threads):同一个 warp 中的所有 32 个线程在每个时钟周期必须执行完全相同的指令。这意味着如果代码中出现分支——例如 if condition: A else: B——而 warp 中的不同线程需要走不同的分支,那么硬件会先让所有线程执行分支 A(不满足条件的线程空转),再执行分支 B(满足条件的线程空转)。这就是控制分歧(Control Divergence),它会将并行执行退化为串行执行,严重影响性能。
核心准则:在 warp 内部避免分支。尽可能让同一 warp 中的线程走相同的执行路径。
Warp 调度与延迟隐藏#
Warp 的另一个重要特性是 零开销上下文切换。每个 SM 同时维护多个 resident warp,warp 调度器可以在这些 warp 之间无代价切换——这与 CPU 的上下文切换代价高昂形成鲜明对比。
这种设计专门用于隐藏内存访问延迟。当一个 warp 执行到 tl.load(即从 HBM 读数据)时,它会阻塞等待数百个时钟周期——这类似于 CPU 中的 trap call。此时 SM 不会闲等,warp 调度器立即切换到另一个准备好的 warp 去做 Tensor Core 运算或其他计算。等 HBM 数据到达后,调度器再将执行权交回原来的 warp,继续执行 load 之后的指令。
Warp Occupancy 与寄存器压力#
Occupancy 衡量的是 SM 上实际运行的 warp 数量与硬件支持的最大 warp 数量之比。以 B200 为例,硬件约束如下:
- 每个 SM 最多 65,536 个寄存器
- 每个线程最多使用 255 个寄存器
- 每个 SM 最多同时运行 64 个 warp
考虑一个具体例子:假设线程块有 128 个线程(即 4 个 warp),每个线程使用 160 个寄存器。那么一个线程块消耗的寄存器总量为 128×160=20,480,因此一个 SM 最多容纳 ⌊65,536/20,480⌋=3 个线程块,对应 3×4=12 个 warp,occupancy 为 12/64=18.75%。
但 occupancy 并非越高越好。占用率低可能意味着每个线程有更多寄存器可用、能做更多工作。Thread Coarsening(线程粗化)就是一种有意降低 occupancy 的策略:让每个线程处理多个元素(比如 8 个而非 1 个),用更少的线程完成同样的工作。如果线程本身很轻量,这种”增肥”反而能减少调度开销、提升性能。
Bank Conflict:共享内存的隐形瓶颈#
Shared Memory 在物理上被划分为 32 个 bank,每个 bank 宽 4 字节。在每个时钟周期内,每个 bank 最多被一个线程访问(访问同一地址的情况除外)。如果多个线程试图访问同一个 bank 的不同地址,这些访问就必须串行化——这就是 Bank Conflict。
最坏情况是 32-way bank conflict:32 个线程同时访问同一个 bank 的不同行(比如访问一个矩阵的同一列),此时原本应并行完成的 32 次访问被完全串行化。
对于 element-wise 操作,bank conflict 通常不是问题,因为可以自由选择访问顺序。但对于矩阵乘法,不可避免地需要同时访问行和列。一种解决方案是 Swizzling——通过特殊的内存布局重排,使得列访问时不同线程分散到不同 bank。
Memory Coalescing:HBM 访问的合并#
当一个 warp 中的 32 个线程访问 HBM 时,硬件会尝试将这些访问**合并(Coalesce)**为若干次 cache line 事务(每次 128 字节)。
Full Coalescing(完全合并)发生在所有线程访问的地址落在同一个 cache line 内时——例如线程 0 访问 M[0][0],线程 1 访问 M[0][1],以此类推。此时一次 cache line 读取就能满足全部 32 个线程的需求。
Poor Coalescing(不良合并)发生在线程按列访问矩阵时——线程 0 访问 M[0][0],线程 1 访问 M[1][0]。此时每个线程的访问落在不同的 cache line 上,硬件需要发起大量独立的内存事务,大部分读取的数据被浪费。
Memory coalescing 针对的是 HBM 访问,而 bank conflict 针对的是 Shared Memory 访问——两者是不同层级的约束,但背后的思想一致:让并行线程的内存访问模式与硬件的存储组织方式对齐。
Block Occupancy:尾部效应#
线程块在逻辑上可以定义任意多个,但物理上只有有限个 SM(如 B200 的 148 个)可以同时执行。如果启动 160 个线程块,前 148 个会被分配到全部 SM 上并行执行,剩余 12 个必须等待。问题在于,第二波只有 12 个线程块在运行,其余 136 个 SM 处于空闲——这就是尾部效应(Tail Effect),导致大量硬件资源浪费。
实践建议:让线程块的数量是 SM 数量的整数倍,避免最后一波出现严重的资源浪费。或者调整线程块大小(block size),使线程块总数能更均匀地分配到所有 SM。
一个相关问题是:能否让多个 block 共用一个 SM 来减轻尾部效应?答案取决于每个 block 的资源消耗——如果一个 block 已经用满了 SM 上的 Tensor Core,再往上放一个 block 并不会加速。真正的解决方案是调整 block size(从而改变 block 总数),而不是试图共用 SM。
Benchmarking 与 Profiling#
在动手优化 kernel 之前,必须先建立对性能现状的精确认知。Benchmarking 和 profiling 是两个互补的工具:前者告诉你总共花了多久,后者告诉你时间花在了哪里。
Benchmarking:端到端计时#
Benchmarking 将一个操作的性能蒸馏为一个数字——执行时间。虽然它不揭示瓶颈所在,但因为度量的就是最终关心的指标(延迟),加上可以方便地观察性能如何随输入规模变化,因此非常实用。
正确的 GPU benchmarking 需要注意几个陷阱:
Warm-up:GPU 上的某些操作是懒编译的(如 torch.compile 的首次调用会触发 Triton 编译)。如果不做 warm-up,编译时间会被计入测量结果,严重失真。通常在正式计时前先运行几次操作,确保所有编译和缓存都已完成。
CUDA Events 计时:正确的计时方式是使用 CUDA Events,而非 Python 的 time.time():
1start = torch.cuda.Event(enable_timing=True)2end = torch.cuda.Event(enable_timing=True)3
4start.record()5op() # 实际计算6end.record()7
8torch.cuda.synchronize() # 等待 GPU 完成9elapsed = start.elapsed_time(end) # 毫秒关键在于 torch.cuda.synchronize()——GPU 操作是异步的,如果不显式等待,CPU 侧的计时会在 GPU 还在运算时就结束了。
多次运行取统计量:单次运行的结果存在方差。通常重复多次取平均值;严格时应观察分布、取 P95 或中位数。
Scaling 分析:一个有价值的做法是改变输入维度,绘制时间-维度曲线。以矩阵乘法为例,理论上时间应随维度呈立方增长,但实际观察会发现:维度从 2 增长到约 2000 时,计算时间几乎保持常数——这是因为 GPU 的大量并行计算单元在小矩阵上严重欠载,直到矩阵足够大才能充分利用硬件。这种 “floor” 效应说明 GPU 天生就是为大矩阵设计的。
Profiling:看清引擎盖下面的东西#
Profiling 比 benchmarking 多了一层维度——它告诉你时间的分布。但 profiling 的价值不仅在于定位瓶颈,更在于揭示高层代码实际触发了什么底层操作。在 PyTorch 这样的高层框架中,你写 A + B,下面可能调用了名为 vectorized_elementwise_kernel<CUDAFunctor_add> 的 CUDA kernel;你写 A @ B,下面调用的是 CUTLASS 库中一个针对特定架构和 tile 大小优化的 matmul kernel。
PyTorch 内置了 profiler(torch.profiler),使用方式很直接:
1with torch.profiler.profile() as prof:2 op()3print(prof.key_averages().table())Profiler 揭示的几个关键信息:
-
实际调用了哪些 CUDA kernel:每个 kernel 都有一个长名字,包含了架构、数据类型、tile 大小等信息。例如
cutlass_sm100_...f32_f32_64x64x16表示这是一个针对 SM100(Blackwell)架构、FP32 精度、64×64×16 tile 的 CUTLASS matmul kernel。 -
不同输入维度会触发不同的 kernel:同样是
torch.matmul,4096×4096 的矩阵会调用64x64x16tile 的 kernel,128×128 的矩阵则调用32x32x16的 kernel。PyTorch 的运行时会根据输入大小选择不同的底层实现。 -
时间分布一目了然:如果一个
naive_gelu实现的 profiler 输出显示了binary_functor、unary、add、tanh等多个独立的 kernel,那就说明这个实现没有做 kernel fusion——每个 PyTorch 操作对应一次独立的 kernel 调用,每次都要在 HBM 和 SM 之间搬运数据。
案例分析:GELU 的三种实现与 Kernel Fusion#
GELU(Gaussian Error Linear Unit)是 Transformer 中常用的激活函数,其 tanh 近似形式为:
GELU(x)=0.5⋅x⋅(1+tanh(2/π⋅(x+0.044715⋅x3)))
虽然数学上只是一个 element-wise 操作,但三种不同的实现方式在性能上天差地别,这个例子清晰地揭示了 kernel fusion 的核心思想。
三匹赛马#
朴素 PyTorch 实现:直接将上述公式翻译为 PyTorch 表达式——0.5 * x * (1 + torch.tanh(...))。这段代码完全正确,但在 profiler 下暴露出严重的性能问题:它触发了多个独立的 CUDA kernel——binary_functor(乘法)、unary(tanh)、add 等。每个 kernel 都要从 HBM 读取输入、计算、写回 HBM,下一个 kernel 再从 HBM 读取上一步的输出。对于一个 element-wise 操作而言,计算量微乎其微,时间几乎全花在 HBM 的反复读写上。
PyTorch 内置实现(torch.nn.functional.gelu):PyTorch 标准库中有人专门为 GELU 手写了一个 CUDA kernel(gelu_cuda_kernel)。这个 kernel 将整个 GELU 计算融合为一次 kernel 调用——从 HBM 读取一次、做完全部运算、写回一次。速度比朴素实现快得多。
torch.compile 编译版本:对朴素 PyTorch 实现调用 torch.compile,编译器会自动分析计算图,识别出所有操作都是 element-wise 的,然后自动生成一个融合的 Triton kernel。Profiler 显示编译后的版本只调用了一个 Triton kernel。
性能对比:在实际测量中,朴素实现耗时约 3.75ms,内置版快得多,编译版也显著快于朴素版,但略慢于内置版。这个结果值得注意——torch.compile 自动生成的 Triton kernel 虽然已经完成了 fusion,但仍不及 NVIDIA 工程师手写的 CUDA kernel。而且这种差距是硬件相关的:在之前的硬件上,两者性能更为接近。
性能差异的根源:Kernel Fusion#
这三种实现之间的性能差异,本质上就是 kernel fusion(算子融合)的有无。
朴素实现中,计算图的每个节点——乘法、tanh、加法——各自对应一次独立的 kernel 启动(profiler 中可以看到 binary_functor、add、tanh 等各自独立的 kernel)。每次 kernel 启动的代价并不在于计算本身(GELU 的计算极其简单),而在于 HBM ↔ SM 的数据搬运。假设张量有 N 个元素,朴素实现需要约 5N 次 HBM 读取和 3N 次 HBM 写入——因为中间结果在每个 kernel 之间必须经过 HBM 中转。
而融合后的单 kernel 实现只需 N 次读和 N 次写。对于这种 memory-bound(内存瓶颈型)操作——计算密度极低、时间主要由内存带宽决定——减少 HBM 访问次数就是最直接的加速手段。
torch.compile 的意义#
torch.compile 的工作机制是:追踪 PyTorch 代码的计算图,识别可以融合的操作序列,然后自动生成 Triton kernel 代码。这意味着:
- 你不需要手写 CUDA kernel 就能获得接近手写的性能
- 编译器可以做到的不止是简单的 element-wise 融合——它能识别更复杂的模式
- 但编译器不是万能的——对于高度优化的操作(如经过精心调优的 GELU CUDA kernel),编译器生成的代码可能仍有差距
关键洞察:在 PyTorch 中看似简单的一行代码(
A + B、gelu(x))背后,可能触发了截然不同的底层执行路径。理解这一点是写出高性能代码的前提。
Triton 编程入门#
Triton 是由 OpenAI 开发的 GPU 编程语言,如今已经成为 PyTorch 生态中 kernel 编写的主流选择。它在 CUDA(线程级编程)和 PyTorch(张量级编程)之间找到了一个精妙的平衡点——以线程块为思维粒度。
CUDA vs Triton:思维粒度的差异#
在 CUDA 中,程序员思考的是”每个线程做什么”。你为每个线程编写一段代码,线程通过自己的 ID 定位自己负责的数据元素。对于 element-wise 操作,这很自然——每个线程处理一个元素。但一旦涉及需要线程间通信的操作(如 softmax 的行归约、矩阵乘法的 tile 累积),程序员就必须手动管理 Shared Memory 的分配、线程同步、数据搬运——这些 bookkeeping 工作繁琐且容易出错。
在 Triton 中,程序员思考的是”每个线程块做什么”。一个线程块从 HBM 加载一块数据、在 Shared Memory 中操作、最后写回 HBM。线程级别的细节——线程如何分配、如何同步、数据放在寄存器还是 Shared Memory——全部由 Triton 编译器自动决定。
这使得 Triton 代码在结构上与 PyTorch 非常相似:一旦数据被加载到”块”中,就可以用类似 NumPy 的向量化语法进行操作。
第一个 Triton Kernel:Element-wise Value Copy#
以一个最简单的例子入手——将一个 8192 维向量的每个元素做简单变换。
Python 侧的调用代码:
1X = torch.randn(8192, device='cuda')2Y = torch.empty_like(X)3BLOCK_SIZE = 10244num_blocks = X.numel() // BLOCK_SIZE # = 85
6# 启动 kernel:[num_blocks] 定义了 grid 的形状7triton_value_kernel[num_blocks](X, Y, X.numel(), BLOCK_SIZE)注意两个关键点:(1)输出张量 Y 必须预先分配——Triton kernel 不像 PyTorch 那样返回新张量,而是直接写入指定的内存地址;(2)[num_blocks] 指定了 grid 中线程块的数量。
Kernel 代码:
1@triton.jit2def triton_value_kernel(X_ptr, Y_ptr, num_elements, BLOCK_SIZE: tl.constexpr):3 # 1. 确定自己的身份4 pid = tl.program_id(0) # 我是第几个 block?5
6 # 2. 计算负责的数据范围7 start = pid * BLOCK_SIZE8 offsets = start + tl.arange(0, BLOCK_SIZE) # 如 [1024, 1025, ..., 2047]9 mask = offsets < num_elements # 处理尾部不完整 block10
11 # 3. 从 HBM 加载数据12 x = tl.load(X_ptr + offsets, mask=mask)13
14 # 4. 计算(这里可以是任何 element-wise 操作)15 y = x # 或 y = gelu(x) 等16
17 # 5. 写回 HBM18 tl.store(Y_ptr + offsets, y, mask=mask)每个 Triton kernel 都遵循这个范式:wake up → 确定身份(program_id)→ 计算数据范围 → load → compute → store。
几个要点:
X_ptr和Y_ptr是指针(内存地址的整数表示),不是 PyTorch 张量。X_ptr + offsets是指针算术,得到一组内存地址。tl.arange(0, BLOCK_SIZE)返回一个向量[0, 1, ..., BLOCK_SIZE-1],概念上就是 NumPy 的np.arange。- Mask 是 Triton 代码中的常见模式:因为数据总量不一定整除 BLOCK_SIZE,最后一个 block 可能越界。Mask 为
False的位置不会被读写。 tl.load和tl.store是与 HBM 交互的唯一接口——load 将数据从 HBM 拉入 Shared Memory / 寄存器,store 将结果推回 HBM。
从 Triton 到 PTX:编译器做了什么#
Triton 代码并不会直接在 GPU 上执行——它先被编译为 PTX(Parallel Thread Execution),NVIDIA GPU 的中间汇编语言。PTX 才是实际在硬件上运行的指令序列。
查看上述 kernel 的 PTX 输出可以观察到几个有趣的细节:
-
PTX 操作的是线程级别,而不是线程块级别。Triton 的 block-level 抽象已经被编译器”展开”了。
-
Thread Coarsening 自动发生:编译器发现每个线程的工作太轻量,于是自动让每个线程处理 8 个元素(而非 1 个)。PTX 中可以看到 8 组相似的 load-compute-store 指令块。这就是前面提到的”线程粗化”策略——用更少的线程做更多的工作。
-
寄存器操作:
ld.global从 HBM 加载数据到寄存器(%r是整数寄存器,%f是浮点寄存器),算术操作在寄存器上完成,st.global写回 HBM。 -
线程身份识别:PTX 中通过
%ctaid.x(线程块索引,对应 Triton 的program_id)和%tid.x(线程在块内的索引)来定位每个线程。 -
所有线程执行同一段 PTX 代码——编译只发生一次,不同线程通过不同的
%ctaid.x和%tid.x值来区分自己的行为。
值得注意的是,PTX 仍然不是最终形态——SM 调度器如何将 warp 分配到具体的 SM、如何管理 warp 间的切换,这些都是硬件层面的行为,在 PTX 中不可见也不可控。
CUDA 与 Triton 的权衡#
对于简单的 element-wise 操作,CUDA 甚至比 Triton 更直观——每个线程一个元素,不需要考虑 block、offset、mask。但 Triton 的优势在复杂操作中显现:
- 需要跨元素通信时(softmax、matmul),CUDA 程序员必须手动管理 Shared Memory 和同步屏障,Triton 编译器自动处理
- Triton 的 block-level 抽象天然与 tiling 模式对齐——后面会看到这如何简化 matmul 的实现
- Triton 编译器可以自动做 thread coarsening、寄存器分配等底层优化
当然,Triton 也有局限——如果需要利用最新硬件的特殊功能(如 B200 的 Tensor Memory),Triton 可能还不支持,此时只能退回 CUDA 甚至 PTX。确实有人会手写 PTX——当你认为自己能比编译器做得更好时。NVIDIA 的编译器已经相当成熟,通常不需要这样做;但对于一些开发程度较低的加速器平台,有时确实需要更多的手动介入。不过对于绝大多数 Transformer 相关的 kernel,Triton 已经足够。
此外还有其他替代方案:CUTLASS(NVIDIA 的线性代数模板库,性能极致但代码复杂)、CuTe 等 DSL,它们在抽象层级和控制粒度上各有取舍。
Triton 实战:从 Softmax 到 Matmul#
掌握了 element-wise kernel 的基本范式后,接下来的三个例子将逐步引入更复杂的计算模式——行归约、tiling、二维 tiling——最终到达 matmul 这个 kernel 编程的”殿堂级”目标。
Level 1:Softmax(行归约,行可放入单个 Block)#
Softmax 是第一个超越 element-wise 的操作。它对矩阵的每一行独立执行”指数化 → 归一化”:
softmax(xi)=∑jexj−max(x)exi−max(x)
其中减去 max(x) 是为了数值稳定性。
朴素 PyTorch 实现的代价:如果用标准 PyTorch 逐步实现 softmax——先 max(),再减法,再 exp(),再 sum(),再除法——每一步都是一个独立 kernel,总共约 5MN 次 HBM 读取和 3MN 次 HBM 写入(M 是行数,N 是列数)。而最优方案只需 MN 次读和 MN 次写。
Triton 方案:将每一行映射为一个线程块。这是一个自然的划分——softmax 的行之间相互独立,无需跨行通信,因此不同行的线程块可以完全并行执行。
1@triton.jit2def softmax_kernel(X_ptr, Y_ptr, row_stride, N, BLOCK_SIZE: tl.constexpr):3 row = tl.program_id(0) # 我负责第几行4 col_offsets = tl.arange(0, BLOCK_SIZE)5 mask = col_offsets < N6
7 # 计算输入数据的内存地址8 x_ptrs = X_ptr + row * row_stride + col_offsets9 x = tl.load(x_ptrs, mask=mask, other=float('-inf'))10
11 # 核心计算:和普通 PyTorch 代码几乎一样12 x_max = tl.max(x, axis=0)13 x = x - x_max14 numerator = tl.exp(x)15 denominator = tl.sum(numerator, axis=0)16 y = numerator / denominator17
18 # 写回19 y_ptrs = Y_ptr + row * row_stride + col_offsets20 tl.store(y_ptrs, y, mask=mask)调用时,grid 大小就是行数 M,BLOCK_SIZE 取列数 N 向上取到最近的 2 的幂。
这里有一个优雅的特性:一旦数据被加载到 block 的 Shared Memory 中,后续所有操作(max、减法、exp、sum、除法)都不再涉及 HBM。整行数据只需一次 HBM 读 + 一次 HBM 写。mask=mask, other=float('-inf') 处理不完整 block 时,被 mask 掉的位置填充 −∞,确保它们在 softmax 中贡献为 0。
如果要按列做 softmax 而非按行,只需修改 stride 参数——将 row_stride 换成列方向的 stride,让指针沿列方向跳转即可。Triton 的指针算术模型使得行/列方向的切换非常自然。
关键洞察:当所有需要交互的数据能放入一个 block 时,Triton 代码读起来和普通 PyTorch 几乎一样。线程间的通信、Shared Memory 的管理、同步——全部由编译器处理。
Level 2:Row Sum(行不可放入单个 Block → Tiling)#
当行的长度超过 block 容量时——例如列数 N = 4096 但 BLOCK_SIZE = 1024——不能再将整行放入一个 block。此时需要引入 tiling(分块迭代)。
这里以 row sum(行求和)为简化例子。每个 block 仍然负责一行,但现在需要迭代地处理该行的多个 tile:
1@triton.jit2def row_sum_kernel(X_ptr, Y_ptr, row_stride, N, BLOCK_SIZE: tl.constexpr):3 row = tl.program_id(0)4 accumulator = tl.zeros([BLOCK_SIZE], dtype=tl.float32)5
6 # 迭代所有 tile7 for start in range(0, N, BLOCK_SIZE):8 offsets = start + tl.arange(0, BLOCK_SIZE)9 mask = offsets < N10 x = tl.load(X_ptr + row * row_stride + offsets, mask=mask, other=0.0)11 accumulator += x # 逐 tile 累加12
13 # 对 accumulator 向量做最终归约14 result = tl.sum(accumulator, axis=0)15 tl.store(Y_ptr + row, result)执行过程可以这样理解:假设一行有 12 个元素,BLOCK_SIZE = 4,则该行被分为 3 个 tile。Block 中的 4 个线程各自维护一个累加器:
- Tile 0(列 0-3):线程 0 读取
3,线程 1 读取1,线程 2 读取4,线程 3 读取1 - Tile 1(列 4-7):线程 0 累加
5→ 得到8,线程 1 累加9→ 得到10,… - Tile 2(列 8-11):继续累加…
- 最终对 4 个累加器执行
tl.sum,得到整行的总和
注意与前面 softmax 的关键区别:在 softmax 中,整行数据构成一个 block,不同列之间的”交互”(如求 max、求 sum)在 block 内部完成。而在 tiling 模式下,一行数据跨越多个 tile,每个 tile 被顺序加载到 Shared Memory 中处理。累加器(accumulator)在多个 tile 的迭代间持续存在——它驻留在 Shared Memory 或寄存器中(由 Triton 编译器自动决定,程序员不能显式指定),不需要写回 HBM。一般来说,当 block size 较小时 accumulator 可以放在寄存器中;当 block size 足够大时,accumulator 必须放在 Shared Memory 中。
还有一个容易混淆的概念区分:在 element-wise kernel(如 GELU)中,一个长向量也被切分为多个 block,但那里每个 block 是独立的、互不通信。而在 tiling 模式下,多个 tile 属于同一个 block,这个 block 需要迭代处理所有 tile 来完成一行的归约。Block 是并行的最小调度单位;tile 是一个 block 内串行处理的数据分片。
Level 3:矩阵乘法(二维 Tiling)#
矩阵乘法是 kernel 编程的核心场景。给定 A∈RM×K 和 B∈RK×N,计算 C=A×B∈RM×N。
朴素方案及其缺陷#
最直接的思路:为 C 的每个元素分配一个线程。计算 C[m][n] 时,线程遍历 k=0,1,…,K−1,从 HBM 读取 A[m][k] 和 B[k][n],累积乘积。
这个方案正确但极其低效。HBM 读取次数为 O(MKN)(每个 (m,k,n) 组合都需要两次读取),而计算量也是 O(MKN),所以算术强度(Arithmetic Intensity)= 计算量 / 内存传输量 ≈ 常数。一个常数的算术强度意味着操作完全是 memory-bound 的——计算单元大部分时间在等数据。
问题的根源在于冗余读取。举个具体例子:计算 C[1][0] 时需要读取 A 的第 1 行的所有元素(A[1][0],A[1][1],A[1][2]);计算 C[1][1] 时又要读取同样的 A 第 1 行——在朴素方案中,这一行被完整读了两次。类似地,B 的每一列也被不同的线程反复读取。
理想方案:全部放入 Shared Memory#
如果能把整个 A 和 B 都加载到 Shared Memory 中,那么 HBM 读取次数降为 O(MK+KN)(各读一次),算术强度提升到 O(N) 量级——这是理想情况。但问题在于 Shared Memory 容量有限(通常几十到上百 KB),大矩阵根本放不下。
实际方案:Tiled Matmul#
Tiled matmul 是朴素方案和理想方案之间的精妙折中:全局看像朴素方案(逐步积累),局部看像理想方案(tile 内无冗余读取)。
具体做法:
- 将输出矩阵 C 划分为若干 二维 tile,每个 tile 的大小为
BLOCK_M × BLOCK_N - 每个 tile 由一个线程块负责计算
- 计算一个 tile 时,沿 K 维迭代:每步加载 A 的一个
BLOCK_M × BLOCK_K子块和 B 的一个BLOCK_K × BLOCK_N子块到 Shared Memory - 在 Shared Memory 中做这两个小矩阵的乘法(这一步是”理想方案”的局部化),结果累加到 accumulator
- 迭代完所有 K 方向的 tile 后,将 accumulator 写回 HBM
1@triton.jit2def matmul_kernel(A_ptr, B_ptr, C_ptr,3 M, N, K,4 stride_am, stride_ak,5 stride_bk, stride_bn,6 stride_cm, stride_cn,7 BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr):8 # 确定身份:我负责 C 的哪个 tile9 pid_m = tl.program_id(0)10 pid_n = tl.program_id(1)11
12 # 计算 A、B 中要读取的行/列范围13 rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) # A 的行索引14 rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) # B 的列索引15 rk = tl.arange(0, BLOCK_K) # K 维索引16
17 # 初始化 accumulator(Shared Memory 中的 BLOCK_M × BLOCK_N 矩阵)18 acc = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)19
20 # 沿 K 维迭代所有 tile21 for k_start in range(0, K, BLOCK_K):22 # 加载 A 的一个 tile (BLOCK_M × BLOCK_K)23 a = tl.load(A_ptr + rm[:, None] * stride_am + (k_start + rk[None, :]) * stride_ak)24 # 加载 B 的一个 tile (BLOCK_K × BLOCK_N)25 b = tl.load(B_ptr + (k_start + rk[:, None]) * stride_bk + rn[None, :] * stride_bn)26 # 在 Shared Memory 中做 tile-level matmul 并累加27 acc += tl.dot(a, b)28
29 # 可选:kernel fusion——在写回前施加 element-wise 非线性30 acc = tl.maximum(acc, 0) # ReLU31
32 # 写回 HBM33 tl.store(C_ptr + rm[:, None] * stride_cm + rn[None, :] * stride_cn, acc)这里 tl.dot(a, b) 是 Triton 中的 tile-level matmul——一旦数据在 Shared Memory 中,就可以直接调用这个原语,不需要关心底层的线程分配和同步。
算术强度分析:每个 tile 需要从 HBM 读取 O(BLOCK_M⋅K+K⋅BLOCK_N) 的数据(每步读入的 A tile 和 B tile),执行 O(BLOCK_M⋅BLOCK_N⋅K) 次乘加运算。算术强度约为 O(tile_size),虽然达不到理想的 O(N),但只要 tile 足够大,就能充分利用计算单元。
Kernel Fusion 的附带好处#
在 matmul kernel 中融合 element-wise 激活函数几乎是”免费”的。因为 C 的 tile 已经计算完毕并驻留在 Shared Memory 中,对它施加 ReLU(或 GELU 等)只需在 tl.store 之前加一行代码——不会产生额外的 HBM 读写。如果不融合,matmul 和 ReLU 就是两个独立 kernel,中间需要 C 经过 HBM 一次完整的写入-读取循环。
这正是实际训练中常见的 “linear layer = matmul + activation” 模式,也是为什么手写或编译优化的 fused kernel 在实践中如此重要。
Strides:多维张量的线性化#
Triton kernel 操作的是一维的连续内存。要在这片内存中定位二维矩阵的 [row, col] 元素,需要知道 stride——沿每个维度移动一步所需的内存偏移量。对于一个 M×N 的行主序矩阵:
stride_row = N(向下移一行 = 跳过 N 个元素)stride_col = 1(向右移一列 = 移动 1 个元素)
地址计算:ptr + row * stride_row + col * stride_col。转置矩阵的 stride 正好翻转。在 Triton kernel 中,所有指针运算都基于 stride,这使得同一个 kernel 可以透明地处理行主序和列主序(转置)矩阵。
从四个例子看递进的复杂度#
回顾全部四个例子,可以清晰地看到一条从简单到复杂的路径:
- Element-wise(GELU):最简单——每个 block 独立处理一片数据,block 之间无通信,block 内也无跨元素交互
- Softmax(行归约,行 ≤ block):引入 block 内的归约操作(max、sum),但整行数据仍能放入一个 block
- Row Sum(行归约,行 > block → tiling):一个 block 必须迭代处理多个 tile,引入 block 内的 for 循环和累加器
- Matmul(二维 tiling):二维 tile 划分 + 沿 K 维迭代 + Shared Memory 中的 tile-level matmul
这四步递进构成了理解 Flash Attention 等高级 kernel 的基础——Flash Attention 本质上就是将 attention 计算中的 softmax 和 matmul 融合在一起,在 tiling 框架下同时处理 Q、K、V 三个矩阵。
部分内容可能已过时