在本次实验中,我们将实现深度学习模型中非常核心且基础的运算模块——矩阵乘(Matrix Multiplication)。矩阵乘法不仅是构建神经网络中全连接层、卷积操作和注意力机制等模块的基础,同时也是高性能计算优化的重要对象。
Task 1: MalMul with multi-head variant
在 task 1 中,我们要实现两个矩阵相乘的逻辑,我们有以下两个矩阵:
A1:一个 3D 的输入张量,形状为[batch_size, seq_len, hidden_size],batch_size表示序列的数量,seqlen表示一个序列的最大长度,hidden_size表示序列中每一个token拥有的维度。我们简写A1的形状为[b, s, h]。W1:一个 2D 的权重张量,形状为[hidden_size, embed_size],它表示一个投影矩阵,将任何行向量从hidden_size-dim 投影到embed_size-dim。我们简写W1的形状为[h, e]。
朴素的矩阵乘法仅对 A1 中 batch_size 维度,针对每个序列索引i,都执行 O1[i] = A1[i] @ W1 计算,从而得到形状为 [b, s, e] 的张量 O1。
在多头矩阵乘法中,我们首先将输入张量 A1 和权重张量 W1 的 h 维度均分为 num_heads 个子维度(记为 nh,表示头的数量),由此得到形状为 [b, s, nh, hd] 的四维张量 A2 和形状为 [nh, hd, e] 的三维张量 W2。接下来,对于 A2 中 batch_size 维度下的每个序列,遍历其 num_heads 维度上的每个 [s, hd] 矩阵,并将其与 W2 中 num_heads 维度下对应的 [hd, e] 矩阵进行乘法运算。通过多头并行计算,最终输出一个形状为 [b, s, nh, e] 的四维张量 O2。
TODO
完成 src/functional.py 中的 matmul_with_multi_head 函数 ,实现上述多头矩阵乘法的逻辑,输入张量 A1 和 W1,返回计算值 O2。
- 输入的张量是 A1 和 W1,你需要自己将其转换为 A2 和 W2 再进行计算,请注意 torch 中
reshape,view,transpose,permute等函数的用法和区别。 - 虽然逻辑上矩阵的乘法是用遍历进行计算的,但请勿使用 for 循环的方式进行实现,请自行查阅 pytorch 的计算函数,如
@,torch.bmm,torch.mm,torch.matmul,torch.einsum等。 - 了解并使用 pytorch 计算中的广播机制,有助于简化计算逻辑。
- 所有输入张量均在同一设备(CPU 或 CUDA)上从标准正态分布 N (0, 1) 随机初始化,具有相同的数据类型(float32、float16 或 bfloat16),并且在所有测试用例中均未设置
require_grad; - 在所有测试用例中,
hidden_size均会被保证能被num_heads整除。
Task 2: MalMul with importance
在多头矩阵乘法的基础上,我们引入一个表示“重要性”的概率张量 P,其形状为 [b, s]。P 中的每个元素表示 A1 中对应位置的元素的重要程度。基于这个重要性概率,我们的目标是只对每个序列中的 “重要” 元素执行矩阵乘法运算。这些重要元素总共有total_important_seq_len 个,简记为 t,其计算结果会被收集到输出张量 O3 中,其形状为 [t, nh, e]。
为了精确界定 “重要” 元素的范围,我们提供两个可选参数:
top_p:取值范围为[0., 1.]的浮点数。只有概率值大于或等于top_p的元素才被视为 “重要” 元素,默认值为1.0。top_k:取值范围为[1, ..., seq_len]的整数。对于批次中的每个序列,只将概率最高的top_k个元素视为 “重要” 元素。如果未设置top_k(默认值为None),则等价于top_k = seq_len。
注意,必须同时满足上述两点的元素才是重要元素。
TODO
完成 src/functional.py 中 matmul_with_importance 函数 Task2 的部分,实现上述重要性乘法。首先,你需要根据 top_p 和 top_k 的值,从 A1 中挑选出“重要”的元素,组成 [t, h] 的张量 A3,再仿造 Task1 中的多头矩阵乘法,输出 [t, nh, e] 的张量 O3。
可以使用 torch.topk 计算 topk 个重要元素。
在所有测试用例中,top_p 和 top_k 参数均会被保证在各自有效范围内取值。
Task 3: MalMul with grad
此外,如果提供了输出张量的可选梯度(记为 dO3,其形状与 O3 相同),我们还需要计算输入张量的梯度(记为 dA1,形状与 A1 相同)和权重张量的梯度(记为 dW1,形状与 W1 相同)。若未提供 dO3,则 dA1 和 dW1 均返回 None。
TODO
完成 src/functional.py/matmul_with_importance 中 Task3 的部分,请参考 A0 中介绍的两种求梯度的方式,返回 A1 和 W1 的梯度。
- 若未提供
grad_output参数,应避免计算梯度以提高效率并节省内存。 - 若提供了
grad_output参数,可使用 PyTorch 的自动求导机制计算梯度,但需注意潜在的副作用,这些副作用可能会在测试中被测试。
References
提示:以下是一些可能对你的任务有帮助的参考资料,或者可以加深/拓展你对 PyTorch 的理解:
!! 请记住:查阅论文、源码以及官方文档,并从中进行思考和学习,是一项基本且至关重要的能力。请尽量不要过度依赖一些带有偏见或内容浅显的博客,例如 CSDN !!