FlashAttention
从 GPU 内存层级出发,拆解 Tiling、Online Softmax、Recomputation 三大核心技术
FlashAttention 论文总览#
FlashAttention 系列共有四篇工作,本文聚焦的是 2022 年发表的第一篇,第一作者是 Tri Dao。这篇工作使 Tri Dao 成为大模型 infra 领域的标志性人物。
当代大模型的核心架构是 Transformer,而 Transformer 最核心的 mechanism 就是 Attention。如何让 Attention 变得更快,一直是业界的核心难题。Tri Dao 从一个被所有人忽视的角度攻克了这个问题。
核心观察#
Transformer 处理长上下文时存在根本性限制:序列长度增长导致计算变慢、内存占用急剧增大,原因在于 Attention 的 quadratic(平方级)复杂度。业界此前的主流方法是牺牲模型质量来换取复杂度降低,包括 Sparse Approximation(如 Sliding Window Attention)和 Low-Rank Approximation(如 Linear Attention、MLA)。但即便这些方法在理论上降低了 FLOP 数,真实的 wall-clock time 并没有显著减少。
Tri Dao 发现了一个被所有人忽视的原理:优化目标不应该是减少计算复杂度,而是减少对 GPU 内存的读取和写入。他提出将 Attention 算法变得 IO-Aware,即对 GPU 不同层级内存之间的读写开销有感知。
这里需要区分一个概念:GQA(Grouped-Query Attention)并不属于近似算法,因为它只是减少了 KV 的维度,仍然是精确算法。
核心贡献#
FlashAttention 算法有以下关键特性:
- Exact Algorithm:对模型质量没有任何损失,不是近似计算
- IO-Aware:对不同层级内存的读写有感知
- Tiling 技术:将大矩阵分块加载到 SRAM 中计算,减少 HBM 读写次数
- Recomputation:不存储巨大的中间矩阵 S 和 P,需要时重新计算
- 最优性证明:在实际硬件参数范围内,FlashAttention 的 IO 复杂度达到了理论下界
实际效果#
- BERT 模型:端到端 15% 的加速
- GPT(Decoder-Only):3 倍的速度提升
- 内存使用从 O(N2) 降至 O(N),使得长上下文训练成为现实
- 在 Path-X Challenge(序列长度 16K)上实现了从零到非零的突破,Path-256(序列长度 64K)上也是第一个超越 random baseline 的序列模型
- 在 Vision、长文档等领域实现了显著突破
FlashAttention 已经开源,对业界产生了巨大推动力。
GPU 内存层级与 Attention 的瓶颈所在#
要理解 FlashAttention 为什么有效,首先需要理解 GPU 的内存层级架构,这是整篇论文的立足点。
GPU 内存分为三个层级:
| 层级 | 速度 | 大小 | 特点 |
|---|---|---|---|
| SRAM(on-chip) | 19 TB/s | ~20 MB | 离 Register 最近,每个 SM 的 shared memory |
| HBM(GPU 显存) | 1.5 TB/s | 40 GB(A100) | 最常用的参数存储,后续发展到 80GB(A100 80G)、140GB(H100)、数百 GB(Blackwell) |
| DRAM(CPU 内存) | 12.8 GB/s | >1 TB | 速度比 HBM 慢约 100 倍,训练中通常不涉及 |
关键数字:SRAM 比 HBM 快约 10 倍,但容量只有 HBM 的 1/2000。大量数据只能存在 HBM 中,但每次从 HBM 读写都有巨大的延迟代价。
GPU Execution Model#
GPU 的基本执行单元是 Thread,每个 Thread 执行一个具体的 operation。一个重要的实用技术是 Kernel Fusion,即把多个 element-wise operation 合并为一个 kernel。
以一个简单的线性函数 y=ax+b 为例:
- 不做 Fusion:从 HBM 读 x,计算 ax,写回 HBM;再读出来,计算 +b,写回 HBM。共 2 次读、2 次写。
- 做 Fusion:从 HBM 读 x,计算 ax+b,写回 HBM。共 1 次读、1 次写。
FlashAttention 的核心思路就是把整个 Attention 计算 fuse 成一个 kernel,让算法变得 IO-Aware。
Compute-Bound vs Memory-Bound#
GPU 上的操作可以分为两类:
- Compute-Bound:算力是瓶颈,HBM 访问时间远小于计算时间。典型例子是大通道卷积、大维度矩阵乘法。
- Memory-Bound:内存读写是瓶颈,计算时间远小于 HBM 访问时间。典型例子是 elementwise 操作(activation、dropout、softmax、batch norm、layer norm)和 reduction 操作。
标准 Attention 实现中,除了两次矩阵乘法(Q·K 和 P·V)是 compute-bound 外,softmax、masking、dropout 都是 memory-bound 操作。每次执行这些操作都需要从 HBM 读数据、写回数据,造成大量的内存访问开销。
为什么减少 FLOP 不等于减少时间#
以前的 approximate attention 方法聚焦于减少计算量(FLOP),但实际测得的 wall-clock time 并没有同步减少。原因在于这些方法忽视了 memory access 才是真正的瓶颈。标准 Attention 的每一步中间结果(S 矩阵、P 矩阵)都需要反复在 SRAM 和 HBM 之间搬运,这些 memory-bound 操作占据了大部分时间。
FlashAttention 的核心突破就在于:不去减少 FLOP,而是减少 HBM 的读写次数。通过增加少量计算(recomputation),换来了内存访问次数的数量级减少。
标准 Attention 的内存读写分析#
在深入 FlashAttention 算法之前,需要精确理解标准 Attention 实现中到底发生了多少次 HBM 读写。
Attention 公式回顾#
给定输入矩阵 Q,K,V∈RN×d,Attention 的计算为:
S=QK⊤∈RN×N,P=softmax(S)∈RN×N,O=PV∈RN×d
其中 softmax 按行计算。以 GPT-2 为例,N=1024,d=64。由于 N≫d,中间矩阵 S 和 P 都是 N×N 的巨大矩阵,存储开销为 O(N2)。
此外,实际实现中还需要在 S 上 apply masking,在 P 上 apply dropout,每一次都是一个额外的 memory-bound 操作。
Algorithm 0: 标准实现的三次读写#
标准 Attention 的 PyTorch 实现分为以下步骤:
- 第一次读写:从 HBM 逐 block 读取 Q 和 K,计算 S=QK⊤,将 N×N 的 S 矩阵写回 HBM
- 第二次读写:从 HBM 读取 S,计算 P=softmax(S),将 N×N 的 P 矩阵写回 HBM
- 第三次读写:从 HBM 读取 P 和 V,计算 O=PV,将 O 写回 HBM
整个过程中,三次大块读取、三次大块写入。其中 S 和 P 这两个 N×N 的中间矩阵需要被完整地写入 HBM 后再读出,而它们本身并不是最终输出,只是中间产物。
PyTorch 的这种逐步实现导致每一步都产生一次 HBM round-trip,wall-clock time 中大部分消耗在内存搬运上,而不是矩阵乘法本身。
FlashAttention 的核心思路#
理解了三次读写的问题,FlashAttention 的目标就呼之欲出了:
能否将三次大块读写合并为一次读取 + 一次写入?
具体来说:一次性读入 Q、K、V 的分块,在 SRAM 中完成所有中间计算(S、softmax、dropout、P⋅V),然后只把最终的 O 写回 HBM,全程不将 S 和 P 这两个 N×N 矩阵写入 HBM。
实现这个目标需要解决两个技术难题:
- Tiling:softmax 是一个 global 操作(需要整行的 max 和 sum),如何在只看到一个 block 的情况下做分块 softmax?
- Recomputation:backward pass 需要 S 和 P,如果不存它们,怎么算梯度?
这两个问题分别由 Online Softmax 和 Recomputation 技术解决。
Tiling 的数学基础: Safe Softmax 与 Online Softmax#
Tiling 和 Recomputation 都不是 FlashAttention 发明的新技术,它们在 database、image processing、numerical linear algebra 等领域早已被广泛使用。FlashAttention 的贡献在于将它们应用到 Attention 计算中,并证明了最优性。
在讲 tiling 之前,需要先理解一个前置知识:如何在不看到整行数据的情况下计算 softmax。
Softmax 的数值稳定性问题#
标准 softmax 公式:
softmax(x)j=∑iexiexj
直接计算的问题在于 exj 可能极大(指数增长),导致数值溢出。
Safe Softmax (Stable Softmax)#
解决方法是利用指数函数的性质:对所有元素减去同一个常数后再取指数,结果不变。具体定义:
m(x):=maxixi
f(x):=[ex1−m(x),…,exB−m(x)]
ℓ(x):=∑if(x)i
softmax(x):=ℓ(x)f(x)
减去 m(x) 后,指数的最大值变为 e0=1,所有值都落在 (0,1] 区间内,避免了溢出。这个性质成立的数学基础是:对 x 的每一项加减任意常数 c,softmax 的结果不变:
softmax(x)=softmax(x−c)
关键挑战:Softmax 是 Global 操作#
Safe softmax 需要 m(x)=maxixi,这是一个需要看到整行所有元素的 global 操作。如果要做 tiling(把输入分成多个 block 逐块处理),就必须解决一个问题:在只看到部分数据时,如何正确计算 softmax?
Online Softmax: 分块计算的数学推导#
将向量 x 分成两部分 x=[x(1),x(2)],其中 x(1),x(2)∈RB。目标是分别处理 x(1) 和 x(2),最后合并得到完整的 softmax。
第一步:分解 m(x)
m(x)=m([x(1),x(2)])=max(m(x(1)),m(x(2)))
全局最大值等于两个 block 各自最大值中更大的那一个。
第二步:分解 f(x)
对 x(1) 部分,先用局部最大值 m(x(1)) 计算局部的 f(x(1)),然后通过一个修正因子将其转换到全局基准:
f(x)=[em(x(1))−m(x)⋅f(x(1)),em(x(2))−m(x)⋅f(x(2))]
具体来说,f(x(1)) 中每一项的计算是基于 m(x(1)) 做的减法,现在需要补偿到全局 m(x) 的基准上,所以乘以 em(x(1))−m(x)。
第三步:分解 ℓ(x)
ℓ(x)=em(x(1))−m(x)⋅ℓ(x(1))+em(x(2))−m(x)⋅ℓ(x(2))
softmax(x)=ℓ(x)f(x)
核心 insight:通过维护两个额外的统计量 m(x)(当前已见数据的最大值)和 ℓ(x)(当前已见数据的归一化因子),可以在逐 block 处理数据时,持续更新这两个值,最终得到与一次性计算完全相同的 softmax 结果。
推广到多个 Block#
上述两块的推导可以自然推广到多个 block。处理第 j 个 block 时:
- 计算当前 block 的局部 m~ij 和 P~ij
- 更新全局最大值 minew=max(mi,m~ij)
- 用修正因子 emi−minew 和 em~ij−minew 将之前的累积结果和当前 block 的结果 rescale 到统一基准
- 更新 ℓi 和输出 Oi
这正是 FlashAttention Algorithm 1 的核心循环体。
FlashAttention 核心算法#
有了 Online Softmax 的数学基础和 Recomputation 的思路,可以将它们合成一个统一的 fused kernel,即 FlashAttention 的 Algorithm 1。
Recomputation: 用计算换内存#
在标准 Attention 的 backward pass 中,需要用到中间矩阵 S 和 P,它们都是 N×N 的矩阵,内存开销为 O(N2)。
FlashAttention 的策略是:不存储 S 和 P。forward pass 中只存储 Q、K、V(都是 N×d 的,d≪N)以及 softmax 的归一化因子 m 和 ℓ(都是 O(N) 的向量)。backward pass 需要 S 和 P 时,从 Q、K、V 重新计算。
这看似增加了计算量(FLOP),但由于减少了 HBM 的读写次数,实际的 wall-clock time 反而下降了。这就是 recomputation 的精髓:用额外的 FLOP 换取更少的 memory access,从而实现 speed up。
Algorithm 1: FlashAttention Forward Pass#
输入:Q,K,V∈RN×d(存储在 HBM 中),on-chip SRAM 大小为 M。
Step 1: 设定 block size
Bc=⌈4dM⌉,Br=min(⌈4dM⌉,d)
Block size 由 SRAM 容量 M 和 head dimension d 决定。Bc 是 K/V 的 block 行数,Br 是 Q 的 block 行数。
由于 M 有限,实际的 tiling 不是正方形的:K/V 的 block 可能比 Q 的 block 更宽(Bc>Br),Q 需要被切成更多的 block。这是一个重要的 observation:图示中为了直观画成正方形,实际是长方形的 tiling。
Step 2: 初始化
O=0N×d,ℓ=0N,m=(−∞)N(均在 HBM 中)
Step 3: 分块
- 将 Q 分成 Tr=⌈N/Br⌉ 个 block
- 将 K、V 分成 Tc=⌈N/Bc⌉ 个 block
- 将 O、ℓ、m 也做对应分块
Step 4-5: 双层循环
外循环 (j = 1 to T_c): // 遍历 K/V 的每个 block 从 HBM 加载 K_j, V_j 到 SRAM 内循环 (i = 1 to T_r): // 遍历 Q 的每个 block 从 HBM 加载 Q_i, O_i, ℓ_i, m_i 到 SRAM
// 在 SRAM 中计算(不写回中间结果到 HBM) S_ij = Q_i · K_j^T // 当前 tile 的 score m̃_ij = rowmax(S_ij) // 当前 tile 的行最大值 P̃_ij = exp(S_ij - m̃_ij) // 当前 tile 的 softmax 分子 ℓ̃_ij = rowsum(P̃_ij) // 当前 tile 的行和
// 更新全局统计量 m_i^new = max(m_i, m̃_ij) ℓ_i^new = e^{m_i - m_i^new} · ℓ_i + e^{m̃_ij - m_i^new} · ℓ̃_ij
// 更新输出(rescale 旧值 + 加入新贡献) O_i ← diag(ℓ_i^new)^{-1} · (diag(ℓ_i) · e^{m_i - m_i^new} · O_i + e^{m̃_ij - m_i^new} · P̃_ij · V_j)
// 写回更新后的 ℓ_i, m_i 到 HBM ℓ_i ← ℓ_i^new, m_i ← m_i^new关键点:外循环加载一次 K/V block,内循环遍历所有 Q block。每次内循环中,用 Online Softmax 的 rescale 技术更新输出 O、最大值 m 和归一化因子 ℓ。整个过程中,S 和 P 矩阵只在 SRAM 中以 tile 的形式短暂存在,永远不写入 HBM。
内存分析#
FlashAttention 的 forward pass 需要 O(N2d) FLOPs(与标准 Attention 相同),但只需要 O(N) 的 additional memory(存储 ℓ 和 m),而不是标准实现的 O(N2)。这从根本上解决了长上下文的 out-of-memory 问题。
IO 复杂度对比#
| 标准 Attention | FlashAttention | |
|---|---|---|
| HBM 访问次数 | Θ(Nd+N2) | O(N2d2M−1) |
| 额外内存 | O(N2) | O(N) |
| 端到端 speedup | baseline | 最高 7.6x |
由于 d2/M 在实际硬件中远小于 1(例如 d=128,M=100KB 时,d2/M≈1/6),FlashAttention 的 HBM 访问次数比标准实现少数倍到数十倍。
IO Complexity 分析与最优性证明#
FlashAttention 不仅在实践中快,论文还从理论上证明了它在 IO 复杂度意义下是最优的。
Theorem 1: 正确性与资源开销#
Algorithm 1 返回 O=softmax(QK⊤)V,使用 O(N2d) FLOPs,且只需要 O(N) 的 additional memory(不含输入输出本身)。
Proposition 2: HBM 访问上界#
标准 Attention 需要 Θ(Nd+N2) 次 HBM 访问。FlashAttention 只需要:
O(MN2d2) 次 HBM 访问
对比关键在于 d2/M 这个因子。以实际硬件参数为例:
- d=64:d2=4096,M=100KB,d2/M≈1/24,即 HBM 访问减少约 24 倍
- d=128:d2=16384,M=100KB,d2/M≈1/6,即减少约 6 倍
Block Size 的 Sweet Spot#
Block size 越大,每次 SRAM 中能处理的数据越多,HBM 访问次数越少。但 block size 不能无限增大,因为它受限于 SRAM 容量。而且当 block size 大到一定程度后(超过约 256),runtime 的瓶颈从 memory-bound 转变为 compute-bound 或其他因素,继续增大 block size 的收益递减。
这意味着存在一个 sweet spot:block size 足够大以充分利用 SRAM,但不超过 SRAM 容量限制。
Theorem 3: 最优性下界(Lower Bound)#
论文最硬核的理论贡献之一是 Proposition 3:对于所有精确 Attention 算法,在 SRAM 大小 M 满足 d≤M≤Nd 的合理范围内,HBM 访问次数的下界为 Ω(N2d2M−1)。
这意味着 FlashAttention 的 IO 复杂度已经达到了理论最优,不存在一个精确 Attention 算法能在 IO 复杂度上比 FlashAttention 快一个数量级。
Block-Sparse FlashAttention#
论文进一步将 FlashAttention 扩展到 sparse attention 场景。给定一个 block 形式的 mask 矩阵 M~∈{0,1}N×N,计算:
S=QK⊤,P=softmax(S⊙1M~),O=PV
其中 S⊙1M~ 表示将 mask 为 0 的位置设为 −∞。
Proposition 4:Block-Sparse FlashAttention 的 IO 复杂度为原始 FlashAttention 的 s 倍,其中 s 是 sparsity ratio(非零 block 的比例)。例如 s=0.1 意味着只有 10% 的 block 是非零的,IO 复杂度再降低 10 倍。
这为业界后续的 sparse attention 优化(如 Sliding Window + Global Attention 的组合)提供了高效的底层实现。
实验结果与实际影响#
训练速度#
FlashAttention 在多个主流模型和框架上实现了显著的端到端加速:
| 模型 / 框架 | Baseline | FlashAttention | 加速比 |
|---|---|---|---|
| BERT-large(8×A100) | Nvidia MLPerf 1.1: 20.0±1.5 min | 17.4±1.4 min | 15% faster |
| GPT-2 small(HuggingFace) | 9.5 days (1.0×) | 2.7 days (3.5×) | 3.5× |
| GPT-2 small(Megatron-LM) | 4.7 days (2.0×) | 2.7 days (3.5×) | 1.7× |
| GPT-2 medium(HuggingFace) | 21.0 days (1.0×) | 6.2 days (3.4×) | 3.4× |
| 长上下文场景 | baseline | 2.4× | |
| Long Range Arena | baseline | 最高 2.8× |
BERT 训练时间(17 分钟 vs Nvidia 的 20 分钟)值得强调:FlashAttention 在 BERT 上甚至击败了 Nvidia 自家的 MLPerf 实现,这在当时是相当大胆的 claim。
对于 Decoder-Only 架构(GPT 系列),FlashAttention 的加速更为显著,达到了 3-3.5 倍。这种质的飞跃来自于 GPT 的 auto-regressive 特性使得序列长度对性能的影响更大。
模型质量: 长上下文的突破#
FlashAttention 将 GPT-2 的上下文长度从 1K 扩展到 4K,perplexity 降低了 0.7,在两个长文档分类任务上提升了 6.4 个百分点。
长上下文带来的不仅是速度,更重要的是模型能力的质变:更多信息可以被输入到 Transformer 中,模型可以做更长距离的推理。
Path-X 与 Path-256: 从零到一的突破#
Path-X Challenge 是一个经典的长序列理解任务:在一张图片中判断两个点之间是否存在一条连接路径。这个任务需要模型处理极长的像素序列。
- Path-X(序列长度 16K):FlashAttention 使得 Transformer 在这个任务上第一次超越了 random baseline(50% 准确率)
- Path-256(序列长度 64K):同样是第一个实现 better-than-random 的序列模型
这个结果的意义在于:此前没有任何 Transformer 模型能在这些超长序列任务上做到 better-than-random。FlashAttention 通过使长上下文成为可能,实现了 0 到 1 的突破。
Attention Benchmark#
论文对 FlashAttention 的 attention kernel 本身做了详细的 profiling 和 footprint analysis:
- Memory footprint 随序列长度线性增长(而非平方增长)
- Block-Sparse FlashAttention 的 runtime 也随序列长度线性增长
- 在所有测试的序列长度上,FlashAttention 都快于所有已有的 approximate attention baseline
Backward Pass 的实现复杂度与相关工作对比#
论文正文对 forward pass 做了完整展开,但 backward pass(计算梯度)在正文中几乎一笔带过,实际的完整推导和实现在 Appendix B 中,复杂度相当高。
值得注意的是,Rabe & Stauss 等人此前也做了 Attention 的内存优化工作,他们的方法减少了 memory footprint(内存占用量)。FlashAttention 与这些工作的核心区别在于:FlashAttention 不仅减少了内存占用,更关键的是减少了 memory access(内存访问次数)。这一差异直接体现在 wall-clock time 上。前者只是让你不 OOM,后者让你真正跑得更快。
关于 FlashAttention 2 的预告#
FlashAttention 1 的算法中,外循环遍历 KV block,内循环遍历 Q block。FlashAttention 2 将这两个循环对换了(外循环遍历 Q,内循环遍历 KV),这个看似简单的改动带来了进一步的性能优化。
Limitations 与 Future Directions#
论文坦诚地指出了三个局限性和未来方向:
- Compiling to CUDA:FlashAttention 用 CUDA 手写 kernel,对 Python 开发者不友好。如果想修改 attention 的实现(比如加入新的 bias 或 mask),需要直接改 CUDA 代码。论文提出了一个方向:是否能用高层语言(如 Halide 风格的 DSL)编写 IO-aware 的 attention 实现,自动编译到 CUDA。
- IO-Aware Deep Learning:IO-Aware 的思路不仅适用于 Attention,理论上也可以延伸到 MLP 层、Embedding 层等 Transformer 的其他模块。整个深度学习框架能否变得 IO-Aware?
- Multi-GPU IO-Aware Methods:FlashAttention 的 IO 分析针对的是单 GPU 内部的内存层级。当扩展到多 GPU 时,还需要考虑 GPU 间的数据传输(NVLink、PCIe),IO 分析需要增加一层。
代码库结构与实现解析#
FlashAttention 的代码仓库(flash-attention)包含四个版本(1-4),以下聚焦 FlashAttention v1(flash-attention-1-v1.0.9)的代码结构。
仓库概览#
FlashAttention v1 的两个核心目录:
flash_attn/:Python 层接口,负责将用户的 PyTorch tensor 传递给底层 CUDA kernelcsrc/:C++ 和 CUDA 层实现,包含真正的高性能 kernel
Python 层: flash_attention.py#
这个文件定义了两个核心 class,都继承自 nn.Module:
FlashAttention:最直接的接口。forward 方法接收 QKV tensor,调用底层的 flash_attn_unpadded_qkvpacked_func,返回 attention output。它是一个对底层 CUDA function 的 wrapper。
FlashMHA(Multi-Head Attention):在 FlashAttention 之上再包了一层。它接收 embedding 而不是 QKV,先帮用户计算 QKV projection,再调用 FlashAttention。
Python 层: flash_attn_interface.py#
这个文件是 Python 和 CUDA 之间的桥梁。它定义了四个继承自 torch.autograd.Function 的 class:
| Class | 输入形式 | 使用场景 |
|---|---|---|
FlashAttnQKVPackedFunc | QKV 打包为一个矩阵 | QKV 形状相同时 |
FlashAttnKVPackedFunc | KV 打包,Q 单独 | Cross-attention |
FlashAttnFunc | Q、K、V 三个矩阵分开 | 通用场景 |
FlashAttnSplitFunc | 同上,head_dim > 128 时使用 | 大 head dimension |
每个 class 都实现了 forward 和 backward 两个静态方法:
forward 的核心逻辑:
- 调用
flash_attn_cuda.fwd(底层 CUDA function) - 将中间值(softmax_lse、rng_state 等)存入
ctx(autograd context),供 backward 使用 - 返回 attention output
backward 的核心逻辑:
- 从
ctx取出 forward 时保存的中间值 - 调用
flash_attn_cuda.bwd - 返回 Q、K、V 的梯度
torch.autograd.Function 的 apply 方法会自动建立计算图,使得 FlashAttention 可以无缝嵌入 PyTorch 的自动微分系统。
C++ 层: csrc/flash_attn/#
C++ 层包含一个关键文件 fmha_api.cpp,它定义了两组核心 function:
参数设置:set_params_fprop(forward)和 set_params_dgrad(backward)。它们的工作是将 PyTorch tensor 的指针、stride、形状等信息填入一个 params 结构体,供 CUDA kernel 使用。stride 是指在内存中从一个元素跳到下一个元素需要跳过的字节数。
Dispatch function:run_fmha_fwd 和 run_fmha_bwd。它们根据 head dimension 的大小分发到不同的 kernel:
- d≤32:走一种 kernel
- 32<d≤64:走另一种 kernel
- 64<d≤128:走第三种 kernel
dispatch 完成后,主要流程是:torch check(验证维度是否匹配、形状是否正确)→ set params → launch CUDA kernel → return result。
CUDA Kernel 层#
真正的硬核实现在 hdm32/、hdm64/、hdm128/ 三个目录中,分别对应三种 head dimension 范围。每个目录中包含纯 C++/CUDA 的 kernel 实现。
核心 class 名为 fmha_kernel,它使用了大量 NVIDIA 底层的 CUDA primitives:
- Softmax:在 SRAM 中实现的分块 safe softmax
- LSE(Log-Sum-Exp):softmax 的归一化因子
- GEMM(General Matrix Multiply):矩阵乘法,这是 attention 中 compute-bound 的核心操作
这些 kernel 将 NVIDIA 提供的底层 primitives 重新组合,形成了 fused multi-head attention 的 forward 和 backward kernel。这两个 kernel 文件是整个 FlashAttention 代码库中最底层、最硬核的部分。
Block Sparse Attention#
除了标准的 FlashAttention,代码库中还实现了 BlockSparseAttention,同样包含 forward 和 backward 两个 kernel。Block sparse 版本允许用户指定一个 block-level 的 sparsity mask,跳过不需要计算的 block,进一步减少 IO 和计算。
部分内容可能已过时
评论