FlashAttention

5414 字
27 分钟
FlashAttention
文章摘要

从 GPU 内存层级出发,拆解 Tiling、Online Softmax、Recomputation 三大核心技术

FlashAttention 论文总览#

FlashAttention 系列共有四篇工作,本文聚焦的是 2022 年发表的第一篇,第一作者是 Tri Dao。这篇工作使 Tri Dao 成为大模型 infra 领域的标志性人物。

当代大模型的核心架构是 Transformer,而 Transformer 最核心的 mechanism 就是 Attention。如何让 Attention 变得更快,一直是业界的核心难题。Tri Dao 从一个被所有人忽视的角度攻克了这个问题。

核心观察#

Transformer 处理长上下文时存在根本性限制:序列长度增长导致计算变慢、内存占用急剧增大,原因在于 Attention 的 quadratic(平方级)复杂度。业界此前的主流方法是牺牲模型质量来换取复杂度降低,包括 Sparse Approximation(如 Sliding Window Attention)和 Low-Rank Approximation(如 Linear Attention、MLA)。但即便这些方法在理论上降低了 FLOP 数,真实的 wall-clock time 并没有显著减少。

Tri Dao 发现了一个被所有人忽视的原理:优化目标不应该是减少计算复杂度,而是减少对 GPU 内存的读取和写入。他提出将 Attention 算法变得 IO-Aware,即对 GPU 不同层级内存之间的读写开销有感知。

这里需要区分一个概念:GQA(Grouped-Query Attention)并不属于近似算法,因为它只是减少了 KV 的维度,仍然是精确算法。

核心贡献#

FlashAttention 算法有以下关键特性:

  1. Exact Algorithm:对模型质量没有任何损失,不是近似计算
  2. IO-Aware:对不同层级内存的读写有感知
  3. Tiling 技术:将大矩阵分块加载到 SRAM 中计算,减少 HBM 读写次数
  4. Recomputation:不存储巨大的中间矩阵 S 和 P,需要时重新计算
  5. 最优性证明:在实际硬件参数范围内,FlashAttention 的 IO 复杂度达到了理论下界

实际效果#

  • BERT 模型:端到端 15% 的加速
  • GPT(Decoder-Only):3 倍的速度提升
  • 内存使用从 O(N2)O(N^2) 降至 O(N)O(N),使得长上下文训练成为现实
  • 在 Path-X Challenge(序列长度 16K)上实现了从零到非零的突破,Path-256(序列长度 64K)上也是第一个超越 random baseline 的序列模型
  • 在 Vision、长文档等领域实现了显著突破

FlashAttention 已经开源,对业界产生了巨大推动力。

GPU 内存层级与 Attention 的瓶颈所在#

要理解 FlashAttention 为什么有效,首先需要理解 GPU 的内存层级架构,这是整篇论文的立足点。

GPU 内存分为三个层级:

层级速度大小特点
SRAM(on-chip)19 TB/s~20 MB离 Register 最近,每个 SM 的 shared memory
HBM(GPU 显存)1.5 TB/s40 GB(A100)最常用的参数存储,后续发展到 80GB(A100 80G)、140GB(H100)、数百 GB(Blackwell)
DRAM(CPU 内存)12.8 GB/s>1 TB速度比 HBM 慢约 100 倍,训练中通常不涉及

关键数字:SRAM 比 HBM 快约 10 倍,但容量只有 HBM 的 1/2000。大量数据只能存在 HBM 中,但每次从 HBM 读写都有巨大的延迟代价。

GPU Execution Model#

GPU 的基本执行单元是 Thread,每个 Thread 执行一个具体的 operation。一个重要的实用技术是 Kernel Fusion,即把多个 element-wise operation 合并为一个 kernel。

以一个简单的线性函数 y=ax+by = ax + b 为例:

  • 不做 Fusion:从 HBM 读 xx,计算 axax,写回 HBM;再读出来,计算 +b+b,写回 HBM。共 2 次读、2 次写。
  • 做 Fusion:从 HBM 读 xx,计算 ax+bax + b,写回 HBM。共 1 次读、1 次写。

FlashAttention 的核心思路就是把整个 Attention 计算 fuse 成一个 kernel,让算法变得 IO-Aware。

Compute-Bound vs Memory-Bound#

GPU 上的操作可以分为两类:

  • Compute-Bound:算力是瓶颈,HBM 访问时间远小于计算时间。典型例子是大通道卷积、大维度矩阵乘法。
  • Memory-Bound:内存读写是瓶颈,计算时间远小于 HBM 访问时间。典型例子是 elementwise 操作(activation、dropout、softmax、batch norm、layer norm)和 reduction 操作。

标准 Attention 实现中,除了两次矩阵乘法(Q·K 和 P·V)是 compute-bound 外,softmax、masking、dropout 都是 memory-bound 操作。每次执行这些操作都需要从 HBM 读数据、写回数据,造成大量的内存访问开销。

为什么减少 FLOP 不等于减少时间#

以前的 approximate attention 方法聚焦于减少计算量(FLOP),但实际测得的 wall-clock time 并没有同步减少。原因在于这些方法忽视了 memory access 才是真正的瓶颈。标准 Attention 的每一步中间结果(S 矩阵、P 矩阵)都需要反复在 SRAM 和 HBM 之间搬运,这些 memory-bound 操作占据了大部分时间。

FlashAttention 的核心突破就在于:不去减少 FLOP,而是减少 HBM 的读写次数。通过增加少量计算(recomputation),换来了内存访问次数的数量级减少。

标准 Attention 的内存读写分析#

在深入 FlashAttention 算法之前,需要精确理解标准 Attention 实现中到底发生了多少次 HBM 读写。

Attention 公式回顾#

给定输入矩阵 Q,K,VRN×d\mathbf{Q}, \mathbf{K}, \mathbf{V} \in \mathbb{R}^{N \times d},Attention 的计算为:

S=QKRN×N,P=softmax(S)RN×N,O=PVRN×d\mathbf{S} = \mathbf{Q}\mathbf{K}^\top \in \mathbb{R}^{N \times N}, \quad \mathbf{P} = \text{softmax}(\mathbf{S}) \in \mathbb{R}^{N \times N}, \quad \mathbf{O} = \mathbf{P}\mathbf{V} \in \mathbb{R}^{N \times d}

其中 softmax 按行计算。以 GPT-2 为例,N=1024N = 1024d=64d = 64。由于 NdN \gg d,中间矩阵 S\mathbf{S}P\mathbf{P} 都是 N×NN \times N 的巨大矩阵,存储开销为 O(N2)O(N^2)

此外,实际实现中还需要在 S\mathbf{S} 上 apply masking,在 P\mathbf{P} 上 apply dropout,每一次都是一个额外的 memory-bound 操作。

Algorithm 0: 标准实现的三次读写#

标准 Attention 的 PyTorch 实现分为以下步骤:

  1. 第一次读写:从 HBM 逐 block 读取 Q\mathbf{Q}K\mathbf{K},计算 S=QK\mathbf{S} = \mathbf{Q}\mathbf{K}^\top,将 N×NN \times NS\mathbf{S} 矩阵写回 HBM
  2. 第二次读写:从 HBM 读取 S\mathbf{S},计算 P=softmax(S)\mathbf{P} = \text{softmax}(\mathbf{S}),将 N×NN \times NP\mathbf{P} 矩阵写回 HBM
  3. 第三次读写:从 HBM 读取 P\mathbf{P}V\mathbf{V},计算 O=PV\mathbf{O} = \mathbf{P}\mathbf{V},将 O\mathbf{O} 写回 HBM

整个过程中,三次大块读取、三次大块写入。其中 S\mathbf{S}P\mathbf{P} 这两个 N×NN \times N 的中间矩阵需要被完整地写入 HBM 后再读出,而它们本身并不是最终输出,只是中间产物。

PyTorch 的这种逐步实现导致每一步都产生一次 HBM round-trip,wall-clock time 中大部分消耗在内存搬运上,而不是矩阵乘法本身。

FlashAttention 的核心思路#

理解了三次读写的问题,FlashAttention 的目标就呼之欲出了:

能否将三次大块读写合并为一次读取 + 一次写入?

具体来说:一次性读入 Q\mathbf{Q}K\mathbf{K}V\mathbf{V} 的分块,在 SRAM 中完成所有中间计算(S\mathbf{S}、softmax、dropout、PV\mathbf{P} \cdot \mathbf{V}),然后只把最终的 O\mathbf{O} 写回 HBM,全程不将 S\mathbf{S}P\mathbf{P} 这两个 N×NN \times N 矩阵写入 HBM。

实现这个目标需要解决两个技术难题:

  1. Tiling:softmax 是一个 global 操作(需要整行的 max 和 sum),如何在只看到一个 block 的情况下做分块 softmax?
  2. Recomputation:backward pass 需要 S\mathbf{S}P\mathbf{P},如果不存它们,怎么算梯度?

这两个问题分别由 Online Softmax 和 Recomputation 技术解决。

Tiling 的数学基础: Safe Softmax 与 Online Softmax#

Tiling 和 Recomputation 都不是 FlashAttention 发明的新技术,它们在 database、image processing、numerical linear algebra 等领域早已被广泛使用。FlashAttention 的贡献在于将它们应用到 Attention 计算中,并证明了最优性。

在讲 tiling 之前,需要先理解一个前置知识:如何在不看到整行数据的情况下计算 softmax。

Softmax 的数值稳定性问题#

标准 softmax 公式:

softmax(x)j=exjiexi\text{softmax}(x)_j = \frac{e^{x_j}}{\sum_i e^{x_i}}

直接计算的问题在于 exje^{x_j} 可能极大(指数增长),导致数值溢出。

Safe Softmax (Stable Softmax)#

解决方法是利用指数函数的性质:对所有元素减去同一个常数后再取指数,结果不变。具体定义:

m(x):=maxixim(x) := \max_i x_i

f(x):=[ex1m(x),,exBm(x)]f(x) := \left[e^{x_1 - m(x)}, \ldots, e^{x_B - m(x)}\right]

(x):=if(x)i\ell(x) := \sum_i f(x)_i

softmax(x):=f(x)(x)\text{softmax}(x) := \frac{f(x)}{\ell(x)}

减去 m(x)m(x) 后,指数的最大值变为 e0=1e^0 = 1,所有值都落在 (0,1](0, 1] 区间内,避免了溢出。这个性质成立的数学基础是:对 xx 的每一项加减任意常数 cc,softmax 的结果不变:

softmax(x)=softmax(xc)\text{softmax}(x) = \text{softmax}(x - c)

关键挑战:Softmax 是 Global 操作#

Safe softmax 需要 m(x)=maxixim(x) = \max_i x_i,这是一个需要看到整行所有元素的 global 操作。如果要做 tiling(把输入分成多个 block 逐块处理),就必须解决一个问题:在只看到部分数据时,如何正确计算 softmax?

Online Softmax: 分块计算的数学推导#

将向量 xx 分成两部分 x=[x(1),x(2)]x = [x^{(1)}, x^{(2)}],其中 x(1),x(2)RBx^{(1)}, x^{(2)} \in \mathbb{R}^B。目标是分别处理 x(1)x^{(1)}x(2)x^{(2)},最后合并得到完整的 softmax。

第一步:分解 m(x)m(x)

m(x)=m([x(1),x(2)])=max(m(x(1)),m(x(2)))m(x) = m([x^{(1)}, x^{(2)}]) = \max(m(x^{(1)}), m(x^{(2)}))

全局最大值等于两个 block 各自最大值中更大的那一个。

第二步:分解 f(x)f(x)

x(1)x^{(1)} 部分,先用局部最大值 m(x(1))m(x^{(1)}) 计算局部的 f(x(1))f(x^{(1)}),然后通过一个修正因子将其转换到全局基准:

f(x)=[em(x(1))m(x)f(x(1)),em(x(2))m(x)f(x(2))]f(x) = \left[e^{m(x^{(1)}) - m(x)} \cdot f(x^{(1)}), \quad e^{m(x^{(2)}) - m(x)} \cdot f(x^{(2)})\right]

具体来说,f(x(1))f(x^{(1)}) 中每一项的计算是基于 m(x(1))m(x^{(1)}) 做的减法,现在需要补偿到全局 m(x)m(x) 的基准上,所以乘以 em(x(1))m(x)e^{m(x^{(1)}) - m(x)}

第三步:分解 (x)\ell(x)

(x)=em(x(1))m(x)(x(1))+em(x(2))m(x)(x(2))\ell(x) = e^{m(x^{(1)}) - m(x)} \cdot \ell(x^{(1)}) + e^{m(x^{(2)}) - m(x)} \cdot \ell(x^{(2)})

softmax(x)=f(x)(x)\text{softmax}(x) = \frac{f(x)}{\ell(x)}

核心 insight:通过维护两个额外的统计量 m(x)m(x)(当前已见数据的最大值)和 (x)\ell(x)(当前已见数据的归一化因子),可以在逐 block 处理数据时,持续更新这两个值,最终得到与一次性计算完全相同的 softmax 结果。

推广到多个 Block#

上述两块的推导可以自然推广到多个 block。处理第 jj 个 block 时:

  1. 计算当前 block 的局部 m~ij\tilde{m}_{ij}P~ij\tilde{P}_{ij}
  2. 更新全局最大值 minew=max(mi,m~ij)m_i^{\text{new}} = \max(m_i, \tilde{m}_{ij})
  3. 用修正因子 emiminewe^{m_i - m_i^{\text{new}}}em~ijminewe^{\tilde{m}_{ij} - m_i^{\text{new}}} 将之前的累积结果和当前 block 的结果 rescale 到统一基准
  4. 更新 i\ell_i 和输出 Oi\mathbf{O}_i

这正是 FlashAttention Algorithm 1 的核心循环体。

FlashAttention 核心算法#

有了 Online Softmax 的数学基础和 Recomputation 的思路,可以将它们合成一个统一的 fused kernel,即 FlashAttention 的 Algorithm 1。

Recomputation: 用计算换内存#

在标准 Attention 的 backward pass 中,需要用到中间矩阵 S\mathbf{S}P\mathbf{P},它们都是 N×NN \times N 的矩阵,内存开销为 O(N2)O(N^2)

FlashAttention 的策略是:不存储 S\mathbf{S}P\mathbf{P}。forward pass 中只存储 Q\mathbf{Q}K\mathbf{K}V\mathbf{V}(都是 N×dN \times d 的,dNd \ll N)以及 softmax 的归一化因子 mm\ell(都是 O(N)O(N) 的向量)。backward pass 需要 S\mathbf{S}P\mathbf{P} 时,从 Q\mathbf{Q}K\mathbf{K}V\mathbf{V} 重新计算。

这看似增加了计算量(FLOP),但由于减少了 HBM 的读写次数,实际的 wall-clock time 反而下降了。这就是 recomputation 的精髓:用额外的 FLOP 换取更少的 memory access,从而实现 speed up

Algorithm 1: FlashAttention Forward Pass#

输入Q,K,VRN×d\mathbf{Q}, \mathbf{K}, \mathbf{V} \in \mathbb{R}^{N \times d}(存储在 HBM 中),on-chip SRAM 大小为 MM

Step 1: 设定 block size

Bc=M4d,Br=min(M4d,d)B_c = \left\lceil \frac{M}{4d} \right\rceil, \quad B_r = \min\left(\left\lceil \frac{M}{4d} \right\rceil, d\right)

Block size 由 SRAM 容量 MM 和 head dimension dd 决定。BcB_c 是 K/V 的 block 行数,BrB_r 是 Q 的 block 行数。

由于 MM 有限,实际的 tiling 不是正方形的:K/V 的 block 可能比 Q 的 block 更宽(Bc>BrB_c > B_r),Q 需要被切成更多的 block。这是一个重要的 observation:图示中为了直观画成正方形,实际是长方形的 tiling。

Step 2: 初始化

O=0N×d,=0N,m=()N(均在 HBM 中)\mathbf{O} = \mathbf{0}_{N \times d}, \quad \ell = \mathbf{0}_N, \quad m = (-\infty)_N \quad \text{(均在 HBM 中)}

Step 3: 分块

  • Q\mathbf{Q} 分成 Tr=N/BrT_r = \lceil N/B_r \rceil 个 block
  • K\mathbf{K}V\mathbf{V} 分成 Tc=N/BcT_c = \lceil N/B_c \rceil 个 block
  • O\mathbf{O}\ellmm 也做对应分块

Step 4-5: 双层循环

外循环 (j = 1 to T_c): // 遍历 K/V 的每个 block
从 HBM 加载 K_j, V_j 到 SRAM
内循环 (i = 1 to T_r): // 遍历 Q 的每个 block
从 HBM 加载 Q_i, O_i, ℓ_i, m_i 到 SRAM
// 在 SRAM 中计算(不写回中间结果到 HBM)
S_ij = Q_i · K_j^T // 当前 tile 的 score
m̃_ij = rowmax(S_ij) // 当前 tile 的行最大值
P̃_ij = exp(S_ij - m̃_ij) // 当前 tile 的 softmax 分子
ℓ̃_ij = rowsum(P̃_ij) // 当前 tile 的行和
// 更新全局统计量
m_i^new = max(m_i, m̃_ij)
ℓ_i^new = e^{m_i - m_i^new} · ℓ_i + e^{m̃_ij - m_i^new} · ℓ̃_ij
// 更新输出(rescale 旧值 + 加入新贡献)
O_i ← diag(ℓ_i^new)^{-1} · (diag(ℓ_i) · e^{m_i - m_i^new} · O_i + e^{m̃_ij - m_i^new} · P̃_ij · V_j)
// 写回更新后的 ℓ_i, m_i 到 HBM
ℓ_i ← ℓ_i^new, m_i ← m_i^new

关键点:外循环加载一次 K/V block,内循环遍历所有 Q block。每次内循环中,用 Online Softmax 的 rescale 技术更新输出 O\mathbf{O}、最大值 mm 和归一化因子 \ell。整个过程中,S\mathbf{S}P\mathbf{P} 矩阵只在 SRAM 中以 tile 的形式短暂存在,永远不写入 HBM。

内存分析#

FlashAttention 的 forward pass 需要 O(N2d)O(N^2 d) FLOPs(与标准 Attention 相同),但只需要 O(N)O(N) 的 additional memory(存储 \ellmm),而不是标准实现的 O(N2)O(N^2)。这从根本上解决了长上下文的 out-of-memory 问题。

IO 复杂度对比#

标准 AttentionFlashAttention
HBM 访问次数Θ(Nd+N2)\Theta(Nd + N^2)O(N2d2M1)O(N^2 d^2 M^{-1})
额外内存O(N2)O(N^2)O(N)O(N)
端到端 speedupbaseline最高 7.6x

由于 d2/Md^2 / M 在实际硬件中远小于 1(例如 d=128d = 128M=100KBM = 100\text{KB} 时,d2/M1/6d^2/M \approx 1/6),FlashAttention 的 HBM 访问次数比标准实现少数倍到数十倍。

IO Complexity 分析与最优性证明#

FlashAttention 不仅在实践中快,论文还从理论上证明了它在 IO 复杂度意义下是最优的。

Theorem 1: 正确性与资源开销#

Algorithm 1 返回 O=softmax(QK)V\mathbf{O} = \text{softmax}(\mathbf{Q}\mathbf{K}^\top)\mathbf{V},使用 O(N2d)O(N^2 d) FLOPs,且只需要 O(N)O(N) 的 additional memory(不含输入输出本身)。

Proposition 2: HBM 访问上界#

标准 Attention 需要 Θ(Nd+N2)\Theta(Nd + N^2) 次 HBM 访问。FlashAttention 只需要:

O(N2d2M) 次 HBM 访问O\left(\frac{N^2 d^2}{M}\right) \text{ 次 HBM 访问}

对比关键在于 d2/Md^2 / M 这个因子。以实际硬件参数为例:

  • d=64d = 64d2=4096d^2 = 4096M=100KBM = 100\text{KB}d2/M1/24d^2/M \approx 1/24,即 HBM 访问减少约 24 倍
  • d=128d = 128d2=16384d^2 = 16384M=100KBM = 100\text{KB}d2/M1/6d^2/M \approx 1/6,即减少约 6 倍

Block Size 的 Sweet Spot#

Block size 越大,每次 SRAM 中能处理的数据越多,HBM 访问次数越少。但 block size 不能无限增大,因为它受限于 SRAM 容量。而且当 block size 大到一定程度后(超过约 256),runtime 的瓶颈从 memory-bound 转变为 compute-bound 或其他因素,继续增大 block size 的收益递减。

这意味着存在一个 sweet spot:block size 足够大以充分利用 SRAM,但不超过 SRAM 容量限制。

Theorem 3: 最优性下界(Lower Bound)#

论文最硬核的理论贡献之一是 Proposition 3:对于所有精确 Attention 算法,在 SRAM 大小 MM 满足 dMNdd \leq M \leq Nd 的合理范围内,HBM 访问次数的下界为 Ω(N2d2M1)\Omega(N^2 d^2 M^{-1})

这意味着 FlashAttention 的 IO 复杂度已经达到了理论最优,不存在一个精确 Attention 算法能在 IO 复杂度上比 FlashAttention 快一个数量级。

Block-Sparse FlashAttention#

论文进一步将 FlashAttention 扩展到 sparse attention 场景。给定一个 block 形式的 mask 矩阵 M~{0,1}N×N\tilde{\mathbf{M}} \in \{0, 1\}^{N \times N},计算:

S=QK,P=softmax(S1M~),O=PV\mathbf{S} = \mathbf{Q}\mathbf{K}^\top, \quad \mathbf{P} = \text{softmax}(\mathbf{S} \odot \mathbf{1}_{\tilde{\mathbf{M}}}), \quad \mathbf{O} = \mathbf{P}\mathbf{V}

其中 S1M~\mathbf{S} \odot \mathbf{1}_{\tilde{\mathbf{M}}} 表示将 mask 为 0 的位置设为 -\infty

Proposition 4:Block-Sparse FlashAttention 的 IO 复杂度为原始 FlashAttention 的 ss 倍,其中 ss 是 sparsity ratio(非零 block 的比例)。例如 s=0.1s = 0.1 意味着只有 10% 的 block 是非零的,IO 复杂度再降低 10 倍。

这为业界后续的 sparse attention 优化(如 Sliding Window + Global Attention 的组合)提供了高效的底层实现。

实验结果与实际影响#

训练速度#

FlashAttention 在多个主流模型和框架上实现了显著的端到端加速:

模型 / 框架BaselineFlashAttention加速比
BERT-large(8×A100)Nvidia MLPerf 1.1: 20.0±1.5 min17.4±1.4 min15% faster
GPT-2 small(HuggingFace)9.5 days (1.0×)2.7 days (3.5×)3.5×
GPT-2 small(Megatron-LM)4.7 days (2.0×)2.7 days (3.5×)1.7×
GPT-2 medium(HuggingFace)21.0 days (1.0×)6.2 days (3.4×)3.4×
长上下文场景baseline2.4×
Long Range Arenabaseline最高 2.8×

BERT 训练时间(17 分钟 vs Nvidia 的 20 分钟)值得强调:FlashAttention 在 BERT 上甚至击败了 Nvidia 自家的 MLPerf 实现,这在当时是相当大胆的 claim。

对于 Decoder-Only 架构(GPT 系列),FlashAttention 的加速更为显著,达到了 3-3.5 倍。这种质的飞跃来自于 GPT 的 auto-regressive 特性使得序列长度对性能的影响更大。

模型质量: 长上下文的突破#

FlashAttention 将 GPT-2 的上下文长度从 1K 扩展到 4K,perplexity 降低了 0.7,在两个长文档分类任务上提升了 6.4 个百分点。

长上下文带来的不仅是速度,更重要的是模型能力的质变:更多信息可以被输入到 Transformer 中,模型可以做更长距离的推理。

Path-X 与 Path-256: 从零到一的突破#

Path-X Challenge 是一个经典的长序列理解任务:在一张图片中判断两个点之间是否存在一条连接路径。这个任务需要模型处理极长的像素序列。

  • Path-X(序列长度 16K):FlashAttention 使得 Transformer 在这个任务上第一次超越了 random baseline(50% 准确率)
  • Path-256(序列长度 64K):同样是第一个实现 better-than-random 的序列模型

这个结果的意义在于:此前没有任何 Transformer 模型能在这些超长序列任务上做到 better-than-random。FlashAttention 通过使长上下文成为可能,实现了 0 到 1 的突破。

Attention Benchmark#

论文对 FlashAttention 的 attention kernel 本身做了详细的 profiling 和 footprint analysis:

  • Memory footprint 随序列长度线性增长(而非平方增长)
  • Block-Sparse FlashAttention 的 runtime 也随序列长度线性增长
  • 在所有测试的序列长度上,FlashAttention 都快于所有已有的 approximate attention baseline

Backward Pass 的实现复杂度与相关工作对比#

论文正文对 forward pass 做了完整展开,但 backward pass(计算梯度)在正文中几乎一笔带过,实际的完整推导和实现在 Appendix B 中,复杂度相当高。

值得注意的是,Rabe & Stauss 等人此前也做了 Attention 的内存优化工作,他们的方法减少了 memory footprint(内存占用量)。FlashAttention 与这些工作的核心区别在于:FlashAttention 不仅减少了内存占用,更关键的是减少了 memory access(内存访问次数)。这一差异直接体现在 wall-clock time 上。前者只是让你不 OOM,后者让你真正跑得更快。

关于 FlashAttention 2 的预告#

FlashAttention 1 的算法中,外循环遍历 KV block,内循环遍历 Q block。FlashAttention 2 将这两个循环对换了(外循环遍历 Q,内循环遍历 KV),这个看似简单的改动带来了进一步的性能优化。

Limitations 与 Future Directions#

论文坦诚地指出了三个局限性和未来方向:

  1. Compiling to CUDA:FlashAttention 用 CUDA 手写 kernel,对 Python 开发者不友好。如果想修改 attention 的实现(比如加入新的 bias 或 mask),需要直接改 CUDA 代码。论文提出了一个方向:是否能用高层语言(如 Halide 风格的 DSL)编写 IO-aware 的 attention 实现,自动编译到 CUDA。
  2. IO-Aware Deep Learning:IO-Aware 的思路不仅适用于 Attention,理论上也可以延伸到 MLP 层、Embedding 层等 Transformer 的其他模块。整个深度学习框架能否变得 IO-Aware?
  3. Multi-GPU IO-Aware Methods:FlashAttention 的 IO 分析针对的是单 GPU 内部的内存层级。当扩展到多 GPU 时,还需要考虑 GPU 间的数据传输(NVLink、PCIe),IO 分析需要增加一层。

代码库结构与实现解析#

FlashAttention 的代码仓库(flash-attention)包含四个版本(1-4),以下聚焦 FlashAttention v1(flash-attention-1-v1.0.9)的代码结构。

仓库概览#

FlashAttention v1 的两个核心目录:

  • flash_attn/:Python 层接口,负责将用户的 PyTorch tensor 传递给底层 CUDA kernel
  • csrc/:C++ 和 CUDA 层实现,包含真正的高性能 kernel

Python 层: flash_attention.py#

这个文件定义了两个核心 class,都继承自 nn.Module

FlashAttention:最直接的接口。forward 方法接收 QKV tensor,调用底层的 flash_attn_unpadded_qkvpacked_func,返回 attention output。它是一个对底层 CUDA function 的 wrapper。

FlashMHA(Multi-Head Attention):在 FlashAttention 之上再包了一层。它接收 embedding 而不是 QKV,先帮用户计算 QKV projection,再调用 FlashAttention

Python 层: flash_attn_interface.py#

这个文件是 Python 和 CUDA 之间的桥梁。它定义了四个继承自 torch.autograd.Function 的 class:

Class输入形式使用场景
FlashAttnQKVPackedFuncQKV 打包为一个矩阵QKV 形状相同时
FlashAttnKVPackedFuncKV 打包,Q 单独Cross-attention
FlashAttnFuncQ、K、V 三个矩阵分开通用场景
FlashAttnSplitFunc同上,head_dim > 128 时使用大 head dimension

每个 class 都实现了 forwardbackward 两个静态方法:

forward 的核心逻辑:

  1. 调用 flash_attn_cuda.fwd(底层 CUDA function)
  2. 将中间值(softmax_lse、rng_state 等)存入 ctx(autograd context),供 backward 使用
  3. 返回 attention output

backward 的核心逻辑:

  1. ctx 取出 forward 时保存的中间值
  2. 调用 flash_attn_cuda.bwd
  3. 返回 Q、K、V 的梯度

torch.autograd.Functionapply 方法会自动建立计算图,使得 FlashAttention 可以无缝嵌入 PyTorch 的自动微分系统。

C++ 层: csrc/flash_attn/#

C++ 层包含一个关键文件 fmha_api.cpp,它定义了两组核心 function:

参数设置set_params_fprop(forward)和 set_params_dgrad(backward)。它们的工作是将 PyTorch tensor 的指针、stride、形状等信息填入一个 params 结构体,供 CUDA kernel 使用。stride 是指在内存中从一个元素跳到下一个元素需要跳过的字节数。

Dispatch functionrun_fmha_fwdrun_fmha_bwd。它们根据 head dimension 的大小分发到不同的 kernel:

  • d32d \leq 32:走一种 kernel
  • 32<d6432 < d \leq 64:走另一种 kernel
  • 64<d12864 < d \leq 128:走第三种 kernel

dispatch 完成后,主要流程是:torch check(验证维度是否匹配、形状是否正确)→ set params → launch CUDA kernel → return result。

CUDA Kernel 层#

真正的硬核实现在 hdm32/hdm64/hdm128/ 三个目录中,分别对应三种 head dimension 范围。每个目录中包含纯 C++/CUDA 的 kernel 实现。

核心 class 名为 fmha_kernel,它使用了大量 NVIDIA 底层的 CUDA primitives:

  • Softmax:在 SRAM 中实现的分块 safe softmax
  • LSE(Log-Sum-Exp):softmax 的归一化因子
  • GEMM(General Matrix Multiply):矩阵乘法,这是 attention 中 compute-bound 的核心操作

这些 kernel 将 NVIDIA 提供的底层 primitives 重新组合,形成了 fused multi-head attention 的 forward 和 backward kernel。这两个 kernel 文件是整个 FlashAttention 代码库中最底层、最硬核的部分。

Block Sparse Attention#

除了标准的 FlashAttention,代码库中还实现了 BlockSparseAttention,同样包含 forward 和 backward 两个 kernel。Block sparse 版本允许用户指定一个 block-level 的 sparsity mask,跳过不需要计算的 block,进一步减少 IO 和计算。

FlashAttention
https://www.bilibili.com/video/BV1rHEc6XEX3/
作者
xwysyy
发布于
2026-06-04
许可协议
CC BY-NC-SA 4.0

评论

0/1000
评论加载中…
对话列表
© 2026 xwysyy. All Rights Reserved.
Powered by Astro & Firefly

文章目录