RoPE → Attention 完整
好的我帮你把之前的“Transformer 输入 → RoPE → Attention”全流程整理成一个完整的、连贯的文档。每一步都包含数学表达 PyTorch 示例代码方便直接参考或实现。Transformer 前向 RoPE 全流程1️⃣ 输入Token → Embedding数学表达给定 token 序列t1,t2,...,tnt_1, t_2, ..., t_nt1,t2,...,tnembedding lookupEEmbedding(t)∈Rn×dE \text{Embedding}(t) \in \mathbb{R}^{n \times d}EEmbedding(t)∈Rn×d示例代码Python运行importtorchimporttorch.nnasnnvocab_size10000d8# 为了演示实际 d 可能 512/1024embeddingnn.Embedding(vocab_size,d)tokenstorch.tensor([[1,5,9,2]])# shape: (1, n)Eembedding(tokens)# (1, 4, 8)print(E.shape)# (1, 4, 8)2️⃣ 线性投影得到 Q / K / V数学表达QEWQ,KEWK,VEWVQ E W_Q,\quad K E W_K,\quad V E W_VQEWQ,KEWK,VEWVWQ,WK,WV∈Rd×dkW_Q, W_K, W_V \in \mathbb{R}^{d \times d_k}WQ,WK,WV∈Rd×dk示例代码Python运行d_kdW_Qnn.Linear(d,d_k,biasFalse)W_Knn.Linear(d,d_k,biasFalse)W_Vnn.Linear(d,d_k,biasFalse)QW_Q(E)KW_K(E)VW_V(E)3️⃣ 构造 RoPE 角度数学表达θi,pospos100002i/dk\theta_{i,pos} \frac{pos}{10000^{2i/d_k}}θi,pos100002i/dkposi0,1,...,dk/2−1i 0,1,...,d_k/2-1i0,1,...,dk/2−1pos0,1,...,n−1pos 0,1,...,n-1pos0,1,...,n−1示例代码Python运行defget_rope_angles(seq_len,dim):postorch.arange(seq_len).float()# (n,)itorch.arange(0,dim,2).float()# (d/2,)inv_freq1.0/(10000**(i/dim))# (d/2,)thetatorch.outer(pos,inv_freq)# (n, d/2)returnthetathetaget_rope_angles(seq_lenE.shape[1],dimd_k)4️⃣ 计算 sin / cossin(θ),cos(θ)\sin(\theta), \quad \cos(\theta)sin(θ),cos(θ)示例代码Python运行sintheta.sin()[None, :, :]# (1, n, d/2)costheta.cos()[None, :, :]5️⃣ 应用 RoPE二维旋转数学表达x2i′x2icosθ−x2i1sinθx2i1′x2isinθx2i1cosθ\begin{aligned} x_{2i} x_{2i} \cos\theta - x_{2i1} \sin\theta \\ x_{2i1} x_{2i} \sin\theta x_{2i1} \cos\theta \end{aligned}x2i′x2i1′x2icosθ−x2i1sinθx2isinθx2i1cosθ示例代码Python运行defapply_rope(x,sin,cos):# x: (B, n, d)x1x[..., ::2]# 偶数维x2x[...,1::2]# 奇数维x_rottorch.cat([x1*cos-x2*sin,x1*sinx2*cos],dim-1)returnx_rot6️⃣ 对 Q / K 应用 RoPEQ′RoPE(Q),K′RoPE(K)Q \text{RoPE}(Q), \quad K \text{RoPE}(K)Q′RoPE(Q),K′RoPE(K)示例代码Python运行Q_rotapply_rope(Q,sin,cos)K_rotapply_rope(K,sin,cos)7️⃣ Attention 计算数学表达AQ′K′TdkA \frac{Q {K}^T}{\sqrt{d_k}}AdkQ′K′Tαsoftmax(A)\alpha \mathrm{softmax}(A)αsoftmax(A)OutputαV\text{Output} \alpha VOutputαV示例代码Python运行scorestorch.matmul(Q_rot,K_rot.transpose(-2,-1))/(d_k**0.5)attntorch.softmax(scores,dim-1)outputtorch.matmul(attn,V)print(output.shape)# (1, n, d)8️⃣ 流程总结数学版EEmbedding(t)QEWQ,KEWK,VEWVQ′Rθ(pos)(Q),K′Rθ(pos)(K)Attnsoftmax(Q′K′Tdk)V\begin{aligned} E \text{Embedding}(t) \\ Q E W_Q, \quad K E W_K, \quad V E W_V \\ Q R_{\theta(pos)}(Q), \quad K R_{\theta(pos)}(K) \\ \text{Attn} \mathrm{softmax}\Big( \frac{Q {K}^T}{\sqrt{d_k}} \Big) V \end{aligned}EQQ′AttnEmbedding(t)EWQ,KEWK,VEWVRθ(pos)(Q),K′Rθ(pos)(K)softmax(dkQ′K′T)V代码版Python运行# 1. EmbeddingEembedding(tokens)# 2. Linear projectionQW_Q(E)KW_K(E)VW_V(E)# 3. RoPE anglesthetaget_rope_angles(E.shape[1],d_k)sintheta.sin()[None, :, :]costheta.cos()[None, :, :]# 4. Apply RoPEQ_rotapply_rope(Q,sin,cos)K_rotapply_rope(K,sin,cos)# 5. Attentionscorestorch.matmul(Q_rot,K_rot.transpose(-2,-1))/(d_k**0.5)attntorch.softmax(scores,dim-1)outputtorch.matmul(attn,V)9️⃣ 核心理解RoPE 只作用在 Q/KV 不变。旋转编码位置 → attention 依赖相对位置。相比传统EPEPEP或 learned PERoPE 更稳定、可外推、零参数。与多头 attention 配合 → 每个 head 可以学习不同频率的旋转捕捉多尺度关系。如果你愿意我可以帮你画一个完整流程图把Token → Embedding → Q/K/V → RoPE → Attention → Output的每个维度、形状、矩阵关系都标出来非常直观。你希望我画吗# -*- coding: utf-8 -*- Created on Fri Apr 3 14:13:21 2026 author: luogan import torch import torch.nn as nn vocab_size 10000 d 8 # 为了演示实际 d 可能 512/1024 embedding nn.Embedding(vocab_size, d) tokens torch.tensor([[1, 5, 9, 2]]) # shape: (1, n) E embedding(tokens) # (1, 4, 8) print(E.shape) # (1, 4, 8) d_k d W_Q nn.Linear(d, d_k, biasFalse) W_K nn.Linear(d, d_k, biasFalse) W_V nn.Linear(d, d_k, biasFalse) Q W_Q(E) K W_K(E) V W_V(E) def get_rope_angles(seq_len, dim): pos torch.arange(seq_len).float() # (n,) i torch.arange(0, dim, 2).float() # (d/2,) inv_freq 1.0 / (10000 ** (i / dim)) # (d/2,) theta torch.outer(pos, inv_freq) # (n, d/2) return theta theta get_rope_angles(seq_lenE.shape[1], dimd_k) sin theta.sin()[None, :, :] # (1, n, d/2) cos theta.cos()[None, :, :] def apply_rope(x, sin, cos): # x: (B, n, d) x1 x[..., ::2] # 偶数维 x2 x[..., 1::2] # 奇数维 x_rot torch.cat([ x1 * cos - x2 * sin, x1 * sin x2 * cos ], dim-1) return x_rot Q_rot apply_rope(Q, sin, cos) K_rot apply_rope(K, sin, cos) scores torch.matmul(Q_rot, K_rot.transpose(-2, -1)) / (d_k ** 0.5) attn torch.softmax(scores, dim-1) output torch.matmul(attn, V) print(output.shape) # (1, n, d)
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2480679.html
如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!