MoreRSS

site iconChaofa Yuan修改

大模型算法工程师, 写过《LLMs-Zero-to-Hero》
请复制 RSS 到你的阅读器,或快速订阅到 :

Inoreader Feedly Follow Feedbin Local Reader

Chaofa Yuan的 RSS 预览

2025,浪潮与拧巴

2026-02-24 00:07:00

1. 乎乎:爸爸还有待提升

折腾妈妈九个多月的乎乎出生了,我很开心。尤其是前段时间,抱着她,转过头的一瞬间突然叫一声「爸」,顿时有些恍惚,原来不知不觉间已经能隐约发出「爸」的声音了。

一个月前,点点说:“在我们决定要娃的时候,你说,「乎乎出生后,我所有的业余时间都会用来陪乎乎」,但你没有,你不是一个好爸爸。” 事实确实如此,2025 年是我自工作以来最忙碌的一段时间,各方面的事情都非常的多,我没有很多的时间陪乎乎导致点点有一周很崩溃,离家出走。我也因各种事情搞得头昏脑胀的,一度觉得只有不上班或者离婚才能解决问题[1]

我尝试过每天通勤回龙岗,但是坚持不了几天身体就承受不住了,每天上班都脑袋昏昏沉沉;也试了带睡几天,可是心脏直砰砰的跳,根本睡不着,生怕要自己先猝死了。我时常想,其他人是怎么平衡工作和家庭的呢?是我精力太差了,还是其他人精力太好了呢?

我不是一个好爸爸,按照点点平时给我的打分,我只有 B-。过年放假的时候倒是花了更多的时间陪孩子,但如果工作忙起来,我周末还能有能量这样吗?希望 26 年能到 A,让乎乎和点点都给我打分,靠更频繁的奖励信号纠正爸爸不当的行为。

2. 工作:More Agency

第一个部分其实也提到了工作,似乎核心就一个字「忙」,但真有这么忙吗?也许是有的,有几个离职的同事都表示变好了。但忙出了什么东西吗?我觉得也是有的,团队的业务产出其实都挺好的,当然 2025 年全年来说,我个人产出也是超出自己预期的[2],也有更多的思考🤔。

这一年,我自己的工作主线是:Agent 落地。主要的目标是:给模型更多的「自主决策空间,Let Agent to have More Agency」,主要的方式也和业界主流发展方式一致,围绕着两点:「更好的 Context」以及「给定 Context 下更好的执行」。由于今年是我做公开表达更多的一年,因此对于业界发展的跟进是比较快的,比如近期 Claude Code / Codex 都发表了一些关于 Prompt Caching 对于 Agent 设计的影响,我春节期间也写了三篇文章来介绍「Agent 系统中的 Prompt Caching 设计」。

此外,今年也暴露了自己比较大的问题:更擅长自己执行而不是协作。因为今年自己作为部分子项目 Owner,需要对于项目的进度有更强的规划和跟进能力。尽管现在都在强调 AI 时代要充分发挥个体能力,但是就目前(2026-02-23)而言,我觉得协作能力还是挺重要的,因为协作是意味着多线程的管理和表达能力,这在 AI 时代更是放大人与人区别的关键点。

其实大模型越来越强,我越来越焦虑,牛逼的人已经产出 100X,而自己的效率却没有本质变化。我仿佛看到了上个改革开放年代或者更近的所谓的移动互联网腾飞之年,有大量的人攫取了巨大的财富,仿佛是个人都能吃上所谓的时代红利,但大多数人,浪潮过后好像也没什么变化

而这一次,AI 巨浪滚滚向前,又创造了一大波造富神话,身处其中,感觉每天都是翻天覆地的变化,各种新的产品怎么也跟踪不过来;但似乎:每一个好像都和我没什么太大的关系,我能提升效率吗?我能赚更多的钱吗?我能更快到达自由的彼岸吗?所以更大的可能性是:浪潮过后,自己什么也没得到,就像哪些在历史机遇中平平淡淡的大多数人一样。

3. 公开表达:生与死

3.1 技术表达永生

去年,我在年终总结的时候写:2024年是公开表达元年。是的,2025 年,我尝试了更多的个人表达:

  • 制作 B 站大模型相关教学视频共 20 个,累计播放超过 100W,获得 33k 用户关注,且收获了非常非常的鼓励和赞同,已经远远远超出了预期。
    • 立个 FLAG:我一定要做个用户墙来进行自我激励
  • 技术 BLOG 一共 12 篇,平均下来每个月产出一篇低质量博客(苏神每年每个月产出高质量的四篇,我真的非常非常的佩服,求真务实的学习榜样)。尽管我的 blog——chaofa 用代码打点酱油相对于「科学空间」具有巨大的差距,但是依然得到了很多大佬的认同,比如
    • llama_factory 的作者给我发消息表示赞同
    • 英伟达某负责人加我微信说文章写得不错
    • 公司内很多人(其中有好几个都是团队负责人)私信我说看过我的文章
    • GitHub 有 4k+ Star
    • 和大佬们有了一丝微不足道的「表层交集」,此类有不少,我感觉心里很满足
  • 坚持了 10 个月每个月都写「个人月度总结」,从 2 月份到 11 月份。但由于家庭、工作、其他琐事的变多,我实在是精疲力尽,无法坐在电脑前敲下 2025 年 12 月和 2026 年 1 月的月度总结[3]

3.2 商业化之死

我在多次的月度总结中提到「商业化」,比如:2025-08-孙宇晨真的很值得学习2025-09-合法赚钱就是高尚的2025-10-一个程序员对自媒体商业化的深度复盘等,如果不明所以的朋友可能会觉得赚了很多钱,但事实远非如此,加起来半个月工资都不到,可以说这方面是比较失败的。

但也不是完全没有收获,我觉得通过视频商业化的尝试,我理解了很多的商业化行为,对于完整的商业化闭环也有了更深入的思考——打工是一个「期望风险更低」「期望收获更高」的商业化行为[4]

由于前面的尝试以及认知,我不能再花时间去接所谓的商单了,我觉得商单确实是一种毒药,看上去好像赚了点小钱(见第一段),但实际上挺麻烦的,付出和收入不成正比。还是需要找到自己核心的竞争力和核心产品,才可能真正占用尽量少业余时间获得更高的复利。

此外,由于后面商单的出现,我甚至开始有点羞于向别人说:我做了一个技术频道叫做 chaofa 用代码打点酱油,对标油管的 Andrej Karpathy[5]。所以 2026 年 1-2 月的时候,有 4 个新商单找我,我都拒绝了[6],我还是想做更纯粹的表达(当然肯定是想赚钱的,所以还在拧巴中)。

所以以后会怎么样呢?暂时还不知道🤷‍♀️

4. 回顾与展望

4.1 回顾去年

4.2 展望 26

工作自不必多说,依然是2026 最需要重点投入的事情,要积极跟进前沿,与 LLM 多多探讨业务、技术的发展,争取在工作上有进一步的突破。

另外,生产变革也已经发生,生产方式已经发生了巨变,尤其是已经看到非常多的人在 AI 的加持下做出了让人瞩目的成绩,所以 2026 要更加彻底的拥抱 AI Coding,应该说在创造一些自己的 Product,而不是隔岸观火。

所以我斥巨资买了一个域名叫做:ApeCode.ai,代码都会放到 github.com/ApeCodeAI 下,Slogan 想了好多有意思的:

  • Ape Code, You Sleep (我最喜欢这个)
  • Ape Code, Not You
  • Unleash Your Inner Builder
  • Even Ape can BUILD PRODUCTS with the help of AI

image.png


写于:2026 年 2 月 23 日 20:38:56 新年春节假期返工前一天晚

Ref


  1. 实际上我和点点几乎没有什么矛盾,目前也没有太大的经济压力,只是人都会有情绪崩溃的时候,造成一些不可思议的想法。「不是客观上带孩子时间的问题,而是在带孩子的情绪价值和参与度上的问题」导致点点觉得我投入不够。 ↩︎

  2. 工作这么多年,虽然也有绩效不错的时候,但我很少自己给自己评价超出预期。不过这一年,不管别人怎么看,我自己是尽全力了,业务产出也还不错。另外,由于 25 年也作为面试官面试了非常多的候选人,让我对各种事情有了更加深刻的认识,也让我更加坚定的要建立一套自己的评估体系。 ↩︎

  3. 我记得 11 月度总结的时候,我说:还有一个月就 2025 年全勤了,没想到最后功亏一篑了。有朋友可能会问,AI 写得比你写得好多了,为什么不用 AI?我的答案是:这一类个人思考和总结的东西,我不想用 AI,因为本来是写给自己看的,如果我不知道「下笔的时刻我在思考什么」,那么写出的文章又有什么意义呢?思考过程的本身比结果更让人着迷。要是对此感兴趣的同学也可以关注:公众号——chaofa 用代码打点酱油 ↩︎

  4. 正好对应了我一直说的,打工一定要投入更多的时间。公开表达只能是业余时间中 20% 的精力,因为收益真的没想象中高。自媒体是一个极度放大幸存者偏差的地方。Keep it in mind. ↩︎

  5. 这是最初的愿景,因为我也是完全从零手写代码/或者读一些前沿的论文,这也是吸引很多「专业」的同行的原因,因此我的视频受众有非常多各种大厂在职的算法工程师。 ↩︎

  6. 26 年 1 月份的发的两个视频都是元旦期间做的,都是很早很早以前接的,只是平常时间做 ↩︎

Agent 系统中的 Prompt Caching 设计(下):上下文管理与子代理架构

2026-02-22 23:06:00

0. 阅读收获 (takeaway)

读完本文,你将了解:

  • Context Rot(上下文腐烂):为什么更大的 context window 不是万能解药
  • Cache-Safe Compaction:如何在压缩 context 时不破坏 cache
  • Plan 模式的演进:从 todo.md 到专门 planner agent
  • 文件系统 & Just-in-Time Context:Agent 的"延展记忆"
  • 子代理的 Cache 友好设计:90%+ prefix reuse 是怎么做到的
  • The Bitter Lesson:哪些设计是持久的,哪些会被模型进步淘汰

前置知识:本文是 Agent 系统中的 Prompt Caching 设计(上) 的续篇。如果你还没读过,建议先了解 Cache 破坏机制、Prompt 布局和工具管理策略。更基础的概念见 理解 KV Cache 与 Prompt Caching

1. Context Rot:上下文腐烂

Anthropic 的研究指出:随着 context 中 token 增加,模型注意力分散,性能下降

Attention 中每个 token 要和所有其他 token 建立 $n^2$ 的 pairwise 关系。Context 越长 → 每个 token 的"注意力预算"越少 → 模型可能"忘记"早期的重要指令,或被大量 tool output 稀释关键信息。

更大的 context window 不是万能解药。 能塞进去不代表模型能有效利用。

这就是为什么 Agent 不能简单地把所有信息堆进 context——我们需要主动管理。

2. Compaction:Cache-Safe 的上下文压缩

压缩是解决 context 增长的关键手段,但必须是 cache-safe 的。

2.1 Claude Code 的 Cache-Safe Compaction

Cache-Safe Compaction 流程

Claude Code 的压缩策略非常精巧:

  • 压缩请求使用完全相同的 system prompt + tools + 对话前缀
  • 只在末尾追加 compaction prompt
  • 这样压缩请求本身就能复用父会话的 cache
  • 预留 "compaction buffer"——context 快满之前就开始压缩

2.2 OpenAI Codex 的 /responses/compact

Codex 提供专门的 API 端点:

  • 返回压缩后的 item 列表 + encrypted compaction 项目
  • 保留模型对原始对话的"潜在理解"(指的是把 summary 内容放到上下文中)
  • 超过 auto_compact_limit 自动触发

两者共同点:压缩是必要的,但压缩过程本身不能破坏已有的 cache。

3. 注意力操纵:Plan 模式的演进

Agent 如何确保模型"聚焦"在正确的事情上?三家有一个有趣的演进过程。

3.1 Manus 的演进

  • 初期用 todo.md → 约 1/3 actions 浪费在更新 todo 上!
  • 最新:专门的 planner agent 替代 → 效率大幅提升

3.2 Claude Code 的 Plan Mode

  • 独立规划阶段 → 用户审批 → 再执行
  • 模型可自主调用 EnterPlanMode

3.3 Codex 的 update_plan

  • 执行中的一个工具
  • 无需用户审批,更轻量
方案 独立阶段? 用户审批? 自主控制?
Manus planner agent 独立 agent 无需审批 Agent 决定
Claude Code Plan Mode 独立阶段 需要审批 模型可自主进入
Codex update_plan 也有独立阶段 无需审批 执行中随时调用

4. 保留错误内容

一个有趣的共识:不要删除失败的 action 和 observation。

Manus 不会从 context 中删除失败的工具调用结果。双重好处:

  1. 保持 append-only → 保护 cache
  2. 模型从错误中学习 → 调整后续策略

错误恢复是真正 agentic 行为的标志。 看不到自己犯的错,怎么学会避免?

5. 文件系统 & Just-in-Time Context

Agent 不需要把所有信息都放在 context 里——文件系统可以作为"延展记忆"。

5.1 Manus 的文件系统策略

  • 文件系统当"无限 context":执行结果写入文件,context 只保留引用
  • Full vs Compact 表示:新结果保留完整内容(文件读写的结果),旧结果替换为文件路径引用(压缩的时候,这时候看上去已经破坏了 prompt caching)。来源于:Manus webinar notes (2025.10): Context Reduction/Isolation/Offloading
    • 备注,这里有点没看懂。「括号()部分的内容是我加的,是我个人的理解」。
  • 压缩后通过重新读取文件恢复信息
  • MCP 工具通过 CLI 在沙盒执行,避免工具列表膨胀

5.2 Just-in-Time 检索

方案 预加载 按需检索
Claude Code CLAUDE.md glob/grep 搜索文件系统
Codex AGENTS.md shell 工具探索
Anthropic 建议 最少必要信息 JIT 检索

共同点:用 glob/grep 搜索文件系统,无需向量索引。这和 Agentic RAG 的思路一脉相承——Agent 自主决定搜索什么,而不是被动接受检索结果。

6. 子代理架构与模型选择

6.1 不要在会话中切换模型

Cache 是 model-specific 的。各家的做法:

  • Claude Code:Sub-Agent handoff(Opus → Haiku for Explore)
  • Codex:同一对话保持同一模型
  • Manus:任务级路由(Claude 做代码,Gemini 做多模态,OpenAI 做数学)——不同任务不同模型,但单次对话内不变

6.2 子代理的 Cache 友好设计

Claude Code 架构(逆向分析数据):

子代理 工具数 Prefix Reuse
Main Agent 18
Explore × 3 并行 10/18 子集 92%
Plan 独立 context 93%
Execution 全部 97%

Claude Code 还使用 warm-up 调用:启动时预热 tool list 和 system prompt 的 cache。

Manus 多代理架构

  • Planner → Knowledge Manager → Executor 三层
  • 子代理有 submit_results 工具 + 约束解码确保输出格式

Anthropic 建议:子代理返回压缩 summary → 避免主 context 被"污染"。

6.3 Fork 操作必须共享父 prefix

Fork 出的子任务必须用和父对话相同的 prompt prefix,才能复用父对话 cache。

Claude Code 在 compaction、summarization、skill execution 中都遵循这个原则。核心思想:压缩/fork 是在现有 cache 基础上的延伸,而非另起炉灶。

7. The Bitter Lesson

最后分享 Manus 的反思,引用 Rich Sutton 的 "The Bitter Lesson":

Agent 的 harness(框架/约束)可能限制模型性能。随着模型进步,需要不断简化架构。

Manus 自 2025 年 3 月以来已重构无数次。每次模型能力提升,某些 workaround 就变得不必要。

但有些设计是"持久"的——围绕 cache 的架构决策就是。它们不是在弥补模型不足,而是在适配计算的物理现实。

Cache 是物理约束,不是工程 hack。 只要 Prefill 还是 Compute Bound,Prompt Cache 就会继续是 Agent 架构的核心考量。

参考

备注:本文主要受前 4 篇参考内容的启发

其他

最后欢迎关注我,基本全网同名 chaofa用代码打点酱油

Agent 系统中的 Prompt Caching 设计(上):Cache 破坏、Prompt 布局与工具管理

2026-02-22 18:16:00

0. 阅读收获 (takeaway)

读完本文,你将了解:

  • 从 Prompt Engineering 到 Context Engineering 的范式转变
  • 为什么 Agent 比普通 Chatbot 需要 Prompt Caching
  • 什么操作会破坏 Cache(比你想象的多)
  • Prompt 布局与动态信息管理的最佳实践
  • 工具管理的三种 Cache-aware 方案对比

前置知识:本文假设你已经理解 KV Cache、Prefill/Decode 两阶段、以及 Prompt Cache 的前缀匹配机制。如果不熟悉这些概念,建议先阅读 理解 KV Cache 与 Prompt Caching:LLM 推理加速的核心机制

1. 先说结论:Cache Rules Everything

在深入细节之前,我想先分享我对这个话题的理解:

Prompt Cache 不只是一个省钱技巧,它是 Agent 系统架构设计的核心约束。

就像数据库的 schema 设计会影响整个应用架构一样,Prompt Cache 的前缀匹配约束深刻地影响了 Agent 的每一个设计决策:

  • prompt 怎么组织?→ 稳定内容放前面,变化内容放后面
  • 工具怎么管理?→ 工具列表固定,通过其他机制限制可用范围
  • 状态怎么切换?→ 不切换工具,用工具模拟状态转换(Claude Code Plan Mode 就是最好的例子)
  • context 怎么压缩?→ 压缩操作本身必须 cache-safe
  • 模型怎么选?→ 不在同一会话中切换,用子代理隔离

1.1 从 Prompt Engineering 到 Context Engineering

我们已经越来越多地听到 "Context Engineering" 这个术语。区别在哪?

  • Prompt Engineering 关注 "怎么写指令让模型表现更好"——内容层面的优化。
  • Context Engineering 关注 "怎么组织整个上下文——指令、工具、历史、外部信息——让 Agent 系统整体高效运转"——系统架构层面的设计。

Manus 在 25 年底最新总结中提出了三个维度:Reduce(缩减)、Isolate(隔离)、Offload(卸载)。

后面的内容你会看到,各家的设计都在围绕这三个维度展开。

1.2 三家方案的共同规律

不同公司、不同架构,但核心规律惊人地一致:

  1. 前缀不变:system prompt、tools、早期历史永远不修改
  2. 追加不修改:Append-only,永远不编辑历史消息
  3. 工具定义稳定:tools 数组不变,通过其他机制控制可用范围
  4. 动态信息后置:时间戳、环境状态等放在后面的 user message 中
  5. 压缩必须 cache-safe:压缩操作复用父对话的 cache prefix

带着这些规律,我们来看具体的实践。

2. 为什么 Agent 比 Chatbot 更需要 Prompt Cache?

2.1 Agent 的 I/O 比例严重失衡

Agent 每一步都需要发送完整的对话历史给模型,模型只输出一小段。Manus 披露过一个数据:input:output ≈ 100:1

如果没有 Prompt Cache → 每一步重新 Prefill 所有历史 token → 成本二次方增长

2.2 经济账

场景 不缓存 缓存后 节约
Claude(正常 vs cached) $3/MTok $0.30/MTok 90%
OpenAI GPT-5(正常 vs cached) $10/MTok $2.50/MTok 75%
Claude Code 单任务(约 2M tokens) ~$6.00 | ~$1.15 81%

Thariq(Claude Code 团队):

"Coding agents would be cost prohibitive without prompt caching."

OpenAI Codex:cache 命中后,采样开销从二次降为线性

2.3 延迟:TTFT 的大幅改善

  • Manus:KV cache hit rate 是 "the single most important metric"
  • Claude Code 团队:cache 命中率下降 → 当作线上事故(SEV)处理
  • OpenAI:150K+ tokens 时,cached 请求 TTFT 快 67%

3. 什么操作会破坏 Cache?

前缀匹配是一切的基础:任何位置的任何改动 → 该位置之后的 cache 全部失效。

3.1 改动 System Prompt

在开头放时间戳——Manus 踩过的坑。时间戳每秒都变,第一个 token 就不同,整个 cache 废掉。

3.2 改动 Tool Definitions

Claude Code 团队经验中最常见的 cache 破坏方式。Tool definitions 在 prompt 前部,增删任何一个工具 → 后续所有 cache 失效。

具体场景:

  • 增删工具:动态加载不同工具集 → 每次变化 cache 全废
  • 工具顺序不确定:Codex 的 MCP 工具注册顺序不确定
  • 更新工具参数:如 allowed_subagents 列表变化

3.3 切换模型

Cache 是 model-specific 的。

一个反直觉推论:100K token 对话中,切换到更便宜的模型可能更贵 —— Opus 的 100K cached token 只需 $1.50,换 Haiku 后全部重算。

3.4 修改历史消息

  • 编辑或删除之前的 action/observation → 破坏 cache
  • 非确定性序列化:JSON key 排序不一致(Manus 踩坑)→ 相同语义不同 token 序列

核心原则:序列化必须是确定性的。

4. Prompt 布局与动态信息管理

核心思路:把稳定的内容放前面,把变化的内容放后面。

4.1 Claude Code 的四层缓存架构

Claude Code 四层布局与 Codex Prompt 构建过程对比

层级 内容 稳定性
Layer 1 Static System Prompt & Tools 全局不变
Layer 2 CLAUDE.md 项目配置 项目级不变
Layer 3 Session Context(git status 等) 会话级
Layer 4 Conversation Messages 每轮追加

每轮只有 Layer 4 增长,前 3 层稳定命中 cache。

4.2 OpenAI Codex 的 Prompt 构建

三层结构:instructions → tools → input(input 可能会包含 dev role message,上图)。关键设计:旧提示是新提示的精确前缀

配置变更(沙盒权限、工作目录)→ 追加新消息而非修改旧消息。

4.3 Manus 的三条规则

  1. 稳定前缀
  2. Append-only
  3. 确定性序列化

4.4 动态信息怎么更新?

方案 实现方式
Claude Code <system-reminder> 标签放在 user message 中
Codex 追加新的 developer/user 消息

永远追加,永远不修改。

4.5 Cache Breakpoint 与 Auto-caching

  • Claude API:从手动 cache_control breakpoint → auto-caching 一个参数搞定
  • OpenAI APIprompt_cache_key 路由优化,自动缓存 ≥1024 token 前缀

反直觉:900 token prompt 永远不 cache hit,扩展到 1024+ token 反而更省钱。

5. 工具管理:三种方案,殊途同归

5.1 问题本质

Agent 可能有 30 个工具,但不同阶段只需要一部分。如果按需加载 → 每次状态切换 cache 全废。

三家工具管理策略对比

5.2 Claude Code:状态转换 + defer_loading

  • Plan ModeEnterPlanMode/ExitPlanMode 本身作为工具 → 工具列表永远不变
  • Tool Searchdefer_loading stub → ToolSearch 按需获取完整 schema
  • 模型可自主决定何时进入 Plan Mode

5.3 Manus:Logits Masking

  • 所有工具始终在 prompt 中
  • 工具命名约定:browser_xxxshell_xxx
  • Token logits masking 控制可用工具
  • 三种模式:Auto / Required / Specified

5.4 OpenAI:allowed_tools 参数

  • tools 数组完整不变
  • allowed_tools 限制当前可用子集
  • 注意:MCP 服务器可动态变更工具列表 → 需谨慎处理

5.5 对比总结

本质:工具定义不变(保 cache),通过其他机制限制可选范围。

方案 实现方式 优点 限制
Claude Code tool 本身 + defer_loading 灵活,模型自主决策 需 API 支持
Manus logits masking 精细控制 需 self-hosting
OpenAI allowed_tools 参数 最简单 仅粗粒度

下一篇

本文聚焦于 Cache-aware 的 Prompt 设计和工具管理。但 Agent 还面临另一组挑战:context 越来越长怎么办?怎么压缩才不破坏 cache?子代理怎么设计?

下一篇 Agent 系统中的 Prompt Cache 设计(下):上下文管理与子代理架构 将深入这些话题。

参考

其他

最后欢迎关注我,基本全网同名 chaofa用代码打点酱油

理解 KV Cache 与 Prompt Caching:LLM 推理加速的核心机制

2026-02-21 18:00:00

0. 阅读收获 (takeaway)

读完本文,你将了解:

  • KV Cache 的原理以及它为什么对 LLM 推理如此重要
  • Prefill 与 Decode 两个推理阶段的区别
  • Compute Bound 与 Memory Bound 背后的直觉
  • 一个很好的问题:Prefill 阶段为什么需要计算所有 token 的 Q?
  • Prompt Caching(前缀缓存)的工作原理

KV Cache 和 Prompt Cache 对于 Agent 设计的影响:

1. 什么是 KV Cache?

1.1 Autoregressive 生成的重复计算问题

大语言模型(LLM)的文本生成是 自回归(autoregressive) 的:每次只生成一个 token,然后把这个 token 拼到已有序列后面,再预测下一个。

用伪代码表示就是:

# 自回归生成的朴素实现
output_tokens = []
for step in range(max_new_tokens):
    # 每一步都要把 整个序列 送进模型
    logits = model(input_tokens + output_tokens)
    next_token = sample(logits[-1])  # 只用最后一个位置的 logits
    output_tokens.append(next_token)

Q: 问题出在哪?

每一步生成,模型都要对所有历史 token 重新做 Attention 计算——包括 Q、K、V 矩阵乘法。但对于已经出现过的 token,它们的 K 和 V 其实不会变(因为参数没变、token 没变),唯一在变的只有 "最新生成的那个 token" 对应的 Q、K、V。

这就引出了一个自然的优化思路:能不能把已经算过的 K 和 V 缓存起来,下次直接用?

1.2 KV Cache 的核心思想

KV Cache 的核心思想非常直接:

把每一层 Attention 中、每个已生成 token 对应的 K 向量V 向量缓存下来。后续生成新 token 时,只需要计算新 token 自己的 Q、K、V,然后将新的 K、V 追加到缓存中,用缓存里的完整 K、V 序列做 Attention。

这样一来,生成第 $t$ 个 token 时,Attention 的计算从 $O(t \times d)$(重算所有 token 的 K、V)降低到 $O(d)$(只算 1 个新 token 的 K、V),避免了绝大部分重复计算

用带 KV Cache 的伪代码表示:

# 带 KV Cache 的生成
kv_cache = {}  # 每一层缓存 K, V
for step in range(max_new_tokens):
    if step == 0:
        # 第一步:处理所有 input tokens,填充 cache
        logits, kv_cache = model(input_tokens, kv_cache=None)
    else:
        # 后续步:只送入上一步生成的 1 个 token
        logits, kv_cache = model([last_token], kv_cache=kv_cache)
    next_token = sample(logits[-1])
    last_token = next_token

1.3 KV Cache 显存占用

KV Cache 不是免费的——它用显存计算。随着生成序列变长,KV Cache 占用的显存会线性增长。

具体公式(假设 float16 存储):

$$\text{KV Cache 显存} = 4blh(s + n) \text{ bytes}$$

其中:

  • $b$ = batch size
  • $l$ = Transformer 层数
  • $h$ = hidden size
  • $s$ = 输入序列长度
  • $n$ = 输出序列长度
  • 4 = 2(K 和 V)× 2(float16 占 2 bytes)

这个公式的详细推导和具体数值例子,可以参考我之前的文章 LLM 大模型训练-推理显存占用分析。这里只需要记住一个直觉:序列越长,KV Cache 越大。这也是为什么后续会有 GQA(Grouped Query Attention)、DeepSeek MLA 等 KV Cache 压缩技术出现。

2. Prefill vs Decode:推理的两个阶段

理解了 KV Cache 之后,我们可以把 LLM 推理过程清晰地分成两个阶段:PrefillDecode。这两个阶段的计算特性截然不同,理解它们的区别对后面理解 Prompt Cache 非常关键。

Prefill 与 Decode 两阶段对比

2.1 Prefill 阶段(并行处理 input tokens → Compute Bound)

Prefill 阶段就是上面伪代码中 step == 0 的那一步:模型一次性处理所有输入 token(system prompt + user message),为每一层、每个 token 计算出 K 和 V 并存入 cache。

关键特点:

  • 所有输入 token 可以并行处理(它们之间的 Attention mask 是 causal 的,但计算可以用矩阵乘法一次完成)
  • 计算量大:$n$ 个 token × 所有层 × Q/K/V 矩阵运算
  • Compute Bound:GPU 的算力是瓶颈

2.2 Decode 阶段(逐 token 生成 → Memory Bound)

Decode 阶段就是后续的 step > 0:每一步只输入 1 个 token,利用 KV Cache 做 Attention,生成下一个 token。

关键特点:

  • 每步只处理 1 个 token(无法并行,因为下一个 token 依赖上一个的输出)
  • 每步的计算量其实不大——1 个 token 的 Q 乘以 cache 中所有 K/V
  • 但每步都要从显存读取整个 KV Cache
  • Memory Bound:GPU 的显存带宽是瓶颈

2.3 Compute Bound vs Memory Bound

这里解释一下 Compute Bound 和 Memory Bound 的含义,核心概念是 Arithmetic Intensity(算术强度)

$$\text{Arithmetic Intensity} = \frac{\text{计算量 (FLOPs)}}{\text{数据搬运量 (Bytes)}}$$
  • Compute Bound:算术强度高,GPU 的计算单元忙不过来,数据搬运不是瓶颈。Prefill 就是这种情况——大矩阵乘法,计算密集。
  • Memory Bound:算术强度低,GPU 的计算单元在等数据从显存搬过来。Decode 就是这种情况——每步只有 1 个 token 的小矩阵运算,但要读取整个 KV Cache。

用一个直觉来理解:

  • Prefill 像是"一次批量处理 1000 个快递"——流水线拉满,打包效率高
  • Decode 像是"每次只来 1 个快递"——打包机器空闲大半时间,瓶颈在于快递从仓库取出来的速度

2.4 一个很好的问题:Prefill 为什么要计算所有 token 的 Q?

这个问题来自一位读者在 GitHub Discussion 的提问,我觉得是一个非常好的问题:

既然我们只需要预测 next token,Prefill 阶段不是只需要最后一个 token 的 Q 吗?为什么要计算所有 token 的 Q?

乍一看很有道理——Attention 的输出 $\text{softmax}(QK^T / \sqrt{d})V$ 中,我们只需要最后一个位置的结果来预测 next token。那 K、V 确实需要全算(因为最后一个 Q 要和所有 K 做 attention),但 Q 为什么不能只算最后一个?

答案的核心是:Decoder 有很多层

如果 Transformer 只有一层,那确实,我们只需要最后一个 token 的 Q。但实际的 Decoder 有几十层,上一层所有位置的输出是下一层所有位置的输入

  1. 第 1 层:为了得到所有位置的 K、V(这些 K、V 要存入 cache),需要知道所有位置的输入。而第 1 层的输入就是 token embedding,所以 Q、K、V 都要算全部位置。
  2. 第 2 层:第 2 层的输入是第 1 层的输出。第 1 层的输出取决于 Attention 的完整计算——包括所有位置的 Q。因此第 2 层的 K、V 计算依赖于第 1 层所有位置的 Q 计算结果。
  3. 第 N 层:同理,依赖前面所有层的完整输出。

看图:

Prefill 阶段的计算过程

所以结论是:

Prefill 必须计算所有 token 的 Q,不是因为最终预测需要,而是因为每一层的 K、V 缓存依赖于上一层所有位置的完整输出,而上一层的完整输出需要所有位置的 Q 参与计算。

这也解释了为什么 Prefill 阶段是 Compute Bound——它确实需要做大量计算,不是在浪费。

2.5 TTFT vs TPOT

从用户体验的角度,两个阶段对应两个不同的延迟指标:

  • TTFT(Time To First Token):用户发送请求到看到第一个输出 token 的时间。主要由 Prefill 阶段决定。
  • TPOT(Time Per Output Token):生成每个后续 token 的平均时间。主要由 Decode 阶段决定。

对于输入很长的场景(比如长文档问答、Agent 的多轮对话),Prefill 阶段的耗时会显著增加 TTFT。

这就引出了一个关键问题:如果我们能跳过 Prefill 中那些"之前已经算过"的部分,是不是就能大幅降低 TTFT?这就是 Prompt Cache 要解决的问题。

3. Prompt Cache(前缀缓存)

3.1 从 KV Cache 到 Prompt Cache

前面说的 KV Cache 是单次请求内部的优化——生成过程中缓存已算过的 K、V,避免重复计算。

Prompt Caching(前缀缓存) 则是跨请求的优化:

如果两次 API 调用的 prompt 有相同的前缀,那么第二次调用可以直接复用第一次 Prefill 阶段算出来的 KV Cache,跳过前缀部分的 Prefill 计算。

这对于以下场景特别有价值:

  • 多轮对话:每轮对话的 prompt 都以之前的对话历史作为前缀
  • 相同 system prompt:同一个应用的所有请求共享相同的 system prompt 前缀
  • Agent 系统:Agent 每一步的 prompt 都是上一步的 prompt 加上新的 action/observation

3.2 前缀匹配机制

Prompt Cache 的匹配规则非常严格:

必须从第一个 token 开始完全一致,一个 token 的差异就会导致该位置之后的 cache 全部失效。

前缀匹配:Cache Hit 与 Cache Miss

举个例子:

请求 内容 Cache 命中情况
请求 1 [System Prompt][User: Hello] 无 cache,全部计算
请求 2 [System Prompt][User: Hello][Assistant: Hi][User: 你好] [System Prompt][User: Hello] 部分 cache hit
请求 3 [Modified System Prompt][User: Hello] 完全 miss,因为第一个 token 就不一样了

这个"前缀精确匹配"的约束,是后面 Agent 系统设计的核心基础。在下一篇文章中,我们会详细讨论 Claude Code、Manus、OpenAI Codex 如何围绕这个约束设计整个系统架构。

3.3 开源推理引擎的实现

在开源推理引擎中,Prompt Cache(通常叫 Prefix Caching 或 Automatic Prefix Caching)已经是标配功能:

  • vLLM:通过 --enable-prefix-caching 开启,使用 hash-based 的 block 管理机制。将 KV Cache 按固定大小的 block 存储,相同前缀的 block 可以在不同请求间共享。
  • SGLang:默认开启 RadixAttention,用 Radix Tree(基数树)管理 KV Cache 前缀。相比 vLLM 的 hash 方案,Radix Tree 在前缀共享上有天然优势——可以高效处理多层级的前缀共享。

这些引擎的实现细节不是本文重点,关键是理解:Prompt Cache 在推理引擎层面已经是成熟技术,无论你用开源引擎自部署还是调用商业 API,都可以获得这个优化。

4. 总结与预告

让我们回顾一下本文的核心脉络:

  1. KV Cache 解决了自回归生成中的重复计算问题,是 LLM 推理的基础优化
  2. Prefill 阶段并行处理所有输入 token(Compute Bound),Decode 阶段逐个生成 token(Memory Bound)
  3. Prefill 必须计算所有 token 的 Q,因为多层 Decoder 的层间依赖
  4. Prompt Cache 把 KV Cache 的优化从"单次请求内"扩展到"跨请求",通过前缀匹配复用已计算的 KV Cache
  5. 前缀精确匹配的约束,决定了 Prompt Cache 的使用方式——任何位置的改动都会破坏该位置之后的 cache

这个"前缀精确匹配"的约束,在 AI Agent 系统中变得尤为关键。Agent 每一步都要发送越来越长的 context(历史对话 + 工具调用结果),如果不精心设计 prompt 结构,cache 命中率会很低,成本和延迟都会飙升。

在接下来的两篇文章中,我会详细分析 Claude Code、Manus、OpenAI Codex 等 AI Agent 如何围绕 Prompt Cache 设计整个系统架构:

参考

其他

最后欢迎关注我,基本全网同名 chaofa用代码打点酱油

DPO 算法原理与代码实现:让 LLM 对齐变得简单

2026-01-10 23:07:00

0. 阅读收获 (takeaway)

本文目标是搞懂 DPO(Direct Preference Optimization)算法,阅读完本文你将获得:

  • 理解 DPO 的核心思想:为什么 DPO 可以替代 RLHF 中的 PPO
  • 掌握 DPO 与 RLHF 的关键区别:从 4 个模型到 2 个模型
  • 手撕 DPO Loss:理解损失函数到底在算什么
  • Bonus 1:为什么 DPO 比 PPO 训练更稳定
  • Bonus 2:DPO 损失函数的完整数学推导
  • 源代码位于 Github -动手学习大模型-中文版-第 12.1章——动手学习 DPO

本文代码运行于:Featurize 蒜粒方块 GPU 算力平台,不喜欢看文字的同学可以看 B站视频-chaofa用代码打点酱油YouTube-chaofa用代码打点酱油,视频号:chaofa用代码打点酱油

1. 为什么需要 DPO?

在聊 DPO 之前,我们先快速回顾一下 LLM 训练的三个阶段(参考 OpenAI InstructGPT):

  1. 预训练(Pre-training):在海量文本上训练,让模型学会"说话"
  2. 监督微调(SFT):用高质量的指令数据微调,让模型学会"听话"
  3. 对齐(Alignment):让模型的输出符合人类偏好,学会"说人话"

假设读者对于前两个步骤已经有所了解,这篇文章的重点是第三步"对齐"。

1.1 RLHF 的问题

OpenAI 在训练 ChatGPT 的时候用的是 RLHF(Reinforcement Learning from Human Feedback),整个流程大概是这样的:

DPO 原论文中的 RLHF vs DPO 流程对比

RLHF 确实有效,但问题也很明显:

  1. 需要 4 个模型:Actor(待训练)、Reference(冻结的 SFT 模型)、Reward Model(奖励模型)、Critic(价值函数)
  2. PPO 算法复杂:超参数一堆,训练不稳定,调参调到怀疑人生
  3. 资源消耗大:4 个模型同时跑,显存吃不消

之前在 DeepSeek-R1 论文解读 里也提到过,DPO 是 RLHF 的一种替代方案,但 DeepSeek 最终还是用了 GRPO(一种改进的 PPO)。不过对于大多数场景来说,DPO 已经够用了。

1.2 DPO 的卖点

DPO 的核心思路是:既然 RLHF 这么麻烦,能不能把强化学习的部分去掉,直接用监督学习的方式来做对齐?

答案是可以的。DPO 的作者通过一系列数学推导(后面 Bonus 部分会讲),证明了可以把 RLHF 的优化目标转换成一个简单的损失函数,只需要 2 个模型就能搞定:

  • Actor:待训练的模型 $\pi_\theta$
  • Reference:冻结的 SFT 模型 $\pi_{ref}$

不需要单独训练 Reward Model,也不需要 PPO 那套复杂的东西。训练过程和 SFT 差不多,非常稳定。

2. DPO 的核心思想

2.1 偏好数据长什么样?

DPO 需要的数据格式很简单,就是一个 prompt 配上两个回答:一个好的(chosen),一个差的(rejected)。

# DPO 偏好数据示例
{
    "prompt": "介绍一下 chaofa用代码打点酱油 这个博主",
    "chosen": "chaofa用代码打点酱油 是一位专注于大模型技术的博主,他在 B站、YouTube 等平台分享 LLM 相关的技术内容,包括动手学大模型系列教程。他的内容特点是注重代码实现和原理讲解,帮助读者从零理解大模型的各种技术细节。",
    "rejected": "不知道,没听说过,说不定是个弱智。"
}

简单说就是:同一个问题,告诉模型哪个回答是好的,哪个是不好的。这种数据可以通过人工标注获得,也可以用更强的模型(比如 gemini/claude/gpt)来生成。

TRICK: 非同源模型的数据训练的时候,可以先用 "chosen" 数据 SFT,不然可能导致 chosen 和 rejected 概率都变低。

2.2 DPO 想做什么?

DPO 的目标其实就两个:

  1. 让模型更喜欢生成 chosen 回答:提高 chosen 的生成概率
  2. 不要偏离原来的 SFT 模型太远:保持模型的基本能力,防止"忘记"之前学到的东西

第二点很重要,如果只追求第一点,模型可能会为了迎合偏好数据而变得很奇怪(比如每个回答都很长、很啰嗦)。所以需要用参考模型来"拉住"它。

2.3 DPO 损失函数

好了,到了最核心的部分。DPO 的损失函数长这样:

$$\mathcal{L}_{\mathrm{DPO}}(\pi_\theta; \pi_{\mathrm{ref}}) = - \mathbb{E}_{(x,y_w,y_l) \sim D} \left[ \log \sigma\Big(\beta \log \frac{\pi_\theta(y_w \mid x)}{\pi_{\mathrm{ref}}(y_w \mid x)} - \beta \log \frac{\pi_\theta(y_l \mid x)}{\pi_{\mathrm{ref}}(y_l \mid x)}\Big) \right]$$

这个公式看起来贼复杂,但逻辑其实很清晰。首先看公式里面的核心部分,是在比较两个东西:

  • $\log \frac{\pi_\theta(y_w \mid x)}{\pi_{\mathrm{ref}}(y_w \mid x)}$:当前模型相对于参考模型,在 chosen 回答上的对数概率变化
  • $\log \frac{\pi_\theta(y_l \mid x)}{\pi_{\mathrm{ref}}(y_l \mid x)}$:当前模型相对于参考模型,在 rejected 回答上的对数概率变化

我们希望前者大于后者。也就是说,模型在 chosen 上的"提升幅度"要大于在 rejected 上的"提升幅度"。

$\beta$ 是一个超参数,用来控制"偏离参考模型的惩罚力度"。$\beta$ 越大,模型越不敢偏离参考模型;$\beta$ 越小,模型越"激进"。一般从 0.1 开始试。

$\sigma$ 就是 sigmoid 函数,把差值映射到 (0, 1) 区间,然后取 log 变成 loss。

Q: 这个公式是怎么推导出来的?为什么这样设计就能达到我们的目标?这些问题留到 Bonus 部分再说。现在只要理解"DPO 在做什么"就够了。

3. 手撕 DPO Loss

理解了原理之后,我们来看看代码怎么写。其实 DPO 的核心代码非常简单,比公式看起来简单多了。

3.1 计算序列的 log 概率

首先,我们需要一个函数来计算模型在某个序列上的 log 概率。

对于语言模型来说,生成一个序列的概率就是每个 token 条件概率的乘积。取 log 之后,乘积变成求和:

$$\log \pi(y|x) = \sum_t \log P(y_t | y_{<t}, x)$$
import torch
import torch.nn.functional as F


def compute_log_probs(
    logits: torch.Tensor,       # (batch, seq_len, vocab_size)
    labels: torch.Tensor,       # (batch, seq_len)
    mask: torch.Tensor          # (batch, seq_len),标记哪些位置需要计算
) -> torch.Tensor:
    """
    计算序列的对数概率

    注意:这里只计算 response 部分的概率,prompt 部分不算
    """
    # 获取每个位置的 log softmax
    log_probs = F.log_softmax(logits, dim=-1)

    # 取出对应 label 的 log 概率
    # gather 操作:从 vocab_size 维度取出 labels 对应的概率
    per_token_log_probs = torch.gather(
        log_probs,
        dim=-1,
        index=labels.unsqueeze(-1)
    ).squeeze(-1)

    # 只计算 mask=1 的位置(response 部分)
    masked_log_probs = per_token_log_probs * mask

    # 求和得到整个序列的 log 概率
    return masked_log_probs.sum(dim=-1)

3.2 DPO Loss 核心实现

有了计算 log 概率的函数,DPO Loss 的实现就很直接了:

def dpo_loss(
    policy_chosen_logps: torch.Tensor,    # 当前模型在 chosen 上的 log 概率
    policy_rejected_logps: torch.Tensor,  # 当前模型在 rejected 上的 log 概率
    ref_chosen_logps: torch.Tensor,       # 参考模型在 chosen 上的 log 概率
    ref_rejected_logps: torch.Tensor,     # 参考模型在 rejected 上的 log 概率
    beta: float = 0.1,
) -> torch.Tensor:
    """
    DPO Loss 的核心实现

    代码比公式简单多了吧?
    """
    # 计算 log ratio:当前模型相对于参考模型的变化
    chosen_log_ratios = policy_chosen_logps - ref_chosen_logps
    rejected_log_ratios = policy_rejected_logps - ref_rejected_logps

    # 核心:我们希望 chosen 的 ratio 大于 rejected 的 ratio
    logits = beta * (chosen_log_ratios - rejected_log_ratios)

    # 用 logsigmoid 更数值稳定(等价于 -log(sigmoid(logits)))
    losses = -F.logsigmoid(logits)

    return losses.mean()

就这么简单。核心就三行:

  1. 计算 chosen 的 log ratio
  2. 计算 rejected 的 log ratio
  3. 用 sigmoid + log 算 loss

完整的训练代码涉及数据处理、模型加载等,这里就不展开了。可以参考 trl 源码

4. 用 trl 跑一下 DPO 训练

手写 DPO Loss 是为了理解原理,实际训练的话直接用 trl 就好了。trl 是 Hugging Face 出的强化学习库,DPO 训练用起来很简单。

from datasets import Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import DPOConfig, DPOTrainer

# 1. 准备模型
model_name = "Qwen/Qwen2.5-0.5B-Instruct"  # 用小模型演示
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# 参考模型(就是 SFT 后的模型,这里直接用同一个)
ref_model = AutoModelForCausalLM.from_pretrained(model_name)

# 2. 准备数据(trl 需要的格式)
train_data = Dataset.from_dict({
    "prompt": [
        "介绍一下 chaofa用代码打点酱油 这个博主",
        "DPO 和 RLHF 哪个更适合入门?",
    ],
    "chosen": [
        "chaofa用代码打点酱油 是一位专注于大模型技术的博主,在 B站、YouTube 分享 LLM 相关教程,内容注重代码实现和原理讲解,帮助读者从零理解大模型技术。",
        "建议先学 DPO,原理更简单,训练也更稳定。可以看 chaofa用代码打点酱油 的动手学大模型系列,有详细的代码实现。",
    ],
    "rejected": [
        "没听说过,应该是个小透明吧。",
        "都差不多,随便选一个。",
    ],
})

# 3. 配置训练参数
training_args = DPOConfig(
    output_dir="./dpo_output",
    beta=0.1,                    # DPO 的温度参数
    learning_rate=5e-7,          # DPO 通常用比较小的学习率
    per_device_train_batch_size=2,
    num_train_epochs=1,
    logging_steps=10,
    bf16=True,
)

# 4. 创建 Trainer 并训练
trainer = DPOTrainer(
    model=model,
    ref_model=ref_model,
    args=training_args,
    train_dataset=train_data,
    tokenizer=tokenizer,
)

trainer.train()

关键参数说一下:

  • beta:前面说过,控制偏离参考模型的惩罚力度,一般从 0.1 开始试
  • learning_rate:DPO 通常用比较小的学习率,5e-7 到 5e-6 左右

5. Bonus 1:为什么 DPO 比 PPO 训练更稳定?

很多人说"DPO 比 PPO 更稳定",但到底为什么呢?这个问题其实可以从几个角度来理解:

5.1 Off-policy vs On-policy

PPO 是一种 on-policy 的强化学习算法,DPO 是 off-policy 的,它直接用离线的偏好数据来训练,训练过程和 SFT 差不多。

  • on policy 每一次样本都是采样出来的,梯度可能会随时发生变化,梯度方差大;数据分布随着模型的更新会发生变化,上一轮学好的参数可能不适用下一轮,reward 比较稀疏(SFT/DPO 是 Token 级别的监督信号)。

5.2 不需要 Reward Model 和 Critic

PPO 在 RLHF 中需要:

  • 一个 Reward Model 来打分(这个模型本身就可能有问题,比如 reward hacking)
  • 一个 Critic(Value Function)来估计优势函数(这个网络的训练也不简单)

这些额外的模型都会引入噪声和不稳定因素。DPO 把 Reward Model 直接"吸收"到了损失函数里,不需要单独训练,少了很多可能出错的地方。

5.3 超参数敏感度

PPO 有很多超参数需要调:

  • clip ratio(裁剪系数)
  • GAE lambda
  • 学习率、batch size、epoch 数
  • KL 惩罚系数
  • ...

这些参数之间还有复杂的相互作用,调参调到怀疑人生是常有的事。DPO 的核心超参数就一个 $\beta$,最多再加上学习率。简单很多。

备注:这里说的"稳定"。PPO/GRPO 调好了效果可能更好,但训练成本也更高。对于大多数场景来说,DPO 是一个性价比很高的选择。

6. Bonus 2:DPO 数学推导

这部分是给想深入理解的同学看的,跳过也不影响使用 DPO。

DPO 的 Loss 不是凭空设计出来的,而是从 RLHF 的优化目标一步步推导出来的。

6.1 RLHF 的优化目标

RLHF 想要做的事情是:最大化奖励,同时不要偏离参考模型太远。用公式表示:

$$\max_{\pi} \; \mathbb{E}_{x \sim \mathcal{D},\, y \sim \pi(y \mid x)} \Big[ r(x,y) \Big] - \beta\, \mathbb{D}_{\mathrm{KL}}\Big[ \pi(y \mid x) \,\|\, \pi_{\mathrm{ref}}(y \mid x) \Big]$$

其中:

  • $r(x,y)$ 是奖励函数(需要单独训练一个 Reward Model)
  • KL 散度用来约束模型不要偏离参考模型太远
  • $\beta$ 控制约束的强度

6.2 最优策略的形式

这个优化问题有一个解析解。我们先假设存在这样一个最优策略 $\pi^*$,(具体推导可以参考 DPO 原论文附录,但我没看懂直接抄过来了),可以得到最优策略满足:

$$\pi^*(y \mid x) = \frac{1}{Z(x)} \pi_{\mathrm{ref}}(y \mid x) \exp\Big(\frac{1}{\beta} r(x,y)\Big)$$

其中 $Z(x) = \sum_y \pi_{\mathrm{ref}}(y \mid x) \exp\Big(\frac{1}{\beta} r(x,y)\Big)$ 是归一化常数(配分函数),保证概率和为 1。

备注:

  • 它说的是:最优策略在参考策略的基础上,根据奖励大小进行"加权"。奖励高的回答概率会指数级增大,奖励低的会被抑制。$\beta$ 控制这个"加权"的激进程度。
  • 这个最优策略就是我们要学习的「模型参数」

6.3 反解奖励函数

从上面的式子,我们可以反过来把奖励函数用策略来表示:

$$r(x,y) = \beta \log \frac{\pi^*(y \mid x)}{\pi_{\mathrm{ref}}(y \mid x)} + \beta \log Z(x)$$

这告诉我们:奖励函数可以用"当前策略和参考策略的 log 概率比"来表示

6.4 Bradley-Terry 偏好模型

在有偏好数据的时候,我们通常用 Bradley-Terry 模型来建模"哪个回答更好":

$$P(y_w \succ y_l \mid x) = \frac{\exp[r(x, y_w)]}{\exp[r(x, y_w)] + \exp[r(x, y_l)]} = \sigma(r(x, y_w) - r(x, y_l))$$

$y_w$ 是 chosen 的样本, $y_l$ 是 rejected 的样本。$y_w$ 被偏好的概率取决于两个回答的奖励之差。

6.5 代入得到 DPO Loss

现在把 6.3 中的奖励函数代入 Bradley-Terry 模型。关键观察是:$\log Z(x)$ 在两个回答中是一样的,相减的时候会消掉!

$$r(x, y_w) - r(x, y_l) = \beta \log \frac{\pi^*(y_w \mid x)}{\pi_{\mathrm{ref}}(y_w \mid x)} - \beta \log \frac{\pi^*(y_l \mid x)}{\pi_{\mathrm{ref}}(y_l \mid x)}$$

前面提到,我们把 待训练的模型 $\pi_\theta$ 认为是最优策略 $\pi^*$

最终,最大化偏好数据的似然(等价于最小化负对数似然),就得到了 DPO Loss:

$$\mathcal{L}_{\mathrm{DPO}} = - \mathbb{E}_{(x,y_w,y_l)} \left[ \log \sigma\Big(\beta \log \frac{\pi_\theta(y_w \mid x)}{\pi_{\mathrm{ref}}(y_w \mid x)} - \beta \log \frac{\pi_\theta(y_l \mid x)}{\pi_{\mathrm{ref}}(y_l \mid x)}\Big) \right]$$

这就是我们在第 2 节看到的 DPO Loss。

7. 总结

一句话总结:DPO 用监督学习的方式实现了 RLHF 的效果,把 4 个模型简化成 2 个,训练更稳定、资源消耗更低

DPO 的局限性:

  • 依赖偏好数据的质量,数据不好效果就不好
  • $\beta$ 参数比较敏感,需要调参

后续还有一些 DPO 的变体,比如 IPO(Identity Preference Optimization)、KTO(Kahneman-Tversky Optimization)等,以后有机会再聊(其实就是大概率没有机会了,醒醒吧,2026 年了)。

8. 参考资料

  1. DPO 原论文: Direct Preference Optimization
  2. trl 库文档

其他

最后欢迎关注我,基本全网同名 chaofa用代码打点酱油

从零手写 RoPE 位置编码:原理、PyTorch 源码实现与可视化理解

2026-01-02 00:57:20

0. 阅读收获 (takeaway)

本文旨在彻底搞懂 RoPE(Rotary Position Embedding)位置编码,阅读完本文你将获得:

  • 理解 RoPE 的核心思想:为什么用"旋转"来编码位置信息
  • 掌握 RoPE 的数学原理:从旋转矩阵到三角函数证明
  • 从零手写 RoPE 实现:逐行代码讲解,可直接运行
  • bonus:可视化理解 RoPE:通过热力图和动画直观感受旋转编码

本文代码运行于: Featurize 蒜粒方块 GPU 算力平台,有 GPU 使用需求的同学希望能使用我的邀请链接注册

待更新:不喜欢看文字的同学可以看 B站视频-chaofa用代码打点酱油, YouTube-chaofa用代码打点酱油,或视频号:chaofa用代码打点酱油

1. 为什么需要位置编码?

在 Transformer 架构中,Self-Attention 机制本身是位置无关的。公式如下:

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$

softmax 中 QK 的乘积就是重要性权重,什么意思呢?

# 假设我们有两个句子
sentence1 = "朝发 写 代码"
sentence2 = "代码 写 朝发"

# 对于纯 Self-Attention 来说,这两个句子的表示是一样的!
# 从公式看 Attention 只关心 token 之间的权重关系,不关心它们的顺序

这显然是不对的。语言是有顺序的,顺序不同意思完全不同。因此,我们需要位置编码(Position Encoding, PE)来告诉模型每个 token 在序列中的位置。

1.1 绝对位置编码 vs 相对位置编码

用一个例子来理解这两种编码方式的区别:

句子: "朝发 写 代码"
位置:   0  1  2

绝对位置编码:给每个位置一个固定编号

"朝发" → 位置 0 → PE_0
"写"   → 位置 1 → PE_1
"代码" → 位置 2 → PE_2
备注:PE_0 表示第一个位置的 embedding

相对位置编码:关注两个 token 之间的距离

计算 "朝发" 和 "代码" 的关系时:
→ 不关心它们分别在位置 0 和 2
→ 只关心它们相距 2 个位置

同理:计算 "朝发" 和 "写" 之间的相对位置是 (1 - 0) = 1。

使用相对位置编码就是希望捕获 Token 之间位置的相对关系,保持(某些)语义的不变性,下面 「朝发」和「代码」之间的关系是一样的,尽管绝对位置不同:

句子 A: "朝发 写 代码"
句子 B: "今天 朝发 写 代码"

2. RoPE 的核心思想

RoPE(Rotary Position Embedding,旋转位置编码)的核心思想非常优雅,可以阅读苏神 RoPE blog

通过旋转变换为向量注入位置信息,使得两个向量的内积只依赖于它们的相对位置。

这句话怎么理解呢?让我们一步步拆解看。

2.1 从 2D 旋转说起

假设我们在二维平面上有一个向量 $(x, y)$,将它旋转角度 $\theta$ 后得到新向量:

$$\begin{pmatrix} x' \\ y' \end{pmatrix} = \begin{pmatrix} \cos\theta & -\sin\theta \\ \sin\theta & \cos\theta \end{pmatrix} \begin{pmatrix} x \\ y \end{pmatrix}$$

这就是经典的 2D 旋转矩阵。下面用一张图来直观理解:

2D 向量旋转示意图

从图中可以看到:蓝色向量 $(x, y)$ 绕原点逆时针旋转角度 $\theta$ 后,变成红色向量 $(x', y')$

2.2 RoPE 的目标与解决方案

目标:我们希望找到一个位置编码函数 $f$,使得 query 向量 $\mathbf{q}_m$ 和 key 向量 $\mathbf{k}_n$ 的内积只依赖于它们的相对位置 $(m-n)$

$$\langle f_q(\mathbf{q}, m), f_k(\mathbf{k}, n) \rangle = g(\mathbf{q}, \mathbf{k}, m-n)$$

也就是说,无论 $m$$n$ 的绝对值是多少,只要 $m-n$ 相同,内积结果就相同。

解决方案:RoPE 发现,这个函数 $f$ 就是旋转函数!(实际上是可以通过求解出来的,可以参考:Transformer升级之路:2、博采众长的旋转式位置编码),这里我们假设「知道了这么一个函数」,然后我们去证明它符合我们的需求。


假设词嵌入维度是 2 维($d=2$),对位置 $m$ 的向量 $\mathbf{q}$,应用旋转角度 $m\theta$

$$f_q(\mathbf{q}, m) = \begin{pmatrix} \cos m\theta & -\sin m\theta \\ \sin m\theta & \cos m\theta \end{pmatrix} \begin{pmatrix} q_1 \\ q_2 \end{pmatrix}$$

同理,对位置 $n$ 的向量 $\mathbf{k}$,应用旋转角度 $n\theta$

$$f_k(\mathbf{k}, n) = \begin{pmatrix} \cos n\theta & -\sin n\theta \\ \sin n\theta & \cos n\theta \end{pmatrix} \begin{pmatrix} k_1 \\ k_2 \end{pmatrix}$$

这就是为什么叫做旋转位置编码:位置信息通过旋转变换注入到向量中。

2.3 证明:旋转函数满足相对位置条件

现在我们来证明,旋转函数确实能让内积只依赖于相对位置 $(m-n)$

备注:推导有点复杂,其实看前后即可。

$$\begin{aligned} &\langle f_q(\mathbf{q}, m), f_k(\mathbf{k}, n) \rangle \\[8pt] &= \begin{pmatrix} \cos m\theta & -\sin m\theta \\ \sin m\theta & \cos m\theta \end{pmatrix} \begin{pmatrix} q_1 \\ q_2 \end{pmatrix} \cdot \begin{pmatrix} \cos n\theta & -\sin n\theta \\ \sin n\theta & \cos n\theta \end{pmatrix} \begin{pmatrix} k_1 \\ k_2 \end{pmatrix} \\[8pt] &= \begin{pmatrix} q_1 \cos m\theta - q_2 \sin m\theta \\ q_1 \sin m\theta + q_2 \cos m\theta \end{pmatrix} \cdot \begin{pmatrix} k_1 \cos n\theta - k_2 \sin n\theta \\ k_1 \sin n\theta + k_2 \cos n\theta \end{pmatrix} \\[8pt] &= (q_1 \cos m\theta - q_2 \sin m\theta)(k_1 \cos n\theta - k_2 \sin n\theta) \\ &\quad + (q_1 \sin m\theta + q_2 \cos m\theta)(k_1 \sin n\theta + k_2 \cos n\theta) \\[8pt] &= q_1 k_1 (\cos m\theta \cos n\theta + \sin m\theta \sin n\theta) \\ &\quad + q_2 k_2 (\sin m\theta \sin n\theta + \cos m\theta \cos n\theta) \\ &\quad + q_1 k_2 (-\cos m\theta \sin n\theta + \sin m\theta \cos n\theta) \\ &\quad + q_2 k_1 (-\sin m\theta \cos n\theta + \cos m\theta \sin n\theta) \\[8pt] &= q_1 k_1 \cos((m-n)\theta) + q_2 k_2 \cos((m-n)\theta) \\ &\quad + q_1 k_2 \sin((m-n)\theta) - q_2 k_1 \sin((m-n)\theta) \\[8pt] &= (q_1 k_1 + q_2 k_2) \cos((m-n)\theta) + (q_1 k_2 - q_2 k_1) \sin((m-n)\theta) \\[8pt] &= \begin{pmatrix} q_1 & q_2 \end{pmatrix} \underbrace{\begin{pmatrix} \cos((m-n)\theta) & -\sin((m-n)\theta) \\ \sin((m-n)\theta) & \cos((m-n)\theta) \end{pmatrix}}_{R_{m-n}} \begin{pmatrix} k_1 \\ k_2 \end{pmatrix} \\[8pt] &= \mathbf{q}^T \cdot R_{m-n} \cdot \mathbf{k} \end{aligned}$$

证毕:我们把中间这个只依赖于 $(m-n)$ 的旋转矩阵记为 $R_{m-n}$,最终结果 $\mathbf{q}^T \cdot R_{m-n} \cdot \mathbf{k}$$m$$n$ 的绝对值无关,只与相对位置 $(m-n)$ 有关。

3. RoPE 的数学原理

现在让我们严格推导 RoPE 的数学形式。

3.1 频率设计

RoPE 对于维度 $d$ 的向量,两两配对处理。对于第 $i$ 对(共 $d/2$ 对),使用频率:

$$\theta_i = 10000^{-2i/d}$$

这个频率设计非常关键:

  • 低维度(小 $i$):频率高,变化快,捕捉短距离依赖
  • 高维度(大 $i$):频率低,变化慢,捕捉长距离依赖

3.2 旋转矩阵的完整形式

对于位置 $m$,向量 $\mathbf{x} = [x_0, x_1, x_2, x_3, ..., x_{d-1}]$,RoPE 的旋转操作可以写成:

$$\text{RoPE}(\mathbf{x}, m) = \begin{pmatrix} x_0 \cos(m\theta_0) - x_1 \sin(m\theta_0) \\ x_1 \cos(m\theta_0) + x_0 \sin(m\theta_0) \\ x_2 \cos(m\theta_1) - x_3 \sin(m\theta_1) \\ x_3 \cos(m\theta_1) + x_2 \sin(m\theta_1) \\ \vdots \end{pmatrix}$$

每两个维度组成一对,用对应的角度进行旋转。

3.3 在 Attention 中的应用

在 Self-Attention 中,RoPE 应用于 Query 和 Key:

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{Q_{\text{rope}} K_{\text{rope}}^T}{\sqrt{d}}\right) V$$

其中 $Q_{\text{rope}} = \text{RoPE}(Q, m)$$K_{\text{rope}} = \text{RoPE}(K, n)$

由于旋转的特性,$Q_{\text{rope}} \cdot K_{\text{rope}}^T$ 的结果只依赖于相对位置 $m - n$

4. 从零手写 RoPE 实现

现在让我们一步步实现 RoPE。

4.1 Step 1: 生成旋转频率

import torch
import numpy as np

def get_rotary_frequencies(dim: int, seq_len: int, theta: float = 10000.0):
    """
    生成 RoPE 的旋转频率

    Args:
        dim: 嵌入维度(必须是偶数)
        seq_len: 序列长度
        theta: 基础频率参数

    Returns:
        freqs: shape (seq_len, dim // 2),每个位置每个维度对的频率
    """
    # 计算每个维度对的基础频率
    # theta_i = 10000^(-2i/d),i = 0, 1, ..., d/2-1
    i = torch.arange(0, dim // 2, dtype=torch.float32)
    freqs = theta ** (-2 * i / dim)  # shape: (dim // 2,)

    # 生成位置索引
    positions = torch.arange(seq_len, dtype=torch.float32)  # shape: (seq_len,)

    # 计算每个位置的角度:position * frequency
    # 外积得到 (seq_len, dim // 2) 的矩阵
    angles = torch.outer(positions, freqs)  # shape: (seq_len, dim // 2)

    return angles


# 测试
dim = 64
seq_len = 128
angles = get_rotary_frequencies(dim, seq_len)
print(f"Angles shape: {angles.shape}")  # (128, 32)
print(f"Angles[0]: {angles[0][:5]}")    # 位置 0 的前 5 个维度对的角度
print(f"Angles[1]: {angles[1][:5]}")    # 位置 1 的前 5 个维度对的角度

4.2 Step 2: 构建 sin/cos 缓存

def get_rotary_embedding(dim: int, seq_len: int, theta: float = 10000.0):
    """
    预计算 RoPE 的 sin 和 cos 值

    Returns:
        cos: shape (seq_len, dim)
        sin: shape (seq_len, dim)
    """
    angles = get_rotary_frequencies(dim, seq_len, theta)

    # 计算 cos 和 sin
    cos = torch.cos(angles)
    sin = torch.sin(angles)

    # 将 (seq_len, dim//2) 扩展为 (seq_len, dim),与 rotate_half 配合使用
    cos = torch.cat([cos, cos], dim=-1)
    sin = torch.cat([sin, sin], dim=-1)

    return cos, sin


# 测试
cos, sin = get_rotary_embedding(dim=64, seq_len=128)
print(f"Cos shape: {cos.shape}")  # (128, 64)
print(f"Sin shape: {sin.shape}")  # (128, 64)

4.3 Step 3: 应用旋转变换

这是 RoPE 的核心,参考 LLaMA 的实现方式:

def rotate_half(x):
    """
    将向量的前半部分和后半部分交换,并对后半部分取负
    [x1, x2, x3, x4] -> [-x3, -x4, x1, x2]

    这是实现旋转的关键辅助函数
    """
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin):
    """
    应用 RoPE 旋转变换(LLaMA 风格实现)

    Args:
        q: Query,shape (batch, seq_len, num_heads, head_dim)
        k: Key,shape (batch, seq_len, num_heads, head_dim)
        cos: shape (seq_len, head_dim)
        sin: shape (seq_len, head_dim)

    Returns:
        q_rot, k_rot: 旋转后的 Query 和 Key

    旋转公式:
        q' = q * cos + rotate_half(q) * sin
        k' = k * cos + rotate_half(k) * sin
    """
    # 调整 cos/sin 形状以便广播: (seq_len, head_dim) -> (1, seq_len, 1, head_dim)
    cos = cos.unsqueeze(0).unsqueeze(2)
    sin = sin.unsqueeze(0).unsqueeze(2)

    # 应用旋转
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)

    return q_embed, k_embed

为什么这个公式是对的?

回顾 2D 旋转公式:

$$\begin{pmatrix} x' \\ y' \end{pmatrix} = \begin{pmatrix} x \cos\theta - y \sin\theta \\ x \sin\theta + y \cos\theta \end{pmatrix}$$

对于向量 $[x, y]$rotate_half 会把它变成 $[-y, x]$,所以:

原向量 * cos + rotate_half(原向量) * sin
= [x, y] * cos + [-y, x] * sin
= [x*cos - y*sin, y*cos + x*sin]

这正是旋转公式!

4.4 Step 4: 完整的 RoPE 模块

class RotaryPositionEmbedding(torch.nn.Module):
    """
    完整的 RoPE 实现
    """
    def __init__(self, dim: int, max_seq_len: int = 4096, theta: float = 10000.0):
        """
        Args:
            dim: 每个注意力头的维度
            max_seq_len: 最大序列长度
            theta: 基础频率参数
        """
        super().__init__()
        self.dim = dim
        self.max_seq_len = max_seq_len
        self.theta = theta

        # 预计算并缓存 sin/cos 值
        cos, sin = get_rotary_embedding(dim, max_seq_len, theta)
        self.register_buffer('cos_cached', cos)
        self.register_buffer('sin_cached', sin)

    def forward(self, q: torch.Tensor, k: torch.Tensor, positions: torch.Tensor = None):
        """
        对 Query 和 Key 应用 RoPE

        Args:
            q: Query,shape (batch, seq_len, num_heads, head_dim)
            k: Key,shape (batch, seq_len, num_heads, head_dim)
            positions: 位置索引,默认为 [0, 1, 2, ..., seq_len-1]

        Returns:
            q_rot, k_rot: 旋转后的 Query 和 Key
        """
        seq_len = q.shape[1]

        # 获取当前序列长度的 cos/sin
        cos = self.cos_cached[:seq_len]
        sin = self.sin_cached[:seq_len]

        # 应用旋转
        q_rot, k_rot = apply_rotary_pos_emb(q, k, cos, sin)

        return q_rot, k_rot


# 测试
rope = RotaryPositionEmbedding(dim=64, max_seq_len=4096)

# 模拟输入
batch_size = 2
seq_len = 128
num_heads = 8
head_dim = 64

q = torch.randn(batch_size, seq_len, num_heads, head_dim)
k = torch.randn(batch_size, seq_len, num_heads, head_dim)

q_rot, k_rot = rope(q, k)
print(f"Q_rot shape: {q_rot.shape}")
print(f"K_rot shape: {k_rot.shape}")

4.5 验证:相对位置不变性

RoPE 最重要的性质是:两个位置的 Query 和 Key 的内积只依赖于它们的相对位置。让我们验证一下:

def verify_relative_position_invariance():
    """
    验证 RoPE 的相对位置不变性
    """
    dim = 64
    max_seq_len = 100

    # 预计算 cos/sin
    cos, sin = get_rotary_embedding(dim, max_seq_len)

    # 创建两个相同的向量
    torch.manual_seed(42)
    q = torch.randn(1, 1, 1, dim)
    k = torch.randn(1, 1, 1, dim)

    # 场景 1:q 在位置 0,k 在位置 5(相对位置 = 5)
    cos1_q, sin1_q = cos[0:1], sin[0:1]
    cos1_k, sin1_k = cos[5:6], sin[5:6]

    q1_rot, _ = apply_rotary_pos_emb(q, q, cos1_q, sin1_q)
    _, k1_rot = apply_rotary_pos_emb(k, k, cos1_k, sin1_k)

    dot_product_1 = (q1_rot * k1_rot).sum()

    # 场景 2:q 在位置 10,k 在位置 15(相对位置仍然是 5)
    cos2_q, sin2_q = cos[10:11], sin[10:11]
    cos2_k, sin2_k = cos[15:16], sin[15:16]

    q2_rot, _ = apply_rotary_pos_emb(q, q, cos2_q, sin2_q)
    _, k2_rot = apply_rotary_pos_emb(k, k, cos2_k, sin2_k)

    dot_product_2 = (q2_rot * k2_rot).sum()

    print(f"位置 (0, 5) 的内积: {dot_product_1.item():.6f}")
    print(f"位置 (10, 15) 的内积: {dot_product_2.item():.6f}")
    print(f"差异: {abs(dot_product_1.item() - dot_product_2.item()):.10f}")
    print("验证通过!" if abs(dot_product_1.item() - dot_product_2.item()) < 1e-5 else "验证失败!")


verify_relative_position_invariance()

5. 为什么 RoPE 位置编码好?

  • 相对位置编码:内积只依赖相对位置,天然适合语言建模
  • 外推性能好:配合 NTK/YaRN 可以泛化到更长序列
  • 计算高效:不增加额外的位置嵌入,只需旋转操作
  • 无需额外参数:基于固定的三角函数,不增加可学习参数
  • 兼容 KV Cache:缓存的 K 无需重新计算位置编码

6. 可视化理解 RoPE (Bonus)

以下是选看内容(为了帮助理解 RoPE 的内容)

6.1 位置编码热力图

import matplotlib.pyplot as plt
import seaborn as sns

def visualize_rope_heatmap():
    """
    可视化 RoPE 编码的热力图
    """
    dim = 64
    seq_len = 128

    cos, sin = get_rotary_embedding(dim, seq_len)

    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    # Cos 热力图
    sns.heatmap(cos.numpy()[:64, :], ax=axes[0], cmap='RdBu', center=0)
    axes[0].set_title('RoPE Cos Values')
    axes[0].set_xlabel('Dimension')
    axes[0].set_ylabel('Position')

    # Sin 热力图
    sns.heatmap(sin.numpy()[:64, :], ax=axes[1], cmap='RdBu', center=0)
    axes[1].set_title('RoPE Sin Values')
    axes[1].set_xlabel('Dimension')
    axes[1].set_ylabel('Position')

    plt.tight_layout()
    plt.savefig('rope_heatmap.png', dpi=150)
    plt.show()

    print("观察要点:")
    print("1. 低维度(左侧)变化快 -> 捕捉短距离依赖")
    print("2. 高维度(右侧)变化慢 -> 捕捉长距离依赖")
    print("3. 每个维度都是周期函数,频率不同")


visualize_rope_heatmap()

6.2 2D 旋转动画

import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML

def visualize_2d_rotation():
    """
    可视化 2D 空间中的旋转效果
    """
    # 原始向量
    original = np.array([1.0, 0.5])

    # 不同位置的旋转角度
    positions = range(0, 16)
    theta_base = 0.5  # 基础角度

    fig, ax = plt.subplots(figsize=(8, 8))

    colors = plt.cm.viridis(np.linspace(0, 1, len(positions)))

    for pos, color in zip(positions, colors):
        angle = pos * theta_base
        # 旋转矩阵
        cos_a, sin_a = np.cos(angle), np.sin(angle)
        rotated = np.array([
            original[0] * cos_a - original[1] * sin_a,
            original[0] * sin_a + original[1] * cos_a
        ])

        ax.arrow(0, 0, rotated[0], rotated[1],
                head_width=0.05, head_length=0.03,
                fc=color, ec=color, alpha=0.7,
                label=f'pos={pos}' if pos % 4 == 0 else None)

    ax.set_xlim(-1.5, 1.5)
    ax.set_ylim(-1.5, 1.5)
    ax.set_aspect('equal')
    ax.grid(True, alpha=0.3)
    ax.axhline(y=0, color='k', linewidth=0.5)
    ax.axvline(x=0, color='k', linewidth=0.5)
    ax.legend(loc='upper right')
    ax.set_title('RoPE: 不同位置的向量旋转效果\n(同一向量在不同位置被旋转不同角度)')

    plt.savefig('rope_rotation.png', dpi=150)
    plt.show()

    print("观察要点:")
    print("1. 同一向量在不同位置被旋转不同角度")
    print("2. 位置越大,旋转角度越大")
    print("3. 这就是 RoPE 编码位置信息的方式")


visualize_2d_rotation()

6.3 相对位置注意力分数

def visualize_relative_attention():
    """
    可视化 RoPE 对注意力分数的影响
    """
    dim = 64
    seq_len = 32

    # 生成随机 Q 和 K
    torch.manual_seed(42)
    q = torch.randn(1, seq_len, 1, dim)
    k = torch.randn(1, seq_len, 1, dim)

    # 应用 RoPE
    rope = RotaryPositionEmbedding(dim, seq_len)
    q_rot, k_rot = rope(q, k)

    # 计算注意力分数
    # 无 RoPE
    attn_no_rope = torch.matmul(q.squeeze(), k.squeeze().transpose(-2, -1)) / np.sqrt(dim)

    # 有 RoPE
    attn_with_rope = torch.matmul(q_rot.squeeze(), k_rot.squeeze().transpose(-2, -1)) / np.sqrt(dim)

    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    sns.heatmap(attn_no_rope.squeeze().detach().numpy(), ax=axes[0], cmap='viridis')
    axes[0].set_title('Attention Scores (No RoPE)')
    axes[0].set_xlabel('Key Position')
    axes[0].set_ylabel('Query Position')

    sns.heatmap(attn_with_rope.squeeze().detach().numpy(), ax=axes[1], cmap='viridis')
    axes[1].set_title('Attention Scores (With RoPE)')
    axes[1].set_xlabel('Key Position')
    axes[1].set_ylabel('Query Position')

    plt.tight_layout()
    plt.savefig('rope_attention.png', dpi=150)
    plt.show()

    print("观察要点:")
    print("1. 无 RoPE 时,注意力分数与位置无关")
    print("2. 有 RoPE 时,注意力分数体现位置关系")
    print("3. 对角线附近通常有更高的注意力(局部依赖)")


visualize_relative_attention()

6.4. 实际应用:集成到 Transformer

最后,让我们看看如何将 RoPE 集成到完整的 Multi-Head Attention 中:

class MultiHeadAttentionWithRoPE(torch.nn.Module):
    """
    带 RoPE 的多头注意力
    """
    def __init__(self, d_model: int, num_heads: int, max_seq_len: int = 4096):
        super().__init__()
        assert d_model % num_heads == 0

        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads

        self.q_proj = torch.nn.Linear(d_model, d_model, bias=False)
        self.k_proj = torch.nn.Linear(d_model, d_model, bias=False)
        self.v_proj = torch.nn.Linear(d_model, d_model, bias=False)
        self.o_proj = torch.nn.Linear(d_model, d_model, bias=False)

        # RoPE
        self.rope = RotaryPositionEmbedding(self.head_dim, max_seq_len)

    def forward(self, x: torch.Tensor, mask: torch.Tensor = None):
        """
        Args:
            x: 输入,shape (batch, seq_len, d_model)
            mask: 注意力掩码,shape (seq_len, seq_len)
        """
        batch, seq_len, _ = x.shape

        # 线性投影
        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)

        # 重塑为多头形式
        q = q.view(batch, seq_len, self.num_heads, self.head_dim)
        k = k.view(batch, seq_len, self.num_heads, self.head_dim)
        v = v.view(batch, seq_len, self.num_heads, self.head_dim)

        # 应用 RoPE(只对 Q 和 K)
        q, k = self.rope(q, k)

        # 转置用于矩阵乘法:(batch, num_heads, seq_len, head_dim)
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        # 计算注意力分数
        scale = self.head_dim ** -0.5
        attn_scores = torch.matmul(q, k.transpose(-2, -1)) * scale

        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, float('-inf'))

        attn_probs = torch.softmax(attn_scores, dim=-1)

        # 加权求和
        attn_output = torch.matmul(attn_probs, v)

        # 重塑并输出投影
        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(batch, seq_len, self.d_model)
        output = self.o_proj(attn_output)

        return output


# 测试
mha = MultiHeadAttentionWithRoPE(d_model=512, num_heads=8)
x = torch.randn(2, 128, 512)
output = mha(x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")

7. 参考资料

  1. RoFormer: Enhanced Transformer with Rotary Position Embedding
  2. LLaMA: Open and Efficient Foundation Language Models
  3. YaRN: Efficient Context Window Extension of Large Language Models
  4. Extending Context Window of Large Language Models via Positional Interpolation
  5. Transformer升级之路:2、博采众长的旋转式位置编码
  6. 十分钟读懂旋转编码(RoPE)
  7. 解密旋转位置编码:数学基础、代码实现与绝对编码一体化探索

8. 其他

最后欢迎关注我,基本全网同名 chaofa用代码打点酱油