Softmax 的公式:
为了防止指数爆炸问题,在实际计算的时候会采用 Safe Softmax:
一般来说,上述公式中
V2版本算法
如下图所示,蓝色的部分表示当前存储在 shared memory 中的部分
FlashAttention 的实现是不唯一的,事实上,很多实现都没有完全采用原始论文中的方法,会有一定程度的调整
假设一个 block 实际上会被 SM 划分成 4 个 warp,在 V1 版本中,矩阵 𝐾,𝑉 的 block 会被划分成 4 个 warp,每个 warp 计算
在 V2 版本中,矩阵 𝑄 的 block 会被划分成 4 个 warp,这种情况下每个 warp 计算出来的结果就是一个
https://marp.app/
