MoreRSS

site iconChaofa Yuan修改

大模型算法工程师, 写过《LLMs-Zero-to-Hero》
请复制 RSS 到你的阅读器,或快速订阅到 :

Inoreader Feedly Follow Feedbin Local Reader

Chaofa Yuan的 RSS 预览

DPO 算法原理与代码实现:让 LLM 对齐变得简单

2026-01-10 23:07:00

0. 阅读收获 (takeaway)

本文目标是搞懂 DPO(Direct Preference Optimization)算法,阅读完本文你将获得:

  • 理解 DPO 的核心思想:为什么 DPO 可以替代 RLHF 中的 PPO
  • 掌握 DPO 与 RLHF 的关键区别:从 4 个模型到 2 个模型
  • 手撕 DPO Loss:理解损失函数到底在算什么
  • Bonus 1:为什么 DPO 比 PPO 训练更稳定
  • Bonus 2:DPO 损失函数的完整数学推导
  • 源代码位于 Github -动手学习大模型-中文版-第 12.1章——动手学习 DPO

本文代码运行于:Featurize GPU 算力云平台,不喜欢看文字的同学可以看 B站视频-chaofa用代码打点酱油YouTube-chaofa用代码打点酱油,视频号:chaofa用代码打点酱油

1. 为什么需要 DPO?

在聊 DPO 之前,我们先快速回顾一下 LLM 训练的三个阶段(参考 OpenAI InstructGPT):

  1. 预训练(Pre-training):在海量文本上训练,让模型学会"说话"
  2. 监督微调(SFT):用高质量的指令数据微调,让模型学会"听话"
  3. 对齐(Alignment):让模型的输出符合人类偏好,学会"说人话"

假设读者对于前两个步骤已经有所了解,这篇文章的重点是第三步"对齐"。

1.1 RLHF 的问题

OpenAI 在训练 ChatGPT 的时候用的是 RLHF(Reinforcement Learning from Human Feedback),整个流程大概是这样的:

DPO 原论文中的 RLHF vs DPO 流程对比

RLHF 确实有效,但问题也很明显:

  1. 需要 4 个模型:Actor(待训练)、Reference(冻结的 SFT 模型)、Reward Model(奖励模型)、Critic(价值函数)
  2. PPO 算法复杂:超参数一堆,训练不稳定,调参调到怀疑人生
  3. 资源消耗大:4 个模型同时跑,显存吃不消

之前在 DeepSeek-R1 论文解读 里也提到过,DPO 是 RLHF 的一种替代方案,但 DeepSeek 最终还是用了 GRPO(一种改进的 PPO)。不过对于大多数场景来说,DPO 已经够用了。

1.2 DPO 的卖点

DPO 的核心思路是:既然 RLHF 这么麻烦,能不能把强化学习的部分去掉,直接用监督学习的方式来做对齐?

答案是可以的。DPO 的作者通过一系列数学推导(后面 Bonus 部分会讲),证明了可以把 RLHF 的优化目标转换成一个简单的损失函数,只需要 2 个模型就能搞定:

  • Actor:待训练的模型 πθ\pi_\theta
  • Reference:冻结的 SFT 模型 πref\pi_{ref}

不需要单独训练 Reward Model,也不需要 PPO 那套复杂的东西。训练过程和 SFT 差不多,非常稳定。

2. DPO 的核心思想

2.1 偏好数据长什么样?

DPO 需要的数据格式很简单,就是一个 prompt 配上两个回答:一个好的(chosen),一个差的(rejected)。

# DPO 偏好数据示例
{
    "prompt": "介绍一下 chaofa用代码打点酱油 这个博主",
    "chosen": "chaofa用代码打点酱油 是一位专注于大模型技术的博主,他在 B站、YouTube 等平台分享 LLM 相关的技术内容,包括动手学大模型系列教程。他的内容特点是注重代码实现和原理讲解,帮助读者从零理解大模型的各种技术细节。",
    "rejected": "不知道,没听说过,说不定是个弱智。"
}

简单说就是:同一个问题,告诉模型哪个回答是好的,哪个是不好的。这种数据可以通过人工标注获得,也可以用更强的模型(比如 gemini/claude/gpt)来生成。

TRICK: 非同源模型的数据训练的时候,可以先用 "chosen" 数据 SFT,不然可能导致 chosen 和 rejected 概率都变低。

2.2 DPO 想做什么?

DPO 的目标其实就两个:

  1. 让模型更喜欢生成 chosen 回答:提高 chosen 的生成概率
  2. 不要偏离原来的 SFT 模型太远:保持模型的基本能力,防止"忘记"之前学到的东西

第二点很重要,如果只追求第一点,模型可能会为了迎合偏好数据而变得很奇怪(比如每个回答都很长、很啰嗦)。所以需要用参考模型来"拉住"它。

2.3 DPO 损失函数

好了,到了最核心的部分。DPO 的损失函数长这样:

LDPO(πθ;πref)=E(x,yw,yl)D[logσ(βlogπθ(ywx)πref(ywx)βlogπθ(ylx)πref(ylx))] \mathcal{L}_{\mathrm{DPO}}(\pi_\theta; \pi_{\mathrm{ref}}) = - \mathbb{E}_{(x,y_w,y_l) \sim D} \left[ \log \sigma\Big(\beta \log \frac{\pi_\theta(y_w \mid x)}{\pi_{\mathrm{ref}}(y_w \mid x)} - \beta \log \frac{\pi_\theta(y_l \mid x)}{\pi_{\mathrm{ref}}(y_l \mid x)}\Big) \right]

这个公式看起来贼复杂,但逻辑其实很清晰。首先看公式里面的核心部分,是在比较两个东西:

  • logπθ(ywx)πref(ywx)\log \frac{\pi_\theta(y_w \mid x)}{\pi_{\mathrm{ref}}(y_w \mid x)}:当前模型相对于参考模型,在 chosen 回答上的对数概率变化
  • logπθ(ylx)πref(ylx)\log \frac{\pi_\theta(y_l \mid x)}{\pi_{\mathrm{ref}}(y_l \mid x)}:当前模型相对于参考模型,在 rejected 回答上的对数概率变化

我们希望前者大于后者。也就是说,模型在 chosen 上的"提升幅度"要大于在 rejected 上的"提升幅度"。

β\beta 是一个超参数,用来控制"偏离参考模型的惩罚力度"。β\beta 越大,模型越不敢偏离参考模型;β\beta 越小,模型越"激进"。一般从 0.1 开始试。

σ\sigma 就是 sigmoid 函数,把差值映射到 (0, 1) 区间,然后取 log 变成 loss。

Q: 这个公式是怎么推导出来的?为什么这样设计就能达到我们的目标?这些问题留到 Bonus 部分再说。现在只要理解"DPO 在做什么"就够了。

3. 手撕 DPO Loss

理解了原理之后,我们来看看代码怎么写。其实 DPO 的核心代码非常简单,比公式看起来简单多了。

3.1 计算序列的 log 概率

首先,我们需要一个函数来计算模型在某个序列上的 log 概率。

对于语言模型来说,生成一个序列的概率就是每个 token 条件概率的乘积。取 log 之后,乘积变成求和:

logπ(yx)=tlogP(yty<t,x) \log \pi(y|x) = \sum_t \log P(y_t | y_{<t}, x)

import torch
import torch.nn.functional as F


def compute_log_probs(
    logits: torch.Tensor,       # (batch, seq_len, vocab_size)
    labels: torch.Tensor,       # (batch, seq_len)
    mask: torch.Tensor          # (batch, seq_len),标记哪些位置需要计算
) -> torch.Tensor:
    """
    计算序列的对数概率

    注意:这里只计算 response 部分的概率,prompt 部分不算
    """
    # 获取每个位置的 log softmax
    log_probs = F.log_softmax(logits, dim=-1)

    # 取出对应 label 的 log 概率
    # gather 操作:从 vocab_size 维度取出 labels 对应的概率
    per_token_log_probs = torch.gather(
        log_probs,
        dim=-1,
        index=labels.unsqueeze(-1)
    ).squeeze(-1)

    # 只计算 mask=1 的位置(response 部分)
    masked_log_probs = per_token_log_probs * mask

    # 求和得到整个序列的 log 概率
    return masked_log_probs.sum(dim=-1)

3.2 DPO Loss 核心实现

有了计算 log 概率的函数,DPO Loss 的实现就很直接了:

def dpo_loss(
    policy_chosen_logps: torch.Tensor,    # 当前模型在 chosen 上的 log 概率
    policy_rejected_logps: torch.Tensor,  # 当前模型在 rejected 上的 log 概率
    ref_chosen_logps: torch.Tensor,       # 参考模型在 chosen 上的 log 概率
    ref_rejected_logps: torch.Tensor,     # 参考模型在 rejected 上的 log 概率
    beta: float = 0.1,
) -> torch.Tensor:
    """
    DPO Loss 的核心实现

    代码比公式简单多了吧?
    """
    # 计算 log ratio:当前模型相对于参考模型的变化
    chosen_log_ratios = policy_chosen_logps - ref_chosen_logps
    rejected_log_ratios = policy_rejected_logps - ref_rejected_logps

    # 核心:我们希望 chosen 的 ratio 大于 rejected 的 ratio
    logits = beta * (chosen_log_ratios - rejected_log_ratios)

    # 用 logsigmoid 更数值稳定(等价于 -log(sigmoid(logits)))
    losses = -F.logsigmoid(logits)

    return losses.mean()

就这么简单。核心就三行:

  1. 计算 chosen 的 log ratio
  2. 计算 rejected 的 log ratio
  3. 用 sigmoid + log 算 loss

完整的训练代码涉及数据处理、模型加载等,这里就不展开了。可以参考 trl 源码

4. 用 trl 跑一下 DPO 训练

手写 DPO Loss 是为了理解原理,实际训练的话直接用 trl 就好了。trl 是 Hugging Face 出的强化学习库,DPO 训练用起来很简单。

from datasets import Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import DPOConfig, DPOTrainer

# 1. 准备模型
model_name = "Qwen/Qwen2.5-0.5B-Instruct"  # 用小模型演示
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# 参考模型(就是 SFT 后的模型,这里直接用同一个)
ref_model = AutoModelForCausalLM.from_pretrained(model_name)

# 2. 准备数据(trl 需要的格式)
train_data = Dataset.from_dict({
    "prompt": [
        "介绍一下 chaofa用代码打点酱油 这个博主",
        "DPO 和 RLHF 哪个更适合入门?",
    ],
    "chosen": [
        "chaofa用代码打点酱油 是一位专注于大模型技术的博主,在 B站、YouTube 分享 LLM 相关教程,内容注重代码实现和原理讲解,帮助读者从零理解大模型技术。",
        "建议先学 DPO,原理更简单,训练也更稳定。可以看 chaofa用代码打点酱油 的动手学大模型系列,有详细的代码实现。",
    ],
    "rejected": [
        "没听说过,应该是个小透明吧。",
        "都差不多,随便选一个。",
    ],
})

# 3. 配置训练参数
training_args = DPOConfig(
    output_dir="./dpo_output",
    beta=0.1,                    # DPO 的温度参数
    learning_rate=5e-7,          # DPO 通常用比较小的学习率
    per_device_train_batch_size=2,
    num_train_epochs=1,
    logging_steps=10,
    bf16=True,
)

# 4. 创建 Trainer 并训练
trainer = DPOTrainer(
    model=model,
    ref_model=ref_model,
    args=training_args,
    train_dataset=train_data,
    tokenizer=tokenizer,
)

trainer.train()

关键参数说一下:

  • beta:前面说过,控制偏离参考模型的惩罚力度,一般从 0.1 开始试
  • learning_rate:DPO 通常用比较小的学习率,5e-7 到 5e-6 左右

5. Bonus 1:为什么 DPO 比 PPO 训练更稳定?

很多人说"DPO 比 PPO 更稳定",但到底为什么呢?这个问题其实可以从几个角度来理解:

5.1 Off-policy vs On-policy

PPO 是一种 on-policy 的强化学习算法,DPO 是 off-policy 的,它直接用离线的偏好数据来训练,训练过程和 SFT 差不多。

  • on policy 每一次样本都是采样出来的,梯度可能会随时发生变化,梯度方差大;数据分布随着模型的更新会发生变化,上一轮学好的参数可能不适用下一轮,reward 比较稀疏(SFT/DPO 是 Token 级别的监督信号)。

5.2 不需要 Reward Model 和 Critic

PPO 在 RLHF 中需要:

  • 一个 Reward Model 来打分(这个模型本身就可能有问题,比如 reward hacking)
  • 一个 Critic(Value Function)来估计优势函数(这个网络的训练也不简单)

这些额外的模型都会引入噪声和不稳定因素。DPO 把 Reward Model 直接"吸收"到了损失函数里,不需要单独训练,少了很多可能出错的地方。

5.3 超参数敏感度

PPO 有很多超参数需要调:

  • clip ratio(裁剪系数)
  • GAE lambda
  • 学习率、batch size、epoch 数
  • KL 惩罚系数
  • ...

这些参数之间还有复杂的相互作用,调参调到怀疑人生是常有的事。DPO 的核心超参数就一个 β\beta,最多再加上学习率。简单很多。

备注:这里说的"稳定"。PPO/GRPO 调好了效果可能更好,但训练成本也更高。对于大多数场景来说,DPO 是一个性价比很高的选择。

6. Bonus 2:DPO 数学推导

这部分是给想深入理解的同学看的,跳过也不影响使用 DPO。

DPO 的 Loss 不是凭空设计出来的,而是从 RLHF 的优化目标一步步推导出来的。

6.1 RLHF 的优化目标

RLHF 想要做的事情是:最大化奖励,同时不要偏离参考模型太远。用公式表示:

maxπ  ExD,yπ(yx)[r(x,y)]βDKL[π(yx)πref(yx)] \max_{\pi} \; \mathbb{E}_{x \sim \mathcal{D},\, y \sim \pi(y \mid x)} \Big[ r(x,y) \Big] - \beta\, \mathbb{D}_{\mathrm{KL}}\Big[ \pi(y \mid x) \,\|\, \pi_{\mathrm{ref}}(y \mid x) \Big]

其中:

  • r(x,y)r(x,y) 是奖励函数(需要单独训练一个 Reward Model)
  • KL 散度用来约束模型不要偏离参考模型太远
  • β\beta 控制约束的强度

6.2 最优策略的形式

这个优化问题有一个解析解。我们先假设存在这样一个最优策略 π\pi^*,(具体推导可以参考 DPO 原论文附录,但我没看懂直接抄过来了),可以得到最优策略满足:

π(yx)=1Z(x)πref(yx)exp(1βr(x,y)) \pi^*(y \mid x) = \frac{1}{Z(x)} \pi_{\mathrm{ref}}(y \mid x) \exp\Big(\frac{1}{\beta} r(x,y)\Big)

其中 Z(x)=yπref(yx)exp(1βr(x,y))Z(x) = \sum_y \pi_{\mathrm{ref}}(y \mid x) \exp\Big(\frac{1}{\beta} r(x,y)\Big) 是归一化常数(配分函数),保证概率和为 1。

备注:

  • 它说的是:最优策略在参考策略的基础上,根据奖励大小进行"加权"。奖励高的回答概率会指数级增大,奖励低的会被抑制。β\beta 控制这个"加权"的激进程度。
  • 这个最优策略就是我们要学习的「模型参数」

6.3 反解奖励函数

从上面的式子,我们可以反过来把奖励函数用策略来表示:

r(x,y)=βlogπ(yx)πref(yx)+βlogZ(x) r(x,y) = \beta \log \frac{\pi^*(y \mid x)}{\pi_{\mathrm{ref}}(y \mid x)} + \beta \log Z(x)

这告诉我们:奖励函数可以用"当前策略和参考策略的 log 概率比"来表示

6.4 Bradley-Terry 偏好模型

在有偏好数据的时候,我们通常用 Bradley-Terry 模型来建模"哪个回答更好":

P(ywylx)=exp[r(x,yw)]exp[r(x,yw)]+exp[r(x,yl)]=σ(r(x,yw)r(x,yl)) P(y_w \succ y_l \mid x) = \frac{\exp[r(x, y_w)]}{\exp[r(x, y_w)] + \exp[r(x, y_l)]} = \sigma(r(x, y_w) - r(x, y_l))

ywy_w 是 chosen 的样本, yly_l 是 rejected 的样本。ywy_w 被偏好的概率取决于两个回答的奖励之差。

6.5 代入得到 DPO Loss

现在把 6.3 中的奖励函数代入 Bradley-Terry 模型。关键观察是:logZ(x)\log Z(x) 在两个回答中是一样的,相减的时候会消掉!

r(x,yw)r(x,yl)=βlogπ(ywx)πref(ywx)βlogπ(ylx)πref(ylx) r(x, y_w) - r(x, y_l) = \beta \log \frac{\pi^*(y_w \mid x)}{\pi_{\mathrm{ref}}(y_w \mid x)} - \beta \log \frac{\pi^*(y_l \mid x)}{\pi_{\mathrm{ref}}(y_l \mid x)}

前面提到,我们把 待训练的模型 πθ\pi_\theta 认为是最优策略 π\pi^*

最终,最大化偏好数据的似然(等价于最小化负对数似然),就得到了 DPO Loss:

LDPO=E(x,yw,yl)[logσ(βlogπθ(ywx)πref(ywx)βlogπθ(ylx)πref(ylx))] \mathcal{L}_{\mathrm{DPO}} = - \mathbb{E}_{(x,y_w,y_l)} \left[ \log \sigma\Big(\beta \log \frac{\pi_\theta(y_w \mid x)}{\pi_{\mathrm{ref}}(y_w \mid x)} - \beta \log \frac{\pi_\theta(y_l \mid x)}{\pi_{\mathrm{ref}}(y_l \mid x)}\Big) \right]

这就是我们在第 2 节看到的 DPO Loss。

7. 总结

一句话总结:DPO 用监督学习的方式实现了 RLHF 的效果,把 4 个模型简化成 2 个,训练更稳定、资源消耗更低

DPO 的局限性:

  • 依赖偏好数据的质量,数据不好效果就不好
  • β\beta 参数比较敏感,需要调参

后续还有一些 DPO 的变体,比如 IPO(Identity Preference Optimization)、KTO(Kahneman-Tversky Optimization)等,以后有机会再聊(其实就是大概率没有机会了,醒醒吧,2026 年了)。

8. 参考资料

  1. DPO 原论文: Direct Preference Optimization
  2. trl 库文档

其他

最后欢迎关注我,基本全网同名 chaofa用代码打点酱油

从零手写 RoPE 位置编码:原理、PyTorch 源码实现与可视化理解

2026-01-02 00:57:20

0. 阅读收获 (takeaway)

本文旨在彻底搞懂 RoPE(Rotary Position Embedding)位置编码,阅读完本文你将获得:

  • 理解 RoPE 的核心思想:为什么用"旋转"来编码位置信息
  • 掌握 RoPE 的数学原理:从旋转矩阵到三角函数证明
  • 从零手写 RoPE 实现:逐行代码讲解,可直接运行
  • bonus:可视化理解 RoPE:通过热力图和动画直观感受旋转编码

本文代码运行于: Featurize GPU 算力云平台,有 GPU 使用需求的同学希望能使用我的邀请链接注册

待更新:不喜欢看文字的同学可以看 B站视频-chaofa用代码打点酱油, YouTube-chaofa用代码打点酱油,或视频号:chaofa用代码打点酱油

1. 为什么需要位置编码?

在 Transformer 架构中,Self-Attention 机制本身是位置无关的。公式如下:

Attention(Q,K,V)=softmax(QKTdk)V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V

softmax 中 QK 的乘积就是重要性权重,什么意思呢?

# 假设我们有两个句子
sentence1 = "朝发 写 代码"
sentence2 = "代码 写 朝发"

# 对于纯 Self-Attention 来说,这两个句子的表示是一样的!
# 从公式看 Attention 只关心 token 之间的权重关系,不关心它们的顺序

这显然是不对的。语言是有顺序的,顺序不同意思完全不同。因此,我们需要位置编码(Position Encoding, PE)来告诉模型每个 token 在序列中的位置。

1.1 绝对位置编码 vs 相对位置编码

用一个例子来理解这两种编码方式的区别:

句子: "朝发 写 代码"
位置:   0  1  2

绝对位置编码:给每个位置一个固定编号

"朝发" → 位置 0 → PE_0
"写"   → 位置 1 → PE_1
"代码" → 位置 2 → PE_2
备注:PE_0 表示第一个位置的 embedding

相对位置编码:关注两个 token 之间的距离

计算 "朝发" 和 "代码" 的关系时:
→ 不关心它们分别在位置 0 和 2
→ 只关心它们相距 2 个位置

同理:计算 "朝发" 和 "写" 之间的相对位置是 (1 - 0) = 1。

使用相对位置编码就是希望捕获 Token 之间位置的相对关系,保持(某些)语义的不变性,下面 「朝发」和「代码」之间的关系是一样的,尽管绝对位置不同:

句子 A: "朝发 写 代码"
句子 B: "今天 朝发 写 代码"

2. RoPE 的核心思想

RoPE(Rotary Position Embedding,旋转位置编码)的核心思想非常优雅,可以阅读苏神 RoPE blog

通过旋转变换为向量注入位置信息,使得两个向量的内积只依赖于它们的相对位置。

这句话怎么理解呢?让我们一步步拆解看。

2.1 从 2D 旋转说起

假设我们在二维平面上有一个向量 (x,y)(x, y),将它旋转角度 θ\theta 后得到新向量:

(xy)=(cosθsinθsinθcosθ)(xy) \begin{pmatrix} x' \\ y' \end{pmatrix} = \begin{pmatrix} \cos\theta & -\sin\theta \\ \sin\theta & \cos\theta \end{pmatrix} \begin{pmatrix} x \\ y \end{pmatrix}

这就是经典的 2D 旋转矩阵。下面用一张图来直观理解:

2D 向量旋转示意图

从图中可以看到:蓝色向量 (x,y)(x, y) 绕原点逆时针旋转角度 θ\theta 后,变成红色向量 (x,y)(x', y')

2.2 RoPE 的目标与解决方案

目标:我们希望找到一个位置编码函数 ff,使得 query 向量 qm\mathbf{q}_m 和 key 向量 kn\mathbf{k}_n 的内积只依赖于它们的相对位置 (mn)(m-n)

fq(q,m),fk(k,n)=g(q,k,mn) \langle f_q(\mathbf{q}, m), f_k(\mathbf{k}, n) \rangle = g(\mathbf{q}, \mathbf{k}, m-n)

也就是说,无论 mmnn 的绝对值是多少,只要 mnm-n 相同,内积结果就相同。

解决方案:RoPE 发现,这个函数 ff 就是旋转函数!(实际上是可以通过求解出来的,可以参考:Transformer升级之路:2、博采众长的旋转式位置编码),这里我们假设「知道了这么一个函数」,然后我们去证明它符合我们的需求。

Keep Looking, Don&amp;apos;t Settle:重听乔布斯演讲(25-11-月度小结)

2025-12-07 06:52:00

Keep Looking, Don't Settle

这个月突然又偶然又听了一遍那场著名的「乔布斯在斯坦福毕业演讲」,最有名的一句话莫过于「Stay hungry. Stay foolish」。

‘You’ve got to find what you love,’ Jobs says|388x258

看这个 blog 的人应该每个人都听过好多遍,我也一样,但这次听真正击中的我并不是这句话,而是第二个创业故事中提到的「Keep Looking, Don't Settle」。

纯英文演讲我并没有 GET 到多少东西,我又翻了一遍原文,才知道原来我一直理解错了这个演讲的主题,主题并不是告诉我们要【永远保持学习和探索的热情,即这句最广为流传的 Stay hungry. Stay foolish】,真正的文章标题是:「You’ve got to find what you love」,原来这篇文章的主题一直都是寻找自己的热爱。

第一次这个演讲应该是初三的时候,那时候觉得很热血很受鼓舞,15 年后的今天重看有完全不一样的感觉,我觉得好感动,我觉得说出这样的话真的很让人感动。

I’m convinced that the only thing that kept me going was that I loved what I did. You’ve got to find what you love—and that is as true for work as it is for your lovers. Your work is going to fill a large part of your life, and the only way to be truly satisfied is to do what you believe is great work. And the only way to do great work is to love what you do.

If you haven’t found it yet, keep looking—and don’t settle. As with all matters of the heart, you’ll know when you find it. And like any great relationship, it just gets better and better as the years roll on. So keep looking. Don’t settle.

我想,尝试过寻找方向的你们一定很懂这里的话有多感人,So keep looking. Don't settle.

工作

回看了一下 11 月记录的 Flomo,这个月的工作压力依然非常的巨大,这里有很多原因造成的:

  • 任务真的太难了(某个非常复杂的生成场景)。尽管已经通过一些方式优化最终完成了交付,但也仅仅是可以使用,真正离达到好用还有很明显的差距。
  • 标准难以定义。可执行和执行好两个完全不同的维度,尤其是这个好很难在任务中很好的定义出来,虽然有隐隐的优化方向,但是前面已经在「可执行」上花费了太多的力气导致「执行好」的进度并没有如我自己想象中的那般快。
  • 项目进展管理一直做的不太好。从今年 2 月份开始一人做一个方向,到后面两三个人一起做,我要负责整个项目的多人进展时,就会明显感觉自己在这方面的能力有所欠缺,自己也会有明显的畏难情绪。整体虽然有进步但是离目标还差太远了。

这个月未其实有个我挺喜欢的任务要做,但是真的太耗费我晚上与周末的时间了,有娃之后的我真的顶不住,我周末只想陪娃一会和休息了,不想周末再卷了,想放弃了。难道这又是到了最关键的时候就开始泄气吗?毕竟已经念头通达了。

破产之路

本月出现巨额亏损,核心是因为我重仓了 FIGMA,亏成了麻瓜,我为什么要在阿里起飞之际卖出画成 FIGMA 啊,怎么买不买中概都是亏钱啊。

OK,好吧,我从现在开始,我就定投纳斯达克和标普 500,慢慢彻底不玩个股了,这样的我五年后还会亏钱吗?Looking my eyes, answer me!!!!

LLM MOE的进化之路,从普通简化 MOE,到 sparse moe,再到 deepseek 使用的 share_expert sparse moe

2025-01-28 03:30:00

1. 阅读前提

本次课一共讲解三个不同版本的 MOE,分别是基础版MOE,大模型训练用的 SparseMoE,还有 DeepSeek 用的比较多的 shared_expert 的 SparseMoE。

2. 版本1:基础版本MOE

输入是一个 Token, 输出是一个 Token Embedding。暂时先不考虑 MOE 得到的 Embedding 怎么使用。

因为 MOE 网络对应着 Expert,这个 Expert 一般是一个 FeadFoward Network,FFN。而为了简化,后续我们都用一层的 Linear 代替,更高级版本的 Expert 留给大家当做课后作业。下面是一个专家的定义。

class BasicExpert(nn.Module):
    # 一个 Expert 可以是一个最简单的, linear 层即可
    # 也可以是 MLP 层
    # 也可以是 更复杂的 MLP 层(active function 设置为 swiglu)
    def __init__(self, feature_in, feature_out):
        super().__init__()
        self.linear = nn.Linear(feature_in, feature_out)
    
    def forward(self, x):
        return self.linear(x)

基础版本的 MOE 可以看这个图,非常的简单。

llms-zero-to-hero-basic-moe-model


class BasicMOE(nn.Module):
    def __init__(self, feature_in, feature_out, expert_number):
        super().__init__()
        self.experts = nn.ModuleList(
            [
                BasicExpert(feature_in, feature_out) for _ in range(expert_number)
            ]
        )
        # gate 就是选一个 expert 
        self.gate = nn.Linear(feature_in, expert_number)
    
    def forward(self, x):
        # x 的 shape 是 (batch, feature_in)
        expert_weight = self.gate(x)  # shape 是 (batch, expert_number)
        expert_out_list = [
            expert(x).unsqueeze(1) for expert in self.experts
        ]  # 里面每一个元素的 shape 是: (batch, ) ??

        # concat 起来 (batch, expert_number, feature_out)
        expert_output = torch.cat(expert_out_list, dim=1)

        # print(expert_output.size())

        expert_weight = expert_weight.unsqueeze(1) # (batch, 1, expert_nuber)

        # expert_weight * expert_out_list
        output = expert_weight @ expert_output  # (batch, 1, feature_out)
        
        return output.squeeze()


def test_basic_moe():
    x = torch.rand(2, 4)

    basic_moe = BasicMOE(4, 3, 2)
    out = basic_moe(x)
    print(out)


test_basic_moe()

2. 版本2:SparseMoE (大模型训练使用)

这个一般我们用 switch transformers 这篇文章的图作为演示,详情看:

llms-zero-to-hero-switch-transformers-moe-model

和 Basic 区别是,MOE 选择 topK 个专家,然后对这 topK 个专家的输出进行加权求和,并且把输入样本变成了大模型中真实的输入 Shape,(batch, seq_len, hidden_dim)


# 主要参考自 mistral MOE 的实现

class MOERouter(nn.Module):
    def __init__(self, hidden_dim, expert_number, top_k):
        super().__init__()
        self.gate = nn.Linear(hidden_dim, expert_number)
        self.expert_number = expert_number
        self.top_k = top_k
    
    def forward(self, hidden_states):
        # 计算路由logits
        router_logits = self.gate(hidden_states)  # shape is (b * s, expert_number)
        
        # 计算专家经过softmax之后的概率
        routing_probs = F.softmax(router_logits, dim=-1, dtype=torch.float)
        
        # 计算topk的专家的输出
        router_weights, selected_experts = torch.topk(
            routing_probs, self.top_k, dim=-1
        )  # shape都是 (b * s, top_k)
        
        # 专家权重归一化
        router_weights = router_weights / router_weights.sum(dim=-1, keepdim=True)
        router_weights = router_weights.to(hidden_states.dtype)
        
        # 生成专家掩码
        expert_mask = F.one_hot(
            selected_experts,
            num_classes=self.expert_number
        )  # shape是 (b * s, top_k, expert_number)
        expert_mask = expert_mask.permute(2, 1, 0)  # (expert_number, top_k, b * s)
        
        return router_logits, router_weights, selected_experts, expert_mask


class MOEConfig:
    def __init__(
            self, 
            hidden_dim, 
            expert_number, 
            top_k, 
            shared_experts_number=2,
        ):
        self.hidden_dim = hidden_dim
        self.expert_number = expert_number
        self.top_k = top_k
        self.shared_experts_number = shared_experts_number

class SparseMOE(nn.Module):
    # 稀疏 MOE 模型,这里每一个 token 都会过 topk 个专家,得到对应token 的 hidden_embeddings
    def __init__(self, config):
        super().__init__()

        self.hidden_dim = config.hidden_dim

        self.expert_number = config.expert_number
        self.top_k = config.top_k

        self.experts = nn.ModuleList(
            [
                BasicExpert(self.hidden_dim, self.hidden_dim) for _ in range(self.expert_number)
            ]
        )

        self.router = MOERouter(self.hidden_dim, self.expert_number, self.top_k)
    
    def forward(self, x):
        # x shape is (b, s, hidden_dim)
        batch_size, seq_len, hidden_dim = x.size()

        # 合并前两个维度,因为不是 Sample 维度了,而是 token 维度
        hidden_states = x.view(-1, hidden_dim) # shape is(b * s, hidden_dim)

        router_logits, router_weights, selected_experts_indices, expert_mask = self.router(hidden_states)
        # 其中 selected_experts_indices shape 是 (b * s, top_k)
        # 其中 expert_mask shape 是 (expert_number, top_k, b * s)
        
        final_hidden_states = torch.zeros(
            (batch_size * seq_len, hidden_dim),
            dtype=hidden_states.dtype,
            device=hidden_states.device
        )

        for expert_idx in range(self.expert_number):
            expert_layer = self.experts[expert_idx]
            # expert_mask[expert_idx] shape 是 (top_k, b * s)
            idx, top_x = torch.where(expert_mask[expert_idx]) 
            # idx 和 top_x 都是一维 tensor
            # idx 的值是 0 或 1, 表示这个 token 是作为当前专家的 top1 还是 top2
            # top_x 的值是 token 在 batch*seq_len 中的位置索引
            # 例如对于 batch_size=2, seq_len=4 的输入:
            # top_x 的值范围是 0-7, 表示在展平后的 8 个 token 中的位置
            # idx 的值是 0/1, 表示这个 token 把当前专家作为其 top1/top2 专家

            # hidden_states 的 shape 是 (b * s, hidden_dim)
            # 需要取到 top_x 对应的 hidden_states
            current_state = hidden_states.unsqueeze(
                0
            )[:, top_x, :].reshape(-1, hidden_dim) # (selected_token_number, hidden_dim)

            # router_weight 的 shape 是 (b * s, top_k)
            current_hidden_states = expert_layer(
                current_state
            ) * router_weights[top_x, idx].unsqueeze(-1)  # (selected_token_number, 1) 这里有广播

            # 把当前专家的输出加到 final_hidden_states 中
            # 方式1 的写法性能更好,并且方式1容易出现
            final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
            # 方式2
            # final_hidden_states[top_x] += current_hidden_states.to(hidden_states.dtype)
            # 方式2 的写法性能更差,并且方式2容易出现错误,+= 操作在处理重复索引时需要多次读写内存,可能会导致竞争条件

        # 把 final_hidden_states 还原到原来的 shape
        final_hidden_states = final_hidden_states.reshape(batch_size, seq_len, hidden_dim)

        return final_hidden_states, router_logits # shape 是 (b * s, expert_number)


def test_token_level_moe():
    x = torch.rand(2, 4, 16)
    config = MOEConfig(16, 2, 2)
    token_level_moe = SparseMOE(config)
    out = token_level_moe(x)
    print(out[0].shape, out[1].shape)


test_token_level_moe()

3. 版本3:ShareExpert SparseMoE (deepseek 版本)

备注:这里是参考 deepseek moe 思想,写的一个共享 expert 的 MOE 网络,有一定的简化,但是可以方便理解训练过程。

和 版本2 的 SparseMOE 区别是,这里多了一个 shared experts 的模型,这个模型是所有 token 共享的,也就是说,所有 token 都过这个 shared experts 模型,然后每个 token 会用计算的 Router 权重,来选择 topK 个专家,然后和共享的专家的输出一起加权求和。

具体结构图为:

llms-zero-to-hero-deepseek-v3-model-architecture

class ShareExpertMOE(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.moe_model = SparseMOE(config)
        self.shared_experts = nn.ModuleList(
            [
                BasicExpert(
                    config.hidden_dim, config.hidden_dim
                ) for _ in range(config.shared_experts_number)
            ]
        )

    def forward(self, x):
        # x shape 是 (b, s, hidden_dim)
        # 首先过 moe 模型
        sparse_moe_out, router_logits = self.moe_model(x)
        
        # 针对的还是 x 的每一个 
        # 然后过 shared experts
        shared_experts_out = [
            expert(x) for expert in self.shared_experts
        ] # 每一个 expert 的输出 shape 是 (b, s, hidden_dim)
        
        shared_experts_out = torch.stack(
            shared_experts_out, dim=0
        ).sum(dim=0)
        
        # 把 sparse_moe_out 和 shared_experts_out 加起来
        return sparse_moe_out + shared_experts_out, router_logits


def test_share_expert_moe():
    x = torch.rand(2, 4, 16)
    config = MOEConfig(16, 2, 2)
    share_expert_moe = ShareExpertMOE(config)
    out = share_expert_moe(x)
    print(out[0].shape, out[1].shape)


test_share_expert_moe()

4. 模型训练测试

用于测试上面的代码是否可以跑通?


def switch_load_balancing_loss(router_logits: torch.Tensor, num_experts: int) -> torch.Tensor:
    """
    计算 Switch Transformers 的负载均衡损失
    
    Args:
        router_logits: shape [batch_size * sequence_length, num_experts]
        num_experts: 专家数量
    
    Returns:
        total_loss: 总损失 = auxiliary_loss + z_loss
    """
    # 计算路由概率
    router_probs = torch.softmax(router_logits, dim=-1)  # [b*s, num_experts]
    
    # 获取每个token的最优专家
    _, selected_experts = torch.topk(router_probs, k=2, dim=-1) 
    
    # 创建one-hot矩阵表示选中的专家
    mask = torch.nn.functional.one_hot(selected_experts, num_experts).float() 
    
    # 计算每个专家的期望负载 (理想情况下应该是 1/num_experts)
    expected_load = torch.ones_like(router_probs) / num_experts
    
    # 计算实际负载 (每个专家处理的token数量除以总token数量)
    # 在batch维度上计算平均值
    actual_load = mask.mean(dim=0)
    
    # 计算auxiliary loss
    # 这会惩罚负载分布与期望负载的差异
    aux_loss = torch.sum(actual_load * router_probs.mean(dim=0)) * num_experts
    
    # 计算z_loss (可选)
    # 这会惩罚过大的路由logits
    z_loss = torch.mean(torch.square(router_logits))
    z_loss_weight = 0.001  # 可调整的超参数
    
    # 总损失
    total_loss = aux_loss + z_loss * z_loss_weight
    
    return total_loss

def test_moe_training():
    # Create a simple dataset
    batch_size = 32
    seq_len = 16
    hidden_dim = 32
    num_batches = 100
    
    # Initialize model and optimizer
    config = MOEConfig(hidden_dim=hidden_dim, 
                      expert_number=4,
                      top_k=2,
                      shared_experts_number=2)
    model = ShareExpertMOE(config)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    
    # Training loop
    model.train()
    for batch in range(num_batches):
        # Generate random input data
        x = torch.randn(batch_size, seq_len, hidden_dim)
        target = torch.randn(batch_size, seq_len, hidden_dim)
        
        # Forward pass
        output, router_logits = model(x)

        # Compute losses
        # MSE loss for prediction
        mse_loss = F.mse_loss(output, target)
        
        aux_loss = switch_load_balancing_loss(router_logits, config.expert_number)
        # Combined loss
        total_loss = mse_loss + 0.01 * aux_loss
        
        # Backward pass and optimize
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()
        
        if batch % 10 == 0:
            print(f"Batch {batch}, Loss: {total_loss.item():.4f} "
                  f"(MSE: {mse_loss.item():.4f}, Aux: {aux_loss.item():.4f})")

# Run the training test
test_moe_training()

5. 课后作业

  1. 把 expert 改成 swishGLU 版本的 FFN 专家
  2. 把 MOE 应用到上一次的 build_nanoGPT 中,也就是替换掉原来的 FFN层,注意这里负载均衡 loss 要包含每一层的 MOE 的 router_logits
  3. 自己问一下 GPT topK 是怎么实现的反向传播,了解反向传播的梯度怎么流转的?

交个朋友🤣

最后欢迎关注我,基本全网同名 chaofa用代码打点酱油

LLM activate function激活函数的进化之路,从 ReLU,GELU 到 SwiGLU(swishGLU)

2025-01-28 02:58:00

1. 背景

自 chatGPT 22年底问世以来,大模型(Large Language Model, LLM)一般使用 Causal Language Model 的形式,属于 Transformers 中的 Decoder 部分,其中在 Decoder 的 Block 中有一个 FFN(FeadForward) 层,一般认为这部分参数用于存储知识。而标准的 FFN 一般有一个升维度和降维度的过程,一共有两个权重矩阵,用公式表示为

FFN(x)=ReLU(xW1+b1)W2+b2(1) FFN(x) = ReLU(xW_1 + b1)W2 + b2 \tag{1}

其中 x shape 是 (b,s,h)(b, s, h),w1 shape 是 (h,4h)(h, 4h),w2 shape 是 (4h,h)(4h, h), w1 是升维(up),w2 是降维(down)

激活函数主要是为了实现神经网络学习输入和输出之间的复杂非线性关系而使用的一个函数。在公式 (1) 中,ReLU 是一个激活函数(Transfromers原版),可以替换成其他的激活函数,比如 BERT 开始用 Gaussian Error Linear Unit,GELU 比较多,随后就成了激活函数的主流选择,但是随着大模型的爆火以及 PaLM 模型的发布,大家开始慢慢使用 swishGLU 作为激活函数,并且作为一个主要的优化点。

具体可以看下面一段代码即可清楚的理解 FFN 模型是什么实现的。

class FeedForward(nn.Module):
    # 实际上就是 MLP
    def __init__(self, config):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(config.n_embd, 4 * config.n_embd),
             # 激活函数
             nn.ReLU(),  
             #  可以替换成 nn.GELU(),  
             #  但是 如果是 SwishGLU 则实现方式有所不同,接下来就会介绍 swishGLU 是怎么实现的
            nn.Linear(4 * config.n_embd, config.n_embd),
            nn.Dropout(config.dropout)
        )
    
    def forward(self, x):
        return self.net(x)

2. 升级之路

1. ReLU

ReLU 深度学习以来最常用的激活函数,其公式非常的简单。

ReLU(x)=max(0,x)(2) ReLU(x) = max(0, x) \tag{2}

2. GELU

从 GPT、BERT 以来,GELU 似乎成了新时代取代 ReLU 的激活函数,具体形式如下:

GELU(x)=xP(Xx)=xΦ(x)(3) GELU(x) = x P(X \le x) = x \Phi(x) \tag{3}

其中 Φ(x)\Phi(x) 是标准正态分布的累计分布函数,定义为

Φ(x)=12(1+erf(x2))(4) \Phi(x) = \frac{1}{2}(1 + erf(\frac{x}{\sqrt{2}})) \tag{4}

这里的 erf 是误差函数

erf(x)=2π0xet2dt(5) erf(x) = \frac{2}{\sqrt{\pi}} \int_0^x e^{-t^2} dt \tag{5}

但是这个函数由于计算成本较高,因此有两个初等函数作为近似计算(但目前【2025年1月27日】其实很多框架已经可以精确计算 erf 函数)。

近似计算分析详细可以参见苏神的文章,GELU的两个初等函数近似是怎么来的

3. SwiGLU(SwishGLU)

SwiGLU(或者swishGLU,以下可能混用) 是 swish 激活函数和 GLU 门控单元的结合体,因此需要分别介绍两者的不同。

其中需要注意的是:在 T5 开始,很多模型(比如 PaLM )在FFN层都不用 bias 了,也就是说 FFN的公式变成了

FFN(x)=ActiveFunction(xW1)W2(6) FFN(x) = \text{ActiveFunction}(xW_1)W2 \tag{6}

注意公式 6 和公式 1 的区别,一共没有 bias 一个有 bias,但具体得看不同模型的实现,并不能一概而论。

3.1 swish 激活函数

swish 是一个非线性函数(激活函数都是如此,笑🤣),具体公式为:

Swish=xσ(βx) \text{Swish} = x \sigma(\beta x)

其中 β\beta 是一个超参数,当 β=1\beta = 1 时,Swish 就变成了 SiLU (Sigmoid Linear Unit),大多数框架的默认实现(如 PyTorch、TensorFlow 的 nn.SiLU())使用的是 β=1\beta = 1 的固定版本。

因此如果采用 swish 激活函数,FFN 的公式变成了

FFN(W1,W2,x)=Swish(xW1)W2 FFN(W_1, W_2, x) = \text{Swish}(xW_1)W2

共有两个可学习的矩阵,其中 w1,(h,4h)w_1,(h, 4h) 是升维矩阵,w2,(4h,h)w_2,(4h, h) 是降低维度的矩阵。

3.2 GLU 门控单元

GLU,Gated Linear Units,是一种门控结构(有参数,因此相对于普通的激活函数多了一个 gate 矩阵),通过 sigmoid 控制不同维度的激活。公式如下[1]

GLU(W,x,V,b,c)=(Wx+b)sigmoid(Vx+c)(7) GLU(W, x, V, b, c) = (Wx + b) \otimes \text{sigmoid}(Vx + c) \tag{7}

这里是不是熟悉 LSTM, GRU 的同学一下就理解,其中需要注意的是,b, c 对应的 bias 不是必须的。

对比公式 7 和公式 9,公式 9 中的 wupw_{up} 对应 公式 7 中的 WW,而 wgatew_{gate} 对应公式 7 中的 VV 矩阵。

3.3 SwiGLU 的表达形式

SwiGLU 就是把门控函数替换成了 swish,并且去除掉了 bias 部分,以及把 FFN 层的一个 Linear 层替换成了 GLU 层,因此一共有三个可训练的参数矩阵, w1, w2, w3。

因此最终的公式表达为,

FFN(W1,W2,W3,x)=W2(W1xSwish(W3x))(8) FFN(W_1, W_2, W_3, x) = W_2 \cdot (W_1x \otimes \text{Swish}(W_3x)) \tag{8}

而我们都知道 FFN 是一个升高维度,然后降低维度的过程,因此可以写成,W2 是一个降低维度的参数,W1 是升高维度的过程,而 W3 是一个 Gate 需要用到的参数矩阵。

FFN(wup,wdown,wgate)=wdown(wupxSwish(wgatex))(9) FFN(w_{up}, w_{down}, w_{gate}) = w_{down} \cdot (w_{up}x \otimes \text{Swish}(w_{gate}x)) \tag{9}

通过这个公式整体就非常的清晰理解使用 swiGLU 的 FFN。

而我们都知道在 basic 版本的 FFN,见公式(1), 只有 wupw_{up}wdownw_{down} 分别是 (h, 4h) 和(4h, h),因此整体参数是 8h28h^2

而公式9 中,一共有三个矩阵,如果想要实现总参数 8h28h^2,那么每一个参数矩阵的大小应该是 8h23\frac{8h^2}{3},因此 wup,wgatew_{up}, w_{gate} 的shape应该是 (h,8h3)(h, \frac{8h}{3})wdownw_{down} 的 shape 是 (8h3,h)(\frac{8h}{3}, h)

假设输入的 hidden_dim 大小是 hidden_dim,那么中间层(up 后的维度)大小是 mid_dim, 具体计算逻辑如下:

mid_dim = int(8 * hidden_dim / 3)
# multiple_of:make SwiGLU hidden layer size multiple of large power of 2
mid_dim = multiple_of * ((mid_dim + multiple_of - 1) // multiple_of)

# multiple_of 一般设置为 256, LLaMA 和 GPT等模型

注意,在 LLM (大语言模型) 架构中,multiple_of 是一个用于优化计算效率的参数,通常设置为 256 或其他 2 的幂次方数(如 128、512 等),最终让 mid_dim 调整为 multiple_of 的整数倍。这样做有几个原因:

  1. 硬件优化:现代 GPU/TPU 在处理 2 的幂次方大小的张量时效率最高
  2. 内存对齐:确保内存对齐可以提高计算速度
  3. 并行计算效率:某些并行计算操作在处理规整的数字时效率更高

3. 带有 swishGLU 的 FFN 代码实现

class FFNExpert(nn.Module):
    def __init__(self, hidden_dim, dropout):   # LLM 进化之路, FFN 激活函数从 GELU -> SwiGLU
        super().__init__()  

        # 有一个 magic number 叫做 8/3
        hidden_dim = hidden_dim
        # 这里可以自己去优化成 multiple_of 的倍数
        mid_dim = hidden_dim * 8 // 3

        self.up = nn.Linear(hidden_dim, mid_dim, bias=False)
        self.down = nn.Linear(mid_dim, hidden_dim, bias=False)
        self.gate = nn.Linear(hidden_dim, mid_dim, bias=False)

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = self.dropout(
            self.down(
                # up 之后的 Shape 是(b, s, mid_dim)
                # gate 和 up 之后的Shape都是 (b, s, mid_dim)
                # 两者是 element-wise 相乘
                F.silu(
                    self.gate(x)
                ) * self.up(x)
            )
        )
        return out

参考

最后欢迎关注我,基本全网同名 chaofa用代码打点酱油


  1. https://zhuanlan.zhihu.com/p/693332639 ↩︎

手写大模型组件之Group Query Attention,从 MHA,MQA 到 GQA

2024-12-09 06:00:00

  • GQA(Group Query Attention)的优点:效果损失小,推理的时候可以加速(来自于kvcache小,内存取数少)。
  • 仔细阅读 MHA, MQA 和 GQA的区别,就会发现 MHA 和 MQA 都是 GQA 的特殊表达形式
    • 三者可以用同一套代码,只需要修改【GQA】代码里面的 nums_key_value_head 参数就可
    • nums_key_value_head 设置等于 1 就是 MQA
    • nums_key_value_head 设置等于 nums_head 就是 MHA

如果不喜欢看文字的同学可以查看 B站 或者 YouTube 视频。

B站:https://www.bilibili.com/video/BV1ZmqpYfEGY/

YouTube: https://www.youtube.com/watch?v=1jBW7qcyd7A&t=1s

multi-head self-attention

备注:也可以直接由 GQA 中修改参数得到。但是本代码更完整一些

import math
import torch
import torch.nn as nn

class MultiHeadAttention(nn.Module):
    def __init__(self, hidden_dim, nums_head) -> None:
        super().__init__()
        self.nums_head = nums_head

        # 一般来说,
        self.head_dim = hidden_dim // nums_head
        self.hidden_dim = hidden_dim

        # 一般默认有 bias,需要时刻注意,hidden_dim = head_dim * nums_head,所以最终是可以算成是 n 个矩阵
        self.q_proj = nn.Linear(hidden_dim, hidden_dim)
        self.k_proj = nn.Linear(hidden_dim, hidden_dim)
        self.v_proj = nn.Linear(hidden_dim, hidden_dim)

        # gpt2 和 bert 类都有,但是 llama 其实没有
        self.att_dropout = nn.Dropout(0.1)
        # 输出时候的 proj
        self.o_proj = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, X, attention_mask=None):
        # 需要在 mask 之前 masked_fill
        # X shape is (batch, seq, hidden_dim)
        # attention_mask shape is (batch, seq)

        batch_size, seq_len, _ = X.size()

        Q = self.q_proj(X)
        K = self.k_proj(X)
        V = self.v_proj(X)

        # shape 变成 (batch_size, num_head, seq_len, head_dim)
        q_state = Q.view(batch_size, seq_len, self.nums_head, self.head_dim).permute(
            0, 2, 1, 3
        )
        k_state = K.view(batch_size, seq_len, self.nums_head, self.head_dim).transpose(
            1, 2
        )
        v_state = V.view(batch_size, seq_len, self.nums_head, self.head_dim).transpose(
            1, 2
        )
        # 注意这里需要用 head_dim,而不是 hidden_dim
        attention_weight = (
            q_state @ k_state.transpose(-1, -2) / math.sqrt(self.head_dim)
        )
        print(type(attention_mask))
        if attention_mask is not None:
            attention_weight = attention_weight.masked_fill(
                attention_mask == 0, float("-1e20")
            )

        # 第四个维度 softmax
        attention_weight = torch.softmax(attention_weight, dim=3)
        print(attention_weight)

        attention_weight = self.att_dropout(attention_weight)
        output_mid = attention_weight @ v_state

        # 重新变成 (batch, seq_len, num_head, head_dim)
        # 这里的 contiguous() 是相当于返回一个连续内存的 tensor,一般用了 permute/tranpose 都要这么操作
        # 如果后面用 Reshape 就可以不用这个 contiguous(),因为 view 只能在连续内存中操作
        output_mid = output_mid.transpose(1, 2).contiguous()

        # 变成 (batch, seq, hidden_dim),
        output = output_mid.view(batch_size, seq_len, -1)
        output = self.o_proj(output)
        return output


attention_mask = (
    torch.tensor(
        [
            [0, 1],
            [0, 0],
            [1, 0],
        ]
    )
    .unsqueeze(1)
    .unsqueeze(2)
    .expand(3, 8, 2, 2)
)

x = torch.rand(3, 2, 128)
net = MultiHeadAttention(128, 8)
net(x, attention_mask).shape

Group Query Attention

备注:以下代码省略了 attention_dropout attention_mask等情况的处理,真实实现过程中需要考虑。

import torch
import torch.nn as nn
import math

# 忽略了 attention_mask, attention_dropout; 
class GroupQueryAttention(nn.Module):
    def __init__(self, hidden_dim, nums_head, nums_key_value_head):
        super().__init__()
        assert hidden_dim % nums_head == 0 # 可以整除
        assert nums_head % nums_key_value_head == 0  # N 个 query head 为一组

        self.hidden_dim = hidden_dim
        self.nums_head = nums_head
        self.nums_key_value_head = nums_key_value_head
        self.head_dim = hidden_dim // nums_head

        # 初始化 qkv o
        self.q_proj = nn.Linear(hidden_dim, nums_head * self.head_dim)  # out feature_size (nums_head * head_dim)
        # k v out shape (nums_key_value_head * head_dim)
        self.k_proj = nn.Linear(hidden_dim, nums_key_value_head * self.head_dim)
        self.v_proj = nn.Linear(hidden_dim, nums_key_value_head * self.head_dim)

        self.o_proj = nn.Linear(hidden_dim, hidden_dim) # input_size nums_head * head_dim

    def forward(self, X, attention_mask=None):
        # X shape (batch, seq, hidden_dim)
        batch_size, seq, _ = X.size()

        # qkv projection
        q = self.q_proj(X)  # (batch, seq, hidden_dim)
        k = self.k_proj(X)
        v = self.v_proj(X) 

        # attention_weight 目标shape 是 (batch, nums_head, seq, seq)
        q = q.view(batch_size, seq, self.nums_head, self.head_dim)
        k = k.view(batch_size, seq, self.nums_key_value_head, self.head_dim)
        v = v.view(batch_size, seq, self.nums_key_value_head, self.head_dim)

        # 关注: nums_head 和 nums_key_value_head 的关系
        q = q.transpose(1, 2) # (b, nums_head, seq, head_dim)
        k = k.transpose(1, 2) # (b, nums_key_value_head, seq, head_dim)
        v = v.transpose(1, 2)  # (b, nums_key_value_head, seq, head_dim)

        # k v repeat; (广播操作)
        k = k.repeat_interleave(self.nums_head // self.nums_key_value_head, dim=1)
        v = v.repeat_interleave(self.nums_head // self.nums_key_value_head, dim=1)

        attention_score = (q @ k.transpose(2, 3)) / math.sqrt(self.head_dim)

        attention_weight = torch.softmax(attention_score, dim=-1)
        # (attention_mask 忽略) # 可以看前面的视频

        output = attention_weight @ v  # (b, nums_head, seq, head_dim)

        # output projection 变成 (b, seq, hidden_dim)
        output = output.transpose(1, 2).contiguous()
        final_output = self.o_proj(output.view(batch_size, seq, -1))

        return final_output

# 测试
x = torch.rand(3, 2, 128)
net = GroupQueryAttention(128, 8, 4)
net(x).shape

Multi Query Attention

由于 MQA 是 GQA 的一种特殊形式,因此只要在参数设置的时候将 nums_key_value_head = 1 就是 Multi Query Self-Attention。

交个朋友🤣

最后欢迎关注我,基本全网同名 chaofa用代码打点酱油