










Softmax 的公式:

为了防止指数爆炸问题,在实际计算的时候会采用 Safe Softmax:

一般来说,上述公式中
inf,下溢会得到 0,最终 softmax 可能变成 NaN

V2版本算法





如下图所示,蓝色的部分表示当前存储在 shared memory 中的部分

FlashAttention 的实现是不唯一的,事实上,很多实现都没有完全采用原始论文中的方法,会有一定程度的调整

假设一个 block 实际上会被 SM 划分成 4 个 warp,在 V1 版本中,矩阵 𝐾,𝑉 的 block 会被划分成 4 个 warp,每个 warp 计算

在 V2 版本中,矩阵 𝑄 的 block 会被划分成 4 个 warp,这种情况下每个 warp 计算出来的结果就是一个

https://marp.app/
