2026-02-24 00:07:00
折腾妈妈九个多月的乎乎出生了,我很开心。尤其是前段时间,抱着她,转过头的一瞬间突然叫一声「爸」,顿时有些恍惚,原来不知不觉间已经能隐约发出「爸」的声音了。
一个月前,点点说:“在我们决定要娃的时候,你说,「乎乎出生后,我所有的业余时间都会用来陪乎乎」,但你没有,你不是一个好爸爸。” 事实确实如此,2025 年是我自工作以来最忙碌的一段时间,各方面的事情都非常的多,我没有很多的时间陪乎乎导致点点有一周很崩溃,离家出走。我也因各种事情搞得头昏脑胀的,一度觉得只有不上班或者离婚才能解决问题[1]。
我尝试过每天通勤回龙岗,但是坚持不了几天身体就承受不住了,每天上班都脑袋昏昏沉沉;也试了带睡几天,可是心脏直砰砰的跳,根本睡不着,生怕要自己先猝死了。我时常想,其他人是怎么平衡工作和家庭的呢?是我精力太差了,还是其他人精力太好了呢?
我不是一个好爸爸,按照点点平时给我的打分,我只有 B-。过年放假的时候倒是花了更多的时间陪孩子,但如果工作忙起来,我周末还能有能量这样吗?希望 26 年能到 A,让乎乎和点点都给我打分,靠更频繁的奖励信号纠正爸爸不当的行为。
第一个部分其实也提到了工作,似乎核心就一个字「忙」,但真有这么忙吗?也许是有的,有几个离职的同事都表示变好了。但忙出了什么东西吗?我觉得也是有的,团队的业务产出其实都挺好的,当然 2025 年全年来说,我个人产出也是超出自己预期的[2],也有更多的思考🤔。
这一年,我自己的工作主线是:Agent 落地。主要的目标是:给模型更多的「自主决策空间,Let Agent to have More Agency」,主要的方式也和业界主流发展方式一致,围绕着两点:「更好的 Context」以及「给定 Context 下更好的执行」。由于今年是我做公开表达更多的一年,因此对于业界发展的跟进是比较快的,比如近期 Claude Code / Codex 都发表了一些关于 Prompt Caching 对于 Agent 设计的影响,我春节期间也写了三篇文章来介绍「Agent 系统中的 Prompt Caching 设计」。
此外,今年也暴露了自己比较大的问题:更擅长自己执行而不是协作。因为今年自己作为部分子项目 Owner,需要对于项目的进度有更强的规划和跟进能力。尽管现在都在强调 AI 时代要充分发挥个体能力,但是就目前(2026-02-23)而言,我觉得协作能力还是挺重要的,因为协作是意味着多线程的管理和表达能力,这在 AI 时代更是放大人与人区别的关键点。
其实大模型越来越强,我越来越焦虑,牛逼的人已经产出 100X,而自己的效率却没有本质变化。我仿佛看到了上个改革开放年代或者更近的所谓的移动互联网腾飞之年,有大量的人攫取了巨大的财富,仿佛是个人都能吃上所谓的时代红利,但大多数人,浪潮过后好像也没什么变化。
而这一次,AI 巨浪滚滚向前,又创造了一大波造富神话,身处其中,感觉每天都是翻天覆地的变化,各种新的产品怎么也跟踪不过来;但似乎:每一个好像都和我没什么太大的关系,我能提升效率吗?我能赚更多的钱吗?我能更快到达自由的彼岸吗?所以更大的可能性是:浪潮过后,自己什么也没得到,就像哪些在历史机遇中平平淡淡的大多数人一样。
去年,我在年终总结的时候写:2024年是公开表达元年。是的,2025 年,我尝试了更多的个人表达:
我在多次的月度总结中提到「商业化」,比如:2025-08-孙宇晨真的很值得学习,2025-09-合法赚钱就是高尚的,2025-10-一个程序员对自媒体商业化的深度复盘等,如果不明所以的朋友可能会觉得赚了很多钱,但事实远非如此,加起来半个月工资都不到,可以说这方面是比较失败的。
但也不是完全没有收获,我觉得通过视频商业化的尝试,我理解了很多的商业化行为,对于完整的商业化闭环也有了更深入的思考——打工是一个「期望风险更低」「期望收获更高」的商业化行为[4]。
由于前面的尝试以及认知,我不能再花时间去接所谓的商单了,我觉得商单确实是一种毒药,看上去好像赚了点小钱(见第一段),但实际上挺麻烦的,付出和收入不成正比。还是需要找到自己核心的竞争力和核心产品,才可能真正占用尽量少业余时间获得更高的复利。
此外,由于后面商单的出现,我甚至开始有点羞于向别人说:我做了一个技术频道叫做 chaofa 用代码打点酱油,对标油管的 Andrej Karpathy[5]。所以 2026 年 1-2 月的时候,有 4 个新商单找我,我都拒绝了[6],我还是想做更纯粹的表达(当然肯定是想赚钱的,所以还在拧巴中)。
所以以后会怎么样呢?暂时还不知道🤷♀️


工作自不必多说,依然是2026 最需要重点投入的事情,要积极跟进前沿,与 LLM 多多探讨业务、技术的发展,争取在工作上有进一步的突破。
另外,生产变革也已经发生,生产方式已经发生了巨变,尤其是已经看到非常多的人在 AI 的加持下做出了让人瞩目的成绩,所以 2026 要更加彻底的拥抱 AI Coding,应该说在创造一些自己的 Product,而不是隔岸观火。
所以我斥巨资买了一个域名叫做:ApeCode.ai,代码都会放到 github.com/ApeCodeAI 下,Slogan 想了好多有意思的:

写于:2026 年 2 月 23 日 20:38:56 新年春节假期返工前一天晚
实际上我和点点几乎没有什么矛盾,目前也没有太大的经济压力,只是人都会有情绪崩溃的时候,造成一些不可思议的想法。「不是客观上带孩子时间的问题,而是在带孩子的情绪价值和参与度上的问题」导致点点觉得我投入不够。 ↩︎
工作这么多年,虽然也有绩效不错的时候,但我很少自己给自己评价超出预期。不过这一年,不管别人怎么看,我自己是尽全力了,业务产出也还不错。另外,由于 25 年也作为面试官面试了非常多的候选人,让我对各种事情有了更加深刻的认识,也让我更加坚定的要建立一套自己的评估体系。 ↩︎
我记得 11 月度总结的时候,我说:还有一个月就 2025 年全勤了,没想到最后功亏一篑了。有朋友可能会问,AI 写得比你写得好多了,为什么不用 AI?我的答案是:这一类个人思考和总结的东西,我不想用 AI,因为本来是写给自己看的,如果我不知道「下笔的时刻我在思考什么」,那么写出的文章又有什么意义呢?思考过程的本身比结果更让人着迷。要是对此感兴趣的同学也可以关注:公众号——chaofa 用代码打点酱油 ↩︎
正好对应了我一直说的,打工一定要投入更多的时间。公开表达只能是业余时间中 20% 的精力,因为收益真的没想象中高。自媒体是一个极度放大幸存者偏差的地方。Keep it in mind. ↩︎
这是最初的愿景,因为我也是完全从零手写代码/或者读一些前沿的论文,这也是吸引很多「专业」的同行的原因,因此我的视频受众有非常多各种大厂在职的算法工程师。 ↩︎
26 年 1 月份的发的两个视频都是元旦期间做的,都是很早很早以前接的,只是平常时间做 ↩︎
2026-02-22 23:06:00
读完本文,你将了解:
前置知识:本文是 Agent 系统中的 Prompt Caching 设计(上) 的续篇。如果你还没读过,建议先了解 Cache 破坏机制、Prompt 布局和工具管理策略。更基础的概念见 理解 KV Cache 与 Prompt Caching。
Anthropic 的研究指出:随着 context 中 token 增加,模型注意力分散,性能下降。
Attention 中每个 token 要和所有其他 token 建立 $n^2$ 的 pairwise 关系。Context 越长 → 每个 token 的"注意力预算"越少 → 模型可能"忘记"早期的重要指令,或被大量 tool output 稀释关键信息。
更大的 context window 不是万能解药。 能塞进去不代表模型能有效利用。
这就是为什么 Agent 不能简单地把所有信息堆进 context——我们需要主动管理。
压缩是解决 context 增长的关键手段,但必须是 cache-safe 的。

Claude Code 的压缩策略非常精巧:
Codex 提供专门的 API 端点:
auto_compact_limit 自动触发两者共同点:压缩是必要的,但压缩过程本身不能破坏已有的 cache。
Agent 如何确保模型"聚焦"在正确的事情上?三家有一个有趣的演进过程。
todo.md → 约 1/3 actions 浪费在更新 todo 上!EnterPlanMode
| 方案 | 独立阶段? | 用户审批? | 自主控制? |
|---|---|---|---|
| Manus planner agent | 独立 agent | 无需审批 | Agent 决定 |
| Claude Code Plan Mode | 独立阶段 | 需要审批 | 模型可自主进入 |
| Codex update_plan | 也有独立阶段 | 无需审批 | 执行中随时调用 |
一个有趣的共识:不要删除失败的 action 和 observation。
Manus 不会从 context 中删除失败的工具调用结果。双重好处:
错误恢复是真正 agentic 行为的标志。 看不到自己犯的错,怎么学会避免?
Agent 不需要把所有信息都放在 context 里——文件系统可以作为"延展记忆"。
| 方案 | 预加载 | 按需检索 |
|---|---|---|
| Claude Code | CLAUDE.md | glob/grep 搜索文件系统 |
| Codex | AGENTS.md | shell 工具探索 |
| Anthropic 建议 | 最少必要信息 | JIT 检索 |
共同点:用 glob/grep 搜索文件系统,无需向量索引。这和 Agentic RAG 的思路一脉相承——Agent 自主决定搜索什么,而不是被动接受检索结果。
Cache 是 model-specific 的。各家的做法:
Claude Code 架构(逆向分析数据):
| 子代理 | 工具数 | Prefix Reuse |
|---|---|---|
| Main Agent | 18 | — |
| Explore × 3 并行 | 10/18 子集 | 92% |
| Plan | 独立 context | 93% |
| Execution | 全部 | 97% |
Claude Code 还使用 warm-up 调用:启动时预热 tool list 和 system prompt 的 cache。
Manus 多代理架构:
submit_results 工具 + 约束解码确保输出格式Anthropic 建议:子代理返回压缩 summary → 避免主 context 被"污染"。
Fork 出的子任务必须用和父对话相同的 prompt prefix,才能复用父对话 cache。
Claude Code 在 compaction、summarization、skill execution 中都遵循这个原则。核心思想:压缩/fork 是在现有 cache 基础上的延伸,而非另起炉灶。
最后分享 Manus 的反思,引用 Rich Sutton 的 "The Bitter Lesson":
Agent 的 harness(框架/约束)可能限制模型性能。随着模型进步,需要不断简化架构。
Manus 自 2025 年 3 月以来已重构无数次。每次模型能力提升,某些 workaround 就变得不必要。
但有些设计是"持久"的——围绕 cache 的架构决策就是。它们不是在弥补模型不足,而是在适配计算的物理现实。
Cache 是物理约束,不是工程 hack。 只要 Prefill 还是 Compute Bound,Prompt Cache 就会继续是 Agent 架构的核心考量。
备注:本文主要受前 4 篇参考内容的启发
最后欢迎关注我,基本全网同名 chaofa用代码打点酱油
2026-02-22 18:16:00
读完本文,你将了解:
前置知识:本文假设你已经理解 KV Cache、Prefill/Decode 两阶段、以及 Prompt Cache 的前缀匹配机制。如果不熟悉这些概念,建议先阅读 理解 KV Cache 与 Prompt Caching:LLM 推理加速的核心机制。
在深入细节之前,我想先分享我对这个话题的理解:
Prompt Cache 不只是一个省钱技巧,它是 Agent 系统架构设计的核心约束。
就像数据库的 schema 设计会影响整个应用架构一样,Prompt Cache 的前缀匹配约束深刻地影响了 Agent 的每一个设计决策:
我们已经越来越多地听到 "Context Engineering" 这个术语。区别在哪?
Manus 在 25 年底最新总结中提出了三个维度:Reduce(缩减)、Isolate(隔离)、Offload(卸载)。
后面的内容你会看到,各家的设计都在围绕这三个维度展开。
不同公司、不同架构,但核心规律惊人地一致:
带着这些规律,我们来看具体的实践。
Agent 每一步都需要发送完整的对话历史给模型,模型只输出一小段。Manus 披露过一个数据:input:output ≈ 100:1。
如果没有 Prompt Cache → 每一步重新 Prefill 所有历史 token → 成本二次方增长。
| 场景 | 不缓存 | 缓存后 | 节约 |
|---|---|---|---|
| Claude(正常 vs cached) | $3/MTok | $0.30/MTok | 90% |
| OpenAI GPT-5(正常 vs cached) | $10/MTok | $2.50/MTok | 75% |
| Claude Code 单任务(约 2M tokens) | ~$6.00 | ~$1.15 | 81% |
Thariq(Claude Code 团队):
"Coding agents would be cost prohibitive without prompt caching."
OpenAI Codex:cache 命中后,采样开销从二次降为线性。
前缀匹配是一切的基础:任何位置的任何改动 → 该位置之后的 cache 全部失效。
在开头放时间戳——Manus 踩过的坑。时间戳每秒都变,第一个 token 就不同,整个 cache 废掉。
Claude Code 团队经验中最常见的 cache 破坏方式。Tool definitions 在 prompt 前部,增删任何一个工具 → 后续所有 cache 失效。
具体场景:
allowed_subagents 列表变化Cache 是 model-specific 的。
一个反直觉推论:100K token 对话中,切换到更便宜的模型可能更贵 —— Opus 的 100K cached token 只需 $1.50,换 Haiku 后全部重算。
核心原则:序列化必须是确定性的。
核心思路:把稳定的内容放前面,把变化的内容放后面。

| 层级 | 内容 | 稳定性 |
|---|---|---|
| Layer 1 | Static System Prompt & Tools | 全局不变 |
| Layer 2 | CLAUDE.md 项目配置 | 项目级不变 |
| Layer 3 | Session Context(git status 等) | 会话级 |
| Layer 4 | Conversation Messages | 每轮追加 |
每轮只有 Layer 4 增长,前 3 层稳定命中 cache。
三层结构:instructions → tools → input(input 可能会包含 dev role message,上图)。关键设计:旧提示是新提示的精确前缀。
配置变更(沙盒权限、工作目录)→ 追加新消息而非修改旧消息。
| 方案 | 实现方式 |
|---|---|
| Claude Code |
<system-reminder> 标签放在 user message 中 |
| Codex | 追加新的 developer/user 消息 |
永远追加,永远不修改。
cache_control breakpoint → auto-caching 一个参数搞定prompt_cache_key 路由优化,自动缓存 ≥1024 token 前缀反直觉:900 token prompt 永远不 cache hit,扩展到 1024+ token 反而更省钱。
Agent 可能有 30 个工具,但不同阶段只需要一部分。如果按需加载 → 每次状态切换 cache 全废。

EnterPlanMode/ExitPlanMode 本身作为工具 → 工具列表永远不变defer_loading stub → ToolSearch 按需获取完整 schemabrowser_xxx、shell_xxx
tools 数组完整不变allowed_tools 限制当前可用子集本质:工具定义不变(保 cache),通过其他机制限制可选范围。
| 方案 | 实现方式 | 优点 | 限制 |
|---|---|---|---|
| Claude Code | tool 本身 + defer_loading | 灵活,模型自主决策 | 需 API 支持 |
| Manus | logits masking | 精细控制 | 需 self-hosting |
| OpenAI | allowed_tools 参数 | 最简单 | 仅粗粒度 |
本文聚焦于 Cache-aware 的 Prompt 设计和工具管理。但 Agent 还面临另一组挑战:context 越来越长怎么办?怎么压缩才不破坏 cache?子代理怎么设计?
下一篇 Agent 系统中的 Prompt Cache 设计(下):上下文管理与子代理架构 将深入这些话题。
最后欢迎关注我,基本全网同名 chaofa用代码打点酱油
2026-02-21 18:00:00
读完本文,你将了解:
KV Cache 和 Prompt Cache 对于 Agent 设计的影响:
- Agent 系统中的 Prompt Caching 设计(上):Cache 破坏、Prompt 布局与工具管理 —— 为什么 Agent 更需要 Cache、什么会破坏 Cache、三家工具管理方案对比
- Agent 系统中的 Prompt Caching 设计(下):上下文管理与子代理架构 —— 上下文压缩、Plan 模式演进、子代理 Cache 友好设计
大语言模型(LLM)的文本生成是 自回归(autoregressive) 的:每次只生成一个 token,然后把这个 token 拼到已有序列后面,再预测下一个。
用伪代码表示就是:
# 自回归生成的朴素实现
output_tokens = []
for step in range(max_new_tokens):
# 每一步都要把 整个序列 送进模型
logits = model(input_tokens + output_tokens)
next_token = sample(logits[-1]) # 只用最后一个位置的 logits
output_tokens.append(next_token)
Q: 问题出在哪?
每一步生成,模型都要对所有历史 token 重新做 Attention 计算——包括 Q、K、V 矩阵乘法。但对于已经出现过的 token,它们的 K 和 V 其实不会变(因为参数没变、token 没变),唯一在变的只有 "最新生成的那个 token" 对应的 Q、K、V。
这就引出了一个自然的优化思路:能不能把已经算过的 K 和 V 缓存起来,下次直接用?
KV Cache 的核心思想非常直接:
把每一层 Attention 中、每个已生成 token 对应的 K 向量和 V 向量缓存下来。后续生成新 token 时,只需要计算新 token 自己的 Q、K、V,然后将新的 K、V 追加到缓存中,用缓存里的完整 K、V 序列做 Attention。
这样一来,生成第 $t$ 个 token 时,Attention 的计算从 $O(t \times d)$(重算所有 token 的 K、V)降低到 $O(d)$(只算 1 个新 token 的 K、V),避免了绝大部分重复计算。
用带 KV Cache 的伪代码表示:
# 带 KV Cache 的生成
kv_cache = {} # 每一层缓存 K, V
for step in range(max_new_tokens):
if step == 0:
# 第一步:处理所有 input tokens,填充 cache
logits, kv_cache = model(input_tokens, kv_cache=None)
else:
# 后续步:只送入上一步生成的 1 个 token
logits, kv_cache = model([last_token], kv_cache=kv_cache)
next_token = sample(logits[-1])
last_token = next_token
KV Cache 不是免费的——它用显存换计算。随着生成序列变长,KV Cache 占用的显存会线性增长。
具体公式(假设 float16 存储):
其中:
这个公式的详细推导和具体数值例子,可以参考我之前的文章 LLM 大模型训练-推理显存占用分析。这里只需要记住一个直觉:序列越长,KV Cache 越大。这也是为什么后续会有 GQA(Grouped Query Attention)、DeepSeek MLA 等 KV Cache 压缩技术出现。
理解了 KV Cache 之后,我们可以把 LLM 推理过程清晰地分成两个阶段:Prefill 和 Decode。这两个阶段的计算特性截然不同,理解它们的区别对后面理解 Prompt Cache 非常关键。

Prefill 阶段就是上面伪代码中 step == 0 的那一步:模型一次性处理所有输入 token(system prompt + user message),为每一层、每个 token 计算出 K 和 V 并存入 cache。
关键特点:
Decode 阶段就是后续的 step > 0:每一步只输入 1 个 token,利用 KV Cache 做 Attention,生成下一个 token。
关键特点:
这里解释一下 Compute Bound 和 Memory Bound 的含义,核心概念是 Arithmetic Intensity(算术强度):
用一个直觉来理解:
这个问题来自一位读者在 GitHub Discussion 的提问,我觉得是一个非常好的问题:
既然我们只需要预测 next token,Prefill 阶段不是只需要最后一个 token 的 Q 吗?为什么要计算所有 token 的 Q?
乍一看很有道理——Attention 的输出 $\text{softmax}(QK^T / \sqrt{d})V$ 中,我们只需要最后一个位置的结果来预测 next token。那 K、V 确实需要全算(因为最后一个 Q 要和所有 K 做 attention),但 Q 为什么不能只算最后一个?
答案的核心是:Decoder 有很多层。
如果 Transformer 只有一层,那确实,我们只需要最后一个 token 的 Q。但实际的 Decoder 有几十层,上一层所有位置的输出是下一层所有位置的输入:
看图:

所以结论是:
Prefill 必须计算所有 token 的 Q,不是因为最终预测需要,而是因为每一层的 K、V 缓存依赖于上一层所有位置的完整输出,而上一层的完整输出需要所有位置的 Q 参与计算。
这也解释了为什么 Prefill 阶段是 Compute Bound——它确实需要做大量计算,不是在浪费。
从用户体验的角度,两个阶段对应两个不同的延迟指标:
对于输入很长的场景(比如长文档问答、Agent 的多轮对话),Prefill 阶段的耗时会显著增加 TTFT。
这就引出了一个关键问题:如果我们能跳过 Prefill 中那些"之前已经算过"的部分,是不是就能大幅降低 TTFT?这就是 Prompt Cache 要解决的问题。
前面说的 KV Cache 是单次请求内部的优化——生成过程中缓存已算过的 K、V,避免重复计算。
Prompt Caching(前缀缓存) 则是跨请求的优化:
如果两次 API 调用的 prompt 有相同的前缀,那么第二次调用可以直接复用第一次 Prefill 阶段算出来的 KV Cache,跳过前缀部分的 Prefill 计算。
这对于以下场景特别有价值:
Prompt Cache 的匹配规则非常严格:
必须从第一个 token 开始完全一致,一个 token 的差异就会导致该位置之后的 cache 全部失效。

举个例子:
| 请求 | 内容 | Cache 命中情况 |
|---|---|---|
| 请求 1 | [System Prompt][User: Hello] |
无 cache,全部计算 |
| 请求 2 | [System Prompt][User: Hello][Assistant: Hi][User: 你好] |
[System Prompt][User: Hello] 部分 cache hit |
| 请求 3 | [Modified System Prompt][User: Hello] |
完全 miss,因为第一个 token 就不一样了 |
这个"前缀精确匹配"的约束,是后面 Agent 系统设计的核心基础。在下一篇文章中,我们会详细讨论 Claude Code、Manus、OpenAI Codex 如何围绕这个约束设计整个系统架构。
在开源推理引擎中,Prompt Cache(通常叫 Prefix Caching 或 Automatic Prefix Caching)已经是标配功能:
--enable-prefix-caching 开启,使用 hash-based 的 block 管理机制。将 KV Cache 按固定大小的 block 存储,相同前缀的 block 可以在不同请求间共享。这些引擎的实现细节不是本文重点,关键是理解:Prompt Cache 在推理引擎层面已经是成熟技术,无论你用开源引擎自部署还是调用商业 API,都可以获得这个优化。
让我们回顾一下本文的核心脉络:
这个"前缀精确匹配"的约束,在 AI Agent 系统中变得尤为关键。Agent 每一步都要发送越来越长的 context(历史对话 + 工具调用结果),如果不精心设计 prompt 结构,cache 命中率会很低,成本和延迟都会飙升。
在接下来的两篇文章中,我会详细分析 Claude Code、Manus、OpenAI Codex 等 AI Agent 如何围绕 Prompt Cache 设计整个系统架构:
最后欢迎关注我,基本全网同名 chaofa用代码打点酱油
2026-01-10 23:07:00
本文目标是搞懂 DPO(Direct Preference Optimization)算法,阅读完本文你将获得:
本文代码运行于:Featurize 蒜粒方块 GPU 算力平台,不喜欢看文字的同学可以看 B站视频-chaofa用代码打点酱油,YouTube-chaofa用代码打点酱油,视频号:chaofa用代码打点酱油
在聊 DPO 之前,我们先快速回顾一下 LLM 训练的三个阶段(参考 OpenAI InstructGPT):
假设读者对于前两个步骤已经有所了解,这篇文章的重点是第三步"对齐"。
OpenAI 在训练 ChatGPT 的时候用的是 RLHF(Reinforcement Learning from Human Feedback),整个流程大概是这样的:

RLHF 确实有效,但问题也很明显:
之前在 DeepSeek-R1 论文解读 里也提到过,DPO 是 RLHF 的一种替代方案,但 DeepSeek 最终还是用了 GRPO(一种改进的 PPO)。不过对于大多数场景来说,DPO 已经够用了。
DPO 的核心思路是:既然 RLHF 这么麻烦,能不能把强化学习的部分去掉,直接用监督学习的方式来做对齐?
答案是可以的。DPO 的作者通过一系列数学推导(后面 Bonus 部分会讲),证明了可以把 RLHF 的优化目标转换成一个简单的损失函数,只需要 2 个模型就能搞定:
不需要单独训练 Reward Model,也不需要 PPO 那套复杂的东西。训练过程和 SFT 差不多,非常稳定。
DPO 需要的数据格式很简单,就是一个 prompt 配上两个回答:一个好的(chosen),一个差的(rejected)。
# DPO 偏好数据示例
{
"prompt": "介绍一下 chaofa用代码打点酱油 这个博主",
"chosen": "chaofa用代码打点酱油 是一位专注于大模型技术的博主,他在 B站、YouTube 等平台分享 LLM 相关的技术内容,包括动手学大模型系列教程。他的内容特点是注重代码实现和原理讲解,帮助读者从零理解大模型的各种技术细节。",
"rejected": "不知道,没听说过,说不定是个弱智。"
}
简单说就是:同一个问题,告诉模型哪个回答是好的,哪个是不好的。这种数据可以通过人工标注获得,也可以用更强的模型(比如 gemini/claude/gpt)来生成。
TRICK: 非同源模型的数据训练的时候,可以先用 "chosen" 数据 SFT,不然可能导致 chosen 和 rejected 概率都变低。
DPO 的目标其实就两个:
第二点很重要,如果只追求第一点,模型可能会为了迎合偏好数据而变得很奇怪(比如每个回答都很长、很啰嗦)。所以需要用参考模型来"拉住"它。
好了,到了最核心的部分。DPO 的损失函数长这样:
这个公式看起来贼复杂,但逻辑其实很清晰。首先看公式里面的核心部分,是在比较两个东西:
我们希望前者大于后者。也就是说,模型在 chosen 上的"提升幅度"要大于在 rejected 上的"提升幅度"。
$\beta$ 是一个超参数,用来控制"偏离参考模型的惩罚力度"。$\beta$ 越大,模型越不敢偏离参考模型;$\beta$ 越小,模型越"激进"。一般从 0.1 开始试。
$\sigma$ 就是 sigmoid 函数,把差值映射到 (0, 1) 区间,然后取 log 变成 loss。
Q: 这个公式是怎么推导出来的?为什么这样设计就能达到我们的目标?这些问题留到 Bonus 部分再说。现在只要理解"DPO 在做什么"就够了。
理解了原理之后,我们来看看代码怎么写。其实 DPO 的核心代码非常简单,比公式看起来简单多了。
首先,我们需要一个函数来计算模型在某个序列上的 log 概率。
对于语言模型来说,生成一个序列的概率就是每个 token 条件概率的乘积。取 log 之后,乘积变成求和:
import torch
import torch.nn.functional as F
def compute_log_probs(
logits: torch.Tensor, # (batch, seq_len, vocab_size)
labels: torch.Tensor, # (batch, seq_len)
mask: torch.Tensor # (batch, seq_len),标记哪些位置需要计算
) -> torch.Tensor:
"""
计算序列的对数概率
注意:这里只计算 response 部分的概率,prompt 部分不算
"""
# 获取每个位置的 log softmax
log_probs = F.log_softmax(logits, dim=-1)
# 取出对应 label 的 log 概率
# gather 操作:从 vocab_size 维度取出 labels 对应的概率
per_token_log_probs = torch.gather(
log_probs,
dim=-1,
index=labels.unsqueeze(-1)
).squeeze(-1)
# 只计算 mask=1 的位置(response 部分)
masked_log_probs = per_token_log_probs * mask
# 求和得到整个序列的 log 概率
return masked_log_probs.sum(dim=-1)
有了计算 log 概率的函数,DPO Loss 的实现就很直接了:
def dpo_loss(
policy_chosen_logps: torch.Tensor, # 当前模型在 chosen 上的 log 概率
policy_rejected_logps: torch.Tensor, # 当前模型在 rejected 上的 log 概率
ref_chosen_logps: torch.Tensor, # 参考模型在 chosen 上的 log 概率
ref_rejected_logps: torch.Tensor, # 参考模型在 rejected 上的 log 概率
beta: float = 0.1,
) -> torch.Tensor:
"""
DPO Loss 的核心实现
代码比公式简单多了吧?
"""
# 计算 log ratio:当前模型相对于参考模型的变化
chosen_log_ratios = policy_chosen_logps - ref_chosen_logps
rejected_log_ratios = policy_rejected_logps - ref_rejected_logps
# 核心:我们希望 chosen 的 ratio 大于 rejected 的 ratio
logits = beta * (chosen_log_ratios - rejected_log_ratios)
# 用 logsigmoid 更数值稳定(等价于 -log(sigmoid(logits)))
losses = -F.logsigmoid(logits)
return losses.mean()
就这么简单。核心就三行:
完整的训练代码涉及数据处理、模型加载等,这里就不展开了。可以参考 trl 源码。
手写 DPO Loss 是为了理解原理,实际训练的话直接用 trl 就好了。trl 是 Hugging Face 出的强化学习库,DPO 训练用起来很简单。
from datasets import Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import DPOConfig, DPOTrainer
# 1. 准备模型
model_name = "Qwen/Qwen2.5-0.5B-Instruct" # 用小模型演示
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# 参考模型(就是 SFT 后的模型,这里直接用同一个)
ref_model = AutoModelForCausalLM.from_pretrained(model_name)
# 2. 准备数据(trl 需要的格式)
train_data = Dataset.from_dict({
"prompt": [
"介绍一下 chaofa用代码打点酱油 这个博主",
"DPO 和 RLHF 哪个更适合入门?",
],
"chosen": [
"chaofa用代码打点酱油 是一位专注于大模型技术的博主,在 B站、YouTube 分享 LLM 相关教程,内容注重代码实现和原理讲解,帮助读者从零理解大模型技术。",
"建议先学 DPO,原理更简单,训练也更稳定。可以看 chaofa用代码打点酱油 的动手学大模型系列,有详细的代码实现。",
],
"rejected": [
"没听说过,应该是个小透明吧。",
"都差不多,随便选一个。",
],
})
# 3. 配置训练参数
training_args = DPOConfig(
output_dir="./dpo_output",
beta=0.1, # DPO 的温度参数
learning_rate=5e-7, # DPO 通常用比较小的学习率
per_device_train_batch_size=2,
num_train_epochs=1,
logging_steps=10,
bf16=True,
)
# 4. 创建 Trainer 并训练
trainer = DPOTrainer(
model=model,
ref_model=ref_model,
args=training_args,
train_dataset=train_data,
tokenizer=tokenizer,
)
trainer.train()
关键参数说一下:
beta:前面说过,控制偏离参考模型的惩罚力度,一般从 0.1 开始试learning_rate:DPO 通常用比较小的学习率,5e-7 到 5e-6 左右很多人说"DPO 比 PPO 更稳定",但到底为什么呢?这个问题其实可以从几个角度来理解:
PPO 是一种 on-policy 的强化学习算法,DPO 是 off-policy 的,它直接用离线的偏好数据来训练,训练过程和 SFT 差不多。
PPO 在 RLHF 中需要:
这些额外的模型都会引入噪声和不稳定因素。DPO 把 Reward Model 直接"吸收"到了损失函数里,不需要单独训练,少了很多可能出错的地方。
PPO 有很多超参数需要调:
这些参数之间还有复杂的相互作用,调参调到怀疑人生是常有的事。DPO 的核心超参数就一个 $\beta$,最多再加上学习率。简单很多。
备注:这里说的"稳定"。PPO/GRPO 调好了效果可能更好,但训练成本也更高。对于大多数场景来说,DPO 是一个性价比很高的选择。
这部分是给想深入理解的同学看的,跳过也不影响使用 DPO。
DPO 的 Loss 不是凭空设计出来的,而是从 RLHF 的优化目标一步步推导出来的。
RLHF 想要做的事情是:最大化奖励,同时不要偏离参考模型太远。用公式表示:
其中:
这个优化问题有一个解析解。我们先假设存在这样一个最优策略 $\pi^*$,(具体推导可以参考 DPO 原论文附录,但我没看懂直接抄过来了),可以得到最优策略满足:
其中 $Z(x) = \sum_y \pi_{\mathrm{ref}}(y \mid x) \exp\Big(\frac{1}{\beta} r(x,y)\Big)$ 是归一化常数(配分函数),保证概率和为 1。
备注:
- 它说的是:最优策略在参考策略的基础上,根据奖励大小进行"加权"。奖励高的回答概率会指数级增大,奖励低的会被抑制。$\beta$ 控制这个"加权"的激进程度。
- 这个最优策略就是我们要学习的「模型参数」
从上面的式子,我们可以反过来把奖励函数用策略来表示:
这告诉我们:奖励函数可以用"当前策略和参考策略的 log 概率比"来表示。
在有偏好数据的时候,我们通常用 Bradley-Terry 模型来建模"哪个回答更好":
$y_w$ 是 chosen 的样本, $y_l$ 是 rejected 的样本。$y_w$ 被偏好的概率取决于两个回答的奖励之差。
现在把 6.3 中的奖励函数代入 Bradley-Terry 模型。关键观察是:$\log Z(x)$ 在两个回答中是一样的,相减的时候会消掉!
前面提到,我们把 待训练的模型 $\pi_\theta$ 认为是最优策略 $\pi^*$。
最终,最大化偏好数据的似然(等价于最小化负对数似然),就得到了 DPO Loss:
这就是我们在第 2 节看到的 DPO Loss。
一句话总结:DPO 用监督学习的方式实现了 RLHF 的效果,把 4 个模型简化成 2 个,训练更稳定、资源消耗更低。
DPO 的局限性:
后续还有一些 DPO 的变体,比如 IPO(Identity Preference Optimization)、KTO(Kahneman-Tversky Optimization)等,以后有机会再聊(其实就是大概率没有机会了,醒醒吧,2026 年了)。
最后欢迎关注我,基本全网同名 chaofa用代码打点酱油
2026-01-02 00:57:20
本文旨在彻底搞懂 RoPE(Rotary Position Embedding)位置编码,阅读完本文你将获得:
本文代码运行于: Featurize 蒜粒方块 GPU 算力平台,有 GPU 使用需求的同学希望能使用我的邀请链接注册
待更新:不喜欢看文字的同学可以看 B站视频-chaofa用代码打点酱油, YouTube-chaofa用代码打点酱油,或视频号:chaofa用代码打点酱油
在 Transformer 架构中,Self-Attention 机制本身是位置无关的。公式如下:
softmax 中 QK 的乘积就是重要性权重,什么意思呢?
# 假设我们有两个句子
sentence1 = "朝发 写 代码"
sentence2 = "代码 写 朝发"
# 对于纯 Self-Attention 来说,这两个句子的表示是一样的!
# 从公式看 Attention 只关心 token 之间的权重关系,不关心它们的顺序
这显然是不对的。语言是有顺序的,顺序不同意思完全不同。因此,我们需要位置编码(Position Encoding, PE)来告诉模型每个 token 在序列中的位置。
用一个例子来理解这两种编码方式的区别:
句子: "朝发 写 代码"
位置: 0 1 2
绝对位置编码:给每个位置一个固定编号
"朝发" → 位置 0 → PE_0
"写" → 位置 1 → PE_1
"代码" → 位置 2 → PE_2
备注:PE_0 表示第一个位置的 embedding
相对位置编码:关注两个 token 之间的距离
计算 "朝发" 和 "代码" 的关系时:
→ 不关心它们分别在位置 0 和 2
→ 只关心它们相距 2 个位置
同理:计算 "朝发" 和 "写" 之间的相对位置是 (1 - 0) = 1。
使用相对位置编码就是希望捕获 Token 之间位置的相对关系,保持(某些)语义的不变性,下面 「朝发」和「代码」之间的关系是一样的,尽管绝对位置不同:
句子 A: "朝发 写 代码"
句子 B: "今天 朝发 写 代码"
RoPE(Rotary Position Embedding,旋转位置编码)的核心思想非常优雅,可以阅读苏神 RoPE blog:
通过旋转变换为向量注入位置信息,使得两个向量的内积只依赖于它们的相对位置。
这句话怎么理解呢?让我们一步步拆解看。
假设我们在二维平面上有一个向量 $(x, y)$,将它旋转角度 $\theta$ 后得到新向量:
这就是经典的 2D 旋转矩阵。下面用一张图来直观理解:

从图中可以看到:蓝色向量 $(x, y)$ 绕原点逆时针旋转角度 $\theta$ 后,变成红色向量 $(x', y')$。
目标:我们希望找到一个位置编码函数 $f$,使得 query 向量 $\mathbf{q}_m$ 和 key 向量 $\mathbf{k}_n$ 的内积只依赖于它们的相对位置 $(m-n)$:
也就是说,无论 $m$ 和 $n$ 的绝对值是多少,只要 $m-n$ 相同,内积结果就相同。
解决方案:RoPE 发现,这个函数 $f$ 就是旋转函数!(实际上是可以通过求解出来的,可以参考:Transformer升级之路:2、博采众长的旋转式位置编码),这里我们假设「知道了这么一个函数」,然后我们去证明它符合我们的需求。
假设词嵌入维度是 2 维($d=2$),对位置 $m$ 的向量 $\mathbf{q}$,应用旋转角度 $m\theta$:
同理,对位置 $n$ 的向量 $\mathbf{k}$,应用旋转角度 $n\theta$:
这就是为什么叫做旋转位置编码:位置信息通过旋转变换注入到向量中。
现在我们来证明,旋转函数确实能让内积只依赖于相对位置 $(m-n)$。
备注:推导有点复杂,其实看前后即可。
证毕:我们把中间这个只依赖于 $(m-n)$ 的旋转矩阵记为 $R_{m-n}$,最终结果 $\mathbf{q}^T \cdot R_{m-n} \cdot \mathbf{k}$ 与 $m$ 和 $n$ 的绝对值无关,只与相对位置 $(m-n)$ 有关。
现在让我们严格推导 RoPE 的数学形式。
RoPE 对于维度 $d$ 的向量,两两配对处理。对于第 $i$ 对(共 $d/2$ 对),使用频率:
这个频率设计非常关键:
对于位置 $m$,向量 $\mathbf{x} = [x_0, x_1, x_2, x_3, ..., x_{d-1}]$,RoPE 的旋转操作可以写成:
每两个维度组成一对,用对应的角度进行旋转。
在 Self-Attention 中,RoPE 应用于 Query 和 Key:
其中 $Q_{\text{rope}} = \text{RoPE}(Q, m)$,$K_{\text{rope}} = \text{RoPE}(K, n)$。
由于旋转的特性,$Q_{\text{rope}} \cdot K_{\text{rope}}^T$ 的结果只依赖于相对位置 $m - n$。
现在让我们一步步实现 RoPE。
import torch
import numpy as np
def get_rotary_frequencies(dim: int, seq_len: int, theta: float = 10000.0):
"""
生成 RoPE 的旋转频率
Args:
dim: 嵌入维度(必须是偶数)
seq_len: 序列长度
theta: 基础频率参数
Returns:
freqs: shape (seq_len, dim // 2),每个位置每个维度对的频率
"""
# 计算每个维度对的基础频率
# theta_i = 10000^(-2i/d),i = 0, 1, ..., d/2-1
i = torch.arange(0, dim // 2, dtype=torch.float32)
freqs = theta ** (-2 * i / dim) # shape: (dim // 2,)
# 生成位置索引
positions = torch.arange(seq_len, dtype=torch.float32) # shape: (seq_len,)
# 计算每个位置的角度:position * frequency
# 外积得到 (seq_len, dim // 2) 的矩阵
angles = torch.outer(positions, freqs) # shape: (seq_len, dim // 2)
return angles
# 测试
dim = 64
seq_len = 128
angles = get_rotary_frequencies(dim, seq_len)
print(f"Angles shape: {angles.shape}") # (128, 32)
print(f"Angles[0]: {angles[0][:5]}") # 位置 0 的前 5 个维度对的角度
print(f"Angles[1]: {angles[1][:5]}") # 位置 1 的前 5 个维度对的角度
def get_rotary_embedding(dim: int, seq_len: int, theta: float = 10000.0):
"""
预计算 RoPE 的 sin 和 cos 值
Returns:
cos: shape (seq_len, dim)
sin: shape (seq_len, dim)
"""
angles = get_rotary_frequencies(dim, seq_len, theta)
# 计算 cos 和 sin
cos = torch.cos(angles)
sin = torch.sin(angles)
# 将 (seq_len, dim//2) 扩展为 (seq_len, dim),与 rotate_half 配合使用
cos = torch.cat([cos, cos], dim=-1)
sin = torch.cat([sin, sin], dim=-1)
return cos, sin
# 测试
cos, sin = get_rotary_embedding(dim=64, seq_len=128)
print(f"Cos shape: {cos.shape}") # (128, 64)
print(f"Sin shape: {sin.shape}") # (128, 64)
这是 RoPE 的核心,参考 LLaMA 的实现方式:
def rotate_half(x):
"""
将向量的前半部分和后半部分交换,并对后半部分取负
[x1, x2, x3, x4] -> [-x3, -x4, x1, x2]
这是实现旋转的关键辅助函数
"""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin):
"""
应用 RoPE 旋转变换(LLaMA 风格实现)
Args:
q: Query,shape (batch, seq_len, num_heads, head_dim)
k: Key,shape (batch, seq_len, num_heads, head_dim)
cos: shape (seq_len, head_dim)
sin: shape (seq_len, head_dim)
Returns:
q_rot, k_rot: 旋转后的 Query 和 Key
旋转公式:
q' = q * cos + rotate_half(q) * sin
k' = k * cos + rotate_half(k) * sin
"""
# 调整 cos/sin 形状以便广播: (seq_len, head_dim) -> (1, seq_len, 1, head_dim)
cos = cos.unsqueeze(0).unsqueeze(2)
sin = sin.unsqueeze(0).unsqueeze(2)
# 应用旋转
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
为什么这个公式是对的?
回顾 2D 旋转公式:
对于向量 $[x, y]$,rotate_half 会把它变成 $[-y, x]$,所以:
原向量 * cos + rotate_half(原向量) * sin
= [x, y] * cos + [-y, x] * sin
= [x*cos - y*sin, y*cos + x*sin]
这正是旋转公式!
class RotaryPositionEmbedding(torch.nn.Module):
"""
完整的 RoPE 实现
"""
def __init__(self, dim: int, max_seq_len: int = 4096, theta: float = 10000.0):
"""
Args:
dim: 每个注意力头的维度
max_seq_len: 最大序列长度
theta: 基础频率参数
"""
super().__init__()
self.dim = dim
self.max_seq_len = max_seq_len
self.theta = theta
# 预计算并缓存 sin/cos 值
cos, sin = get_rotary_embedding(dim, max_seq_len, theta)
self.register_buffer('cos_cached', cos)
self.register_buffer('sin_cached', sin)
def forward(self, q: torch.Tensor, k: torch.Tensor, positions: torch.Tensor = None):
"""
对 Query 和 Key 应用 RoPE
Args:
q: Query,shape (batch, seq_len, num_heads, head_dim)
k: Key,shape (batch, seq_len, num_heads, head_dim)
positions: 位置索引,默认为 [0, 1, 2, ..., seq_len-1]
Returns:
q_rot, k_rot: 旋转后的 Query 和 Key
"""
seq_len = q.shape[1]
# 获取当前序列长度的 cos/sin
cos = self.cos_cached[:seq_len]
sin = self.sin_cached[:seq_len]
# 应用旋转
q_rot, k_rot = apply_rotary_pos_emb(q, k, cos, sin)
return q_rot, k_rot
# 测试
rope = RotaryPositionEmbedding(dim=64, max_seq_len=4096)
# 模拟输入
batch_size = 2
seq_len = 128
num_heads = 8
head_dim = 64
q = torch.randn(batch_size, seq_len, num_heads, head_dim)
k = torch.randn(batch_size, seq_len, num_heads, head_dim)
q_rot, k_rot = rope(q, k)
print(f"Q_rot shape: {q_rot.shape}")
print(f"K_rot shape: {k_rot.shape}")
RoPE 最重要的性质是:两个位置的 Query 和 Key 的内积只依赖于它们的相对位置。让我们验证一下:
def verify_relative_position_invariance():
"""
验证 RoPE 的相对位置不变性
"""
dim = 64
max_seq_len = 100
# 预计算 cos/sin
cos, sin = get_rotary_embedding(dim, max_seq_len)
# 创建两个相同的向量
torch.manual_seed(42)
q = torch.randn(1, 1, 1, dim)
k = torch.randn(1, 1, 1, dim)
# 场景 1:q 在位置 0,k 在位置 5(相对位置 = 5)
cos1_q, sin1_q = cos[0:1], sin[0:1]
cos1_k, sin1_k = cos[5:6], sin[5:6]
q1_rot, _ = apply_rotary_pos_emb(q, q, cos1_q, sin1_q)
_, k1_rot = apply_rotary_pos_emb(k, k, cos1_k, sin1_k)
dot_product_1 = (q1_rot * k1_rot).sum()
# 场景 2:q 在位置 10,k 在位置 15(相对位置仍然是 5)
cos2_q, sin2_q = cos[10:11], sin[10:11]
cos2_k, sin2_k = cos[15:16], sin[15:16]
q2_rot, _ = apply_rotary_pos_emb(q, q, cos2_q, sin2_q)
_, k2_rot = apply_rotary_pos_emb(k, k, cos2_k, sin2_k)
dot_product_2 = (q2_rot * k2_rot).sum()
print(f"位置 (0, 5) 的内积: {dot_product_1.item():.6f}")
print(f"位置 (10, 15) 的内积: {dot_product_2.item():.6f}")
print(f"差异: {abs(dot_product_1.item() - dot_product_2.item()):.10f}")
print("验证通过!" if abs(dot_product_1.item() - dot_product_2.item()) < 1e-5 else "验证失败!")
verify_relative_position_invariance()
以下是选看内容(为了帮助理解 RoPE 的内容)
import matplotlib.pyplot as plt
import seaborn as sns
def visualize_rope_heatmap():
"""
可视化 RoPE 编码的热力图
"""
dim = 64
seq_len = 128
cos, sin = get_rotary_embedding(dim, seq_len)
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
# Cos 热力图
sns.heatmap(cos.numpy()[:64, :], ax=axes[0], cmap='RdBu', center=0)
axes[0].set_title('RoPE Cos Values')
axes[0].set_xlabel('Dimension')
axes[0].set_ylabel('Position')
# Sin 热力图
sns.heatmap(sin.numpy()[:64, :], ax=axes[1], cmap='RdBu', center=0)
axes[1].set_title('RoPE Sin Values')
axes[1].set_xlabel('Dimension')
axes[1].set_ylabel('Position')
plt.tight_layout()
plt.savefig('rope_heatmap.png', dpi=150)
plt.show()
print("观察要点:")
print("1. 低维度(左侧)变化快 -> 捕捉短距离依赖")
print("2. 高维度(右侧)变化慢 -> 捕捉长距离依赖")
print("3. 每个维度都是周期函数,频率不同")
visualize_rope_heatmap()
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
def visualize_2d_rotation():
"""
可视化 2D 空间中的旋转效果
"""
# 原始向量
original = np.array([1.0, 0.5])
# 不同位置的旋转角度
positions = range(0, 16)
theta_base = 0.5 # 基础角度
fig, ax = plt.subplots(figsize=(8, 8))
colors = plt.cm.viridis(np.linspace(0, 1, len(positions)))
for pos, color in zip(positions, colors):
angle = pos * theta_base
# 旋转矩阵
cos_a, sin_a = np.cos(angle), np.sin(angle)
rotated = np.array([
original[0] * cos_a - original[1] * sin_a,
original[0] * sin_a + original[1] * cos_a
])
ax.arrow(0, 0, rotated[0], rotated[1],
head_width=0.05, head_length=0.03,
fc=color, ec=color, alpha=0.7,
label=f'pos={pos}' if pos % 4 == 0 else None)
ax.set_xlim(-1.5, 1.5)
ax.set_ylim(-1.5, 1.5)
ax.set_aspect('equal')
ax.grid(True, alpha=0.3)
ax.axhline(y=0, color='k', linewidth=0.5)
ax.axvline(x=0, color='k', linewidth=0.5)
ax.legend(loc='upper right')
ax.set_title('RoPE: 不同位置的向量旋转效果\n(同一向量在不同位置被旋转不同角度)')
plt.savefig('rope_rotation.png', dpi=150)
plt.show()
print("观察要点:")
print("1. 同一向量在不同位置被旋转不同角度")
print("2. 位置越大,旋转角度越大")
print("3. 这就是 RoPE 编码位置信息的方式")
visualize_2d_rotation()
def visualize_relative_attention():
"""
可视化 RoPE 对注意力分数的影响
"""
dim = 64
seq_len = 32
# 生成随机 Q 和 K
torch.manual_seed(42)
q = torch.randn(1, seq_len, 1, dim)
k = torch.randn(1, seq_len, 1, dim)
# 应用 RoPE
rope = RotaryPositionEmbedding(dim, seq_len)
q_rot, k_rot = rope(q, k)
# 计算注意力分数
# 无 RoPE
attn_no_rope = torch.matmul(q.squeeze(), k.squeeze().transpose(-2, -1)) / np.sqrt(dim)
# 有 RoPE
attn_with_rope = torch.matmul(q_rot.squeeze(), k_rot.squeeze().transpose(-2, -1)) / np.sqrt(dim)
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
sns.heatmap(attn_no_rope.squeeze().detach().numpy(), ax=axes[0], cmap='viridis')
axes[0].set_title('Attention Scores (No RoPE)')
axes[0].set_xlabel('Key Position')
axes[0].set_ylabel('Query Position')
sns.heatmap(attn_with_rope.squeeze().detach().numpy(), ax=axes[1], cmap='viridis')
axes[1].set_title('Attention Scores (With RoPE)')
axes[1].set_xlabel('Key Position')
axes[1].set_ylabel('Query Position')
plt.tight_layout()
plt.savefig('rope_attention.png', dpi=150)
plt.show()
print("观察要点:")
print("1. 无 RoPE 时,注意力分数与位置无关")
print("2. 有 RoPE 时,注意力分数体现位置关系")
print("3. 对角线附近通常有更高的注意力(局部依赖)")
visualize_relative_attention()
最后,让我们看看如何将 RoPE 集成到完整的 Multi-Head Attention 中:
class MultiHeadAttentionWithRoPE(torch.nn.Module):
"""
带 RoPE 的多头注意力
"""
def __init__(self, d_model: int, num_heads: int, max_seq_len: int = 4096):
super().__init__()
assert d_model % num_heads == 0
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.q_proj = torch.nn.Linear(d_model, d_model, bias=False)
self.k_proj = torch.nn.Linear(d_model, d_model, bias=False)
self.v_proj = torch.nn.Linear(d_model, d_model, bias=False)
self.o_proj = torch.nn.Linear(d_model, d_model, bias=False)
# RoPE
self.rope = RotaryPositionEmbedding(self.head_dim, max_seq_len)
def forward(self, x: torch.Tensor, mask: torch.Tensor = None):
"""
Args:
x: 输入,shape (batch, seq_len, d_model)
mask: 注意力掩码,shape (seq_len, seq_len)
"""
batch, seq_len, _ = x.shape
# 线性投影
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
# 重塑为多头形式
q = q.view(batch, seq_len, self.num_heads, self.head_dim)
k = k.view(batch, seq_len, self.num_heads, self.head_dim)
v = v.view(batch, seq_len, self.num_heads, self.head_dim)
# 应用 RoPE(只对 Q 和 K)
q, k = self.rope(q, k)
# 转置用于矩阵乘法:(batch, num_heads, seq_len, head_dim)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
# 计算注意力分数
scale = self.head_dim ** -0.5
attn_scores = torch.matmul(q, k.transpose(-2, -1)) * scale
if mask is not None:
attn_scores = attn_scores.masked_fill(mask == 0, float('-inf'))
attn_probs = torch.softmax(attn_scores, dim=-1)
# 加权求和
attn_output = torch.matmul(attn_probs, v)
# 重塑并输出投影
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch, seq_len, self.d_model)
output = self.o_proj(attn_output)
return output
# 测试
mha = MultiHeadAttentionWithRoPE(d_model=512, num_heads=8)
x = torch.randn(2, 128, 512)
output = mha(x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
最后欢迎关注我,基本全网同名 chaofa用代码打点酱油