LLM智能应用开发

第6讲: 大语言模型解析 III

基于HF LlaMA实现的讲解

LLM结构的学习路径

  • LLM结构解析(开源LlaMA)
  • 自定义数据集构造
  • 自定义损失函数和模型训练/微调

Transformer经典结构

  • Encoder-decoder结构
  • 输入部分
    • Input embedding
    • Positional embedding
  • Transformer部分
    • Feed forward network
    • Attention module

HF LlaMA模型结构

LlamaForCausalLM(
  (model): LlamaModel(
    ...
    (layers): ModuleList(
      (0-15): 16 x LlamaDecoderLayer(
        (self_attn): LlamaAttention
        (mlp): LlamaMLP
        (input_layernorm): LlamaRMSNorm
        (post_attention_layernorm): LlamaRMSNorm
    )
    ...
  )
  (lm_head): Linear(in_features=2048, out_features=128256, bias=False)
)

LlamaDecoderLayer内部结构

Transformer架构的核心: attention(注意力机制)

(self_attn): LlamaAttention(
  (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
  (k_proj): Linear(in_features=2048, out_features=512, bias=False)
  (v_proj): Linear(in_features=2048, out_features=512, bias=False)
  (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
  (rotary_emb): LlamaRotaryEmbedding()
)

Attention内部结构

  • 静: 结构视角(init function...)
    • 4个Linear层
      • q_proj、k_proj、v_proj、o_proj
  • 动: 推理视角(Forward,bp靠Autograd自动求导)

Attention模块的输入

问题:QKV是输入吗?

  • 非也,输入是上一层的hidden states
class LlamaAttention(nn.Module):
...
def forward(hidden_states)
  ...
  query_states = self.q_proj(hidden_states)
  key_states = self.k_proj(hidden_states)
  value_states = self.v_proj(hidden_states)
  • 思考:hidden states的shape是怎样的?

标准Attention的第一步: 获得

  • 给定hidden states(后续简写为),通过前向传播(执行forward)得到
    • 的shape: [batch_size, seq_len, hidden_size]
    • : ,
      • 的shape: [hidden_size, hidden_size]
    • :
      • 的shape: [hidden_size, hidden_size]
    • :
      • 的shape: [hidden_size, hidden_size]

标准Attention的第一步: 获得

  • 为方便理解方法,脑补通过tensor.view改变shape
    • [batch_size, seq_len, hidden_size] -> [N, d]
      • N = batch_size * seq_len, d = hidden_size

center

标准Attention的第二步: 计算

  • 给定,计算,此处考虑mask

center

标准Attention的第三步: 计算Attention

  • 给定,计算
    • row-wise softmax:
    • , ,

center

标准Attention的第四步: 计算输出

  • 给定,计算

center

标准Attention回顾

  • 给定 (来自)

Attention中mask的作用

  • 回顾
  • <PAD>: 一种表示“padding”的特殊token,用来避免对句子中的某些token的影响
  • 为了避免padding对attention的影响,在计算时,我们可以将padding的部分设置为一个很大的数,如

Attention中mask的作用

center

对应的实现

移步notebook

MuliHeadAttention

  • 标准Attention只生成一个输出A,考虑多种角度,期望得到不同的A
    • 靠多个头实现,什么头??
    • 进行拆分,拆成多个头
    • 拆分为多头:[batch_size, seq_len, num_heads, head_dim]
    • 些许改造Attention计算过程


其中,

MultiHeadAttention

  • 给定 (shape [bs, seq, hs]),shape简化为
  • 多个heads
  • shape的变换(tensor.view实现): [N, d] -> [N, num_heads, head_dim]
    • 其中, d = hidden_size = num_heads * head_dim
    • 实现中,[bs, seq, hs] -> [bs, seq, nh, hd]
      • 再transpose为[bs, nh, seq, hd]

对应的实现

移步notebook

Attention计算开销

  • 的计算过程是的复杂度,那么多头的情况下,的计算复杂度是
  • 实际上,可依赖GPU并行执行提升速度
    • 分块的并行计算(sm计算单元)
  • 如何加速Attention计算?
    • BlockedAttention
    • FlashAttention

BlockedAttention第一步: 获得

  • 给定 (shape [batch_size, seq_len, hidden_size]),shape简化为

center

center

BlockedAttention第二步: 计算

  • 给定,
    • 的shape: , 的shape:

center

center
center

BlockedAttention第二步: 计算

  • 给定,
    • 的shape: , 的shape:

center

BlockedAttention第三步: 计算Attention

给定,计算

center
center

BlockedAttention第三步: 计算Attention

给定,计算

center
center

BlockedAttention第三步: 计算Attention

给定,计算

center
center

BlockedAttention第四步: 计算

给定,计算

center
center

BlockedAttention回顾

center

https://marp.app/

![bg right:40% 100%](images/l4/transformer.png)