文章目录
- RoPE
- 准备知识
- RoPE的推导
- RoPE的代码实现
- 参考资料
 
 
RoPE
RoPE(Rotary Position Embedding,旋转式位置编码)是一种配合Attention机制能达到“通过绝对位置编码的方式实现相对位置编码”的设计。在2021年2月由苏剑林提出,是现在的大模型最常用的位置编码。
记 
     
      
       
       
         q 
        
       
      
        \mathbf{q} 
       
      
    q和 
     
      
       
       
         k 
        
       
      
        \mathbf{k} 
       
      
    k为self-attention的q和k向量, 
     
      
       
       
         j 
        
       
      
        j 
       
      
    j是元素索引。假设有 
     
      
       
       
         0 
        
       
         < 
        
       
         θ 
        
       
         ≤ 
        
        
        
          π 
         
         
         
           2 
          
         
           N 
          
         
        
       
      
        0 < \theta\leq \frac{\pi}{2N} 
       
      
    0<θ≤2Nπ,N是最大序列长度。 i表示复数的虚数单位,  
     
      
       
       
         ⟨ 
        
       
         ⟩ 
        
       
      
        \langle \rangle 
       
      
    ⟨⟩为内积符号,  
     
      
       
        
        
          z 
         
        
          ‾ 
         
        
       
      
        \overline{z} 
       
      
    z表示复数z的共轭复数(注:有些地方会将共轭复数记作 
     
      
       
        
        
          z 
         
        
          ∗ 
         
        
       
      
        z^* 
       
      
    z∗)。RoPE可表示为下述过程:
  
      
       
        
         
          
           
            
            
              RoPE 
             
            
              ( 
             
            
              x 
             
            
              , 
             
            
              m 
             
            
              ) 
             
            
           
          
          
           
            
             
            
              = 
             
            
              x 
             
             
             
               e 
              
              
              
                i 
               
              
                m 
               
              
                θ 
               
              
             
            
           
          
         
         
          
           
            
            
              ⟨ 
             
            
              RoPE 
             
            
              ( 
             
             
             
               q 
              
             
               j 
              
             
            
              , 
             
            
              m 
             
            
              ) 
             
            
              , 
             
            
              RoPE 
             
            
              ( 
             
             
             
               k 
              
             
               j 
              
             
            
              , 
             
            
              n 
             
            
              ) 
             
            
              ⟩ 
             
            
           
          
          
           
            
             
            
              = 
             
            
              ⟨ 
             
             
             
               q 
              
             
               j 
              
             
             
             
               e 
              
              
              
                i 
               
              
                m 
               
              
                θ 
               
              
             
            
              , 
             
             
             
               k 
              
             
               j 
              
             
             
             
               e 
              
              
              
                i 
               
              
                n 
               
              
                θ 
               
              
             
            
              ⟩ 
             
            
           
          
         
         
          
           
            
           
          
          
           
            
             
            
              = 
             
             
             
               q 
              
             
               j 
              
             
             
             
               k 
              
             
               j 
              
             
             
             
               e 
              
              
              
                i 
               
              
                m 
               
              
                θ 
               
              
             
             
              
              
                e 
               
               
               
                 i 
                
               
                 n 
                
               
                 θ 
                
               
              
             
               ‾ 
              
             
            
           
          
         
         
          
           
            
           
          
          
           
            
             
            
              = 
             
             
             
               q 
              
             
               j 
              
             
             
             
               k 
              
             
               j 
              
             
             
             
               e 
              
              
              
                i 
               
              
                ( 
               
              
                m 
               
              
                − 
               
              
                n 
               
              
                ) 
               
              
                θ 
               
              
             
            
           
          
         
         
          
           
            
           
          
          
           
            
             
            
              = 
             
            
              RoPE 
             
            
              ( 
             
             
             
               q 
              
             
               j 
              
             
             
             
               k 
              
             
               j 
              
             
            
              , 
             
            
              m 
             
            
              − 
             
            
              n 
             
            
              ) 
             
            
           
          
         
        
       
         \begin {aligned} \text{RoPE}(\mathbf{x}, m) &= \mathbf{x} e^{im\theta} \\ \langle \text{RoPE}(q_j, m), \text{RoPE}(k_j,n) \rangle &= \langle q_j e^{im\theta}, k_j e^{in\theta} \rangle \\ &= q_jk_j e^{im\theta} \overline{e^{in \theta}} \\ &= q_jk_j e^{i(m-n)\theta} \\ &= \text{RoPE}(q_j k_j, m-n) \end {aligned} 
        
       
     RoPE(x,m)⟨RoPE(qj,m),RoPE(kj,n)⟩=xeimθ=⟨qjeimθ,kjeinθ⟩=qjkjeimθeinθ=qjkjei(m−n)θ=RoPE(qjkj,m−n)
RoPE的示意图如下图(来自RoFormer论文)
 
准备知识
-  复数的笛卡尔积形式(Cartesian form): z = a + i b z = a + ib z=a+ib 
-  复数的极坐标形式(polar form): z = r ( cos  θ + i sin  θ ) z=r(\cos \theta + i \sin \theta) z=r(cosθ+isinθ), 其中 r = ∣ z ∣ = a 2 + b 2 r=|z|=\sqrt{a^2 + b^2} r=∣z∣=a2+b2, θ = arg  ( x ) = tan  − 1 b a \theta = \arg(x) = \tan^{-1} \frac{b}{a} θ=arg(x)=tan−1ab 
-  欧拉公式: e i x = cos  ( x ) + i sin  ( x ) e^{ix} = \cos(x) + i \sin(x) eix=cos(x)+isin(x) 
-  复数的指数形式(exponential form): z = r e i θ z = r \text{e}^{i\theta} z=reiθ, 其中 r = ∣ z ∣ = a 2 + b 2 r=|z|=\sqrt{a^2 + b^2} r=∣z∣=a2+b2, θ = arg  ( x ) = tan  − 1 b a \theta = \arg(x) = \tan^{-1} \frac{b}{a} θ=arg(x)=tan−1ab 
-  在复数的指数形式下复数 z = r e i θ z = r \text{e}^{i\theta} z=reiθ的共轭复数为 z ‾ = r e − i θ \overline{z} = r \text{e}^{-i\theta} z=re−iθ, 两个复数 z = r e i θ z = r \text{e}^{i\theta} z=reiθ和 w = t e i ϕ w = t \text{e}^{i\phi} w=teiϕ的乘积为 z w = r t e i ( θ + ϕ ) zw = rt \text{e}^{i(\theta + \phi)} zw=rtei(θ+ϕ) 
-  在复数的极坐标形式下两个复数 z = r ( cos  θ + i sin  θ ) z=r(\cos \theta + i \sin \theta) z=r(cosθ+isinθ)和 w = t ( cos  ϕ + i sin  ϕ ) w=t(\cos \phi + i \sin \phi) w=t(cosϕ+isinϕ)的乘积为 z w = r t ( cos  ( θ + ϕ ) + i sin  ( θ + ϕ ) ) zw=rt(\cos(\theta+\phi) + i\sin(\theta + \phi)) zw=rt(cos(θ+ϕ)+isin(θ+ϕ)) 
-  复数 z = a + i b z = a + ib z=a+ib表示成矩阵时的形式为 ( a − b b a ) \left(\begin{array}{cc} a & -b \\ b & a \end{array}\right) (ab−ba),而旋转矩阵形式为 [ cos  θ − sin  θ sin  θ cos  θ ] \left[\begin{array}{cc} \cos \theta & -\sin \theta \\ \sin \theta & \cos \theta \end{array}\right] [cosθsinθ−sinθcosθ],所以复数乘法的几何意义为将向量逆时针旋转 θ \theta θ(也可以从复数极坐标形式的乘法来理解旋转)  
RoPE的推导
假设有函数f(x, l)给位置l处的元素x添加绝对位置信息,对于 
     
      
       
       
         q 
        
       
      
        \mathbf{q} 
       
      
    q和 
     
      
       
       
         k 
        
       
      
        \mathbf{k} 
       
      
    k用函数f编码后有:
  
      
       
        
         
          
          
            q 
           
          
            ~ 
           
          
         
           m 
          
         
        
          = 
         
        
          f 
         
        
          ( 
         
        
          q 
         
        
          , 
         
        
          m 
         
        
          ) 
         
        
          , 
         
         
         
          
          
            k 
           
          
            ~ 
           
          
         
           n 
          
         
        
          = 
         
        
          f 
         
        
          ( 
         
        
          k 
         
        
          , 
         
        
          n 
         
        
          ) 
         
         
        
          ( 
         
        
          1 
         
        
          ) 
         
        
       
         \tilde{\mathbf{q}}_m = f(\mathbf{q}, m), \qquad \tilde{\mathbf{k}}_n = f(\mathbf{k}, n) \qquad (1) 
        
       
     q~m=f(q,m),k~n=f(k,n)(1)
 即我们希望为 
     
      
       
       
         q 
        
       
      
        \mathbf{q} 
       
      
    q和 
     
      
       
       
         k 
        
       
      
        \mathbf{k} 
       
      
    k设计函数 
     
      
       
       
         f 
        
       
         ( 
        
       
         ⋅ 
        
       
         , 
        
       
         m 
        
       
         ) 
        
       
      
        f(\cdot, m) 
       
      
    f(⋅,m)和 
     
      
       
       
         f 
        
       
         ( 
        
       
         ⋅ 
        
       
         , 
        
       
         n 
        
       
         ) 
        
       
      
        f(\cdot, n) 
       
      
    f(⋅,n),使用函数编码后, 
     
      
       
        
         
         
           q 
          
         
           ~ 
          
         
        
          m 
         
        
       
      
        \tilde{\mathbf{q}}_m 
       
      
    q~m和 
     
      
       
        
         
         
           k 
          
         
           ~ 
          
         
        
          n 
         
        
       
      
        \tilde{\mathbf{k}}_n 
       
      
    k~n带有了位置m和n的绝对位置信息。因为self-attention的核心运算是内积,我们希望 
     
      
       
       
         q 
        
       
      
        \mathbf{q} 
       
      
    q和 
     
      
       
       
         k 
        
       
      
        \mathbf{k} 
       
      
    k内积的结果有相对位置信息,即我们假设存在有如下恒等关系:
  
      
       
        
        
          ⟨ 
         
        
          f 
         
        
          ( 
         
        
          q 
         
        
          , 
         
        
          m 
         
        
          ) 
         
        
          , 
         
        
          f 
         
        
          ( 
         
        
          k 
         
        
          , 
         
        
          n 
         
        
          ) 
         
        
          ⟩ 
         
        
          = 
         
        
          g 
         
        
          ( 
         
        
          q 
         
        
          , 
         
        
          k 
         
        
          , 
         
        
          m 
         
        
          − 
         
        
          n 
         
        
          ) 
         
         
        
          ( 
         
        
          2 
         
        
          ) 
         
        
       
         \langle f(\mathbf{q},m),f(\mathbf{k}, n) \rangle = g(\mathbf{q}, \mathbf{k}, m-n) \qquad (2) 
        
       
     ⟨f(q,m),f(k,n)⟩=g(q,k,m−n)(2)
 我们的目标是求出这个恒等式的一个解。先将求解过程中的初始条件设为 
     
      
       
       
         f 
        
       
         ( 
        
       
         q 
        
       
         , 
        
       
         0 
        
       
         ) 
        
       
         = 
        
       
         q 
        
       
      
        f(\mathbf{q}, 0) = \mathbf{q} 
       
      
    f(q,0)=q,  
     
      
       
       
         f 
        
       
         ( 
        
       
         k 
        
       
         , 
        
       
         0 
        
       
         ) 
        
       
         = 
        
       
         k 
        
       
      
        f(\mathbf{k}, 0) = \mathbf{k} 
       
      
    f(k,0)=k,这可以理解为没有位置编码信息加入的情形。
我们先考虑二维情形,并借助复数来求解。用复数的指数形式来表示函数:
f ( q , m ) = R f ( q , m ) e i Θ f ( q , m ) ( 3 a ) f ( k , n ) = R f ( k , n ) e i Θ f ( k , n ) ( 3 b ) g ( q , k , m − n ) = R g ( q , k , m − n ) e i Θ g ( q , k , m − n ) ( 3 c ) \begin{align*} f(\mathbf{q}, m) &= R_f(\mathbf{q}, m)e^{i\Theta_f(\mathbf{q}, m)} \qquad (3a)\\ f(\mathbf{k}, n) &= R_f(\mathbf{k}, n)e^{i\Theta_f(\mathbf{k}, n)} \qquad (3b) \\ g(\mathbf{q}, \mathbf{k}, m - n) &= R_g(\mathbf{q}, \mathbf{k}, m - n)e^{i\Theta_g(\mathbf{q}, \mathbf{k}, m - n)} \qquad (3c) \end{align*} f(q,m)f(k,n)g(q,k,m−n)=Rf(q,m)eiΘf(q,m)(3a)=Rf(k,n)eiΘf(k,n)(3b)=Rg(q,k,m−n)eiΘg(q,k,m−n)(3c)
上式中的 R f R_f Rf, R g R_g Rg 是f和g的径向分量(radial component), Θ f \Theta_f Θf, Θ g \Theta_g Θg是f和g的幅角分量(angular components)。将它们代到恒等表达式(2)中,可以得到:
R f ( q , m ) R f ( k , n ) = R g ( q , k , m − n ) ( 4 a ) Θ f ( q , m ) − Θ f ( k , n ) = Θ g ( q , k , m − n ) ( 4 b ) \begin{align*} R_f(\mathbf{q}, m) R_f(\mathbf{k}, n) &= R_g(\mathbf{q}, \mathbf{k}, m - n) \qquad (4a) \\ \Theta_f(\mathbf{q}, m) - \Theta_f(\mathbf{k}, n) &= \Theta_g(\mathbf{q}, \mathbf{k}, m - n) \qquad (4b) \\ \end{align*} Rf(q,m)Rf(k,n)Θf(q,m)−Θf(k,n)=Rg(q,k,m−n)(4a)=Θg(q,k,m−n)(4b)
由初始条件 
     
      
       
       
         f 
        
       
         ( 
        
       
         q 
        
       
         , 
        
       
         0 
        
       
         ) 
        
       
         = 
        
       
         q 
        
       
      
        f(\mathbf{q}, 0) = \mathbf{q} 
       
      
    f(q,0)=q,  
     
      
       
       
         f 
        
       
         ( 
        
       
         k 
        
       
         , 
        
       
         0 
        
       
         ) 
        
       
         = 
        
       
         k 
        
       
      
        f(\mathbf{k}, 0) = \mathbf{k} 
       
      
    f(k,0)=k我们有( 
     
      
       
       
         ∣ 
        
       
         ∣ 
        
       
         q 
        
       
         ∣ 
        
       
         ∣ 
        
       
      
        ||\mathbf{q}|| 
       
      
    ∣∣q∣∣, 
     
      
       
       
         ∣ 
        
       
         ∣ 
        
       
         k 
        
       
         ∣ 
        
       
         ∣ 
        
       
      
        ||\mathbf{k}|| 
       
      
    ∣∣k∣∣ 和  
     
      
       
        
        
          θ 
         
        
          q 
         
        
       
      
        \theta_q 
       
      
    θq, 
     
      
       
        
        
          θ 
         
        
          k 
         
        
       
      
        \theta_k 
       
      
    θk是向量 
     
      
       
       
         q 
        
       
      
        \mathbf{q} 
       
      
    q和 
     
      
       
       
         k 
        
       
      
        \mathbf{k} 
       
      
    k在二维平面上的径向和幅角分量):
  
      
       
        
         
          
           
            
             
             
               q 
              
             
               = 
              
             
               ∣ 
              
             
               ∣ 
              
             
               q 
              
             
               ∣ 
              
             
               ∣ 
              
              
              
                e 
               
               
               
                 i 
                
                
                
                  θ 
                 
                
                  q 
                 
                
               
              
             
               = 
              
              
              
                R 
               
              
                q 
               
              
             
               ( 
              
             
               q 
              
             
               , 
              
             
               0 
              
             
               ) 
              
              
              
                e 
               
               
               
                 i 
                
                
                
                  Θ 
                 
                
                  q 
                 
                
               
                 ( 
                
               
                 q 
                
               
                 , 
                
               
                 0 
                
               
                 ) 
                
               
              
             
            
           
          
          
           
            
             
             
               k 
              
             
               = 
              
             
               ∣ 
              
             
               ∣ 
              
             
               k 
              
             
               ∣ 
              
             
               ∣ 
              
              
              
                e 
               
               
               
                 i 
                
                
                
                  θ 
                 
                
                  k 
                 
                
               
              
             
               = 
              
              
              
                R 
               
              
                q 
               
              
             
               ( 
              
             
               k 
              
             
               , 
              
             
               0 
              
             
               ) 
              
              
              
                e 
               
               
               
                 i 
                
                
                
                  Θ 
                 
                
                  k 
                 
                
               
                 ( 
                
               
                 k 
                
               
                 , 
                
               
                 0 
                
               
                 ) 
                
               
              
             
            
           
          
         
         
        
          ( 
         
        
          5 
         
        
          ) 
         
        
       
         \begin{aligned} \mathbf{q} = ||\mathbf{q}|| e^{i\theta_q}=R_q(\mathbf{q}, 0) e^{i \Theta_q(\mathbf{q}, 0)} \\ \mathbf{k} = ||\mathbf{k}|| e^{i\theta_k}=R_q(\mathbf{k}, 0) e^{i \Theta_k(\mathbf{k}, 0)} \end{aligned} \qquad (5) 
        
       
     q=∣∣q∣∣eiθq=Rq(q,0)eiΘq(q,0)k=∣∣k∣∣eiθk=Rq(k,0)eiΘk(k,0)(5)
 设m=n,以及考虑到初始条件 
     
      
       
       
         f 
        
       
         ( 
        
       
         x 
        
       
         , 
        
       
         0 
        
       
         ) 
        
       
         = 
        
       
         x 
        
       
      
        f(\mathbf{x},0)=\mathbf{x} 
       
      
    f(x,0)=x, 由式(4a)可以得到:
  
      
       
        
         
         
           R 
          
         
           f 
          
         
        
          ( 
         
        
          q 
         
        
          , 
         
        
          m 
         
        
          ) 
         
         
         
           R 
          
         
           f 
          
         
        
          ( 
         
        
          k 
         
        
          , 
         
        
          m 
         
        
          ) 
         
        
          = 
         
         
         
           R 
          
         
           g 
          
         
        
          ( 
         
        
          q 
         
        
          , 
         
        
          k 
         
        
          , 
         
        
          0 
         
        
          ) 
         
        
          = 
         
         
         
           R 
          
         
           f 
          
         
        
          ( 
         
        
          q 
         
        
          , 
         
        
          0 
         
        
          ) 
         
         
         
           R 
          
         
           f 
          
         
        
          ( 
         
        
          k 
         
        
          , 
         
        
          0 
         
        
          ) 
         
        
          = 
         
        
          ∥ 
         
        
          q 
         
        
          ∥ 
         
        
          ∥ 
         
        
          k 
         
        
          ∥ 
         
         
        
          ( 
         
        
          6 
         
        
          ) 
         
         
        
       
         R_f(\mathbf{q}, m) R_f(\mathbf{k}, m) = R_g(\mathbf{q}, \mathbf{k}, 0) = R_f(\mathbf{q}, 0) R_f(\mathbf{k}, 0) = \parallel\mathbf{q}\parallel \parallel \mathbf{k}\parallel \qquad (6) \\ 
        
       
     Rf(q,m)Rf(k,m)=Rg(q,k,0)=Rf(q,0)Rf(k,0)=∥q∥∥k∥(6)
由上式我们可以有 R f ( q , m ) = R f ( q , 0 ) = ∣ ∣ q ∣ ∣ R_f(\mathbf{q}, m)=R_f(\mathbf{q}, 0)=|| \mathbf{q} || Rf(q,m)=Rf(q,0)=∣∣q∣∣, R f ( k , m ) = R f ( k , 0 ) = ∣ ∣ k ∣ ∣ R_f(\mathbf{k}, m)=R_f(\mathbf{k}, 0)=|| \mathbf{k} || Rf(k,m)=Rf(k,0)=∣∣k∣∣, R g ( q , k , m − n ) = R g ( q , k , 0 ) = ∣ ∣ q ∣ ∣ ∣ ∣ k ∣ ∣ R_g(\mathbf{q}, \mathbf{k}, m-n)=R_g(\mathbf{q}, \mathbf{k}, 0)=|| \mathbf{q} || || \mathbf{k} || Rg(q,k,m−n)=Rg(q,k,0)=∣∣q∣∣∣∣k∣∣即 R f R_f Rf和 R g R_g Rg不依赖于位置信息。
类似地,设m=n,以及考虑到初始条件 
     
      
       
       
         Θ 
        
       
         ( 
        
       
         x 
        
       
         , 
        
       
         0 
        
       
         ) 
        
       
         = 
        
       
         Θ 
        
       
         ( 
        
       
         x 
        
       
         ) 
        
       
      
        \Theta(\mathbf{x},0)=\Theta(\mathbf{x}) 
       
      
    Θ(x,0)=Θ(x),由式(4b)可以得到( 
     
      
       
       
         Θ 
        
       
         ( 
        
       
         q 
        
       
         ) 
        
       
      
        \Theta(\mathbf{q}) 
       
      
    Θ(q) 和 
     
      
       
       
         Θ 
        
       
         ( 
        
       
         k 
        
       
         ) 
        
       
      
        \Theta(\mathbf{k}) 
       
      
    Θ(k)是向量 
     
      
       
       
         q 
        
       
      
        \mathbf{q} 
       
      
    q和 
     
      
       
       
         k 
        
       
      
        \mathbf{k} 
       
      
    k的幅角):
  
      
       
        
         
         
           Θ 
          
         
           f 
          
         
        
          ( 
         
        
          q 
         
        
          , 
         
        
          m 
         
        
          ) 
         
        
          − 
         
         
         
           Θ 
          
         
           f 
          
         
        
          ( 
         
        
          k 
         
        
          , 
         
        
          m 
         
        
          ) 
         
        
          = 
         
         
         
           Θ 
          
         
           g 
          
         
        
          ( 
         
        
          q 
         
        
          , 
         
        
          k 
         
        
          , 
         
        
          0 
         
        
          ) 
         
        
          = 
         
         
         
           Θ 
          
         
           f 
          
         
        
          ( 
         
        
          q 
         
        
          , 
         
        
          0 
         
        
          ) 
         
        
          − 
         
         
         
           Θ 
          
         
           f 
          
         
        
          ( 
         
        
          k 
         
        
          , 
         
        
          0 
         
        
          ) 
         
        
          = 
         
        
          Θ 
         
        
          ( 
         
        
          q 
         
        
          ) 
         
        
          − 
         
        
          Θ 
         
        
          ( 
         
        
          k 
         
        
          ) 
         
         
        
          ( 
         
        
          7 
         
        
          ) 
         
        
       
         \Theta_f(\mathbf{q}, m) - \Theta_f(\mathbf{k}, m) = \Theta_g(\mathbf{q}, \mathbf{k}, 0) = \Theta_f(\mathbf{q}, 0) - \Theta_f(\mathbf{k}, 0) = \Theta(\mathbf{q}) - \Theta(\mathbf{k}) \qquad (7) 
        
       
     Θf(q,m)−Θf(k,m)=Θg(q,k,0)=Θf(q,0)−Θf(k,0)=Θ(q)−Θ(k)(7)
 将上式的第一项和最后一项移位我们可得 
     
      
       
        
        
          Θ 
         
        
          f 
         
        
       
         ( 
        
       
         q 
        
       
         , 
        
       
         m 
        
       
         ) 
        
       
         − 
        
       
         Θ 
        
       
         ( 
        
       
         q 
        
       
         ) 
        
       
         = 
        
        
        
          Θ 
         
        
          f 
         
        
       
         ( 
        
       
         k 
        
       
         , 
        
       
         m 
        
       
         ) 
        
       
         − 
        
       
         Θ 
        
       
         ( 
        
       
         k 
        
       
         ) 
        
       
      
        \Theta_f(\mathbf{q}, m)- \Theta(\mathbf{q}) = \Theta_f(\mathbf{k}, m) - \Theta(\mathbf{k}) 
       
      
    Θf(q,m)−Θ(q)=Θf(k,m)−Θ(k),所以 
     
      
       
        
        
          Θ 
         
        
          f 
         
        
       
         ( 
        
       
         q 
        
       
         , 
        
       
         m 
        
       
         ) 
        
       
         − 
        
       
         Θ 
        
       
         ( 
        
       
         q 
        
       
         ) 
        
       
      
        \Theta_f(\mathbf{q}, m)- \Theta(\mathbf{q}) 
       
      
    Θf(q,m)−Θ(q)是一个只与m有关与 
     
      
       
       
         q 
        
       
      
        \mathbf{q} 
       
      
    q无关的函数,将其记为 
     
      
       
       
         φ 
        
       
         ( 
        
       
         m 
        
       
         ) 
        
       
      
        \varphi(m) 
       
      
    φ(m),则有 
     
      
       
        
        
          Θ 
         
        
          f 
         
        
       
         ( 
        
       
         q 
        
       
         , 
        
       
         m 
        
       
         ) 
        
       
         = 
        
       
         Θ 
        
       
         ( 
        
       
         q 
        
       
         ) 
        
       
         + 
        
       
         φ 
        
       
         ( 
        
       
         m 
        
       
         ) 
        
       
      
        \Theta_f(\mathbf{q}, m)= \Theta(\mathbf{q}) + \varphi(m) 
       
      
    Θf(q,m)=Θ(q)+φ(m)。令n=m-1,将其代入到式(4b)并移项可得
  
      
       
        
        
          φ 
         
        
          ( 
         
        
          m 
         
        
          ) 
         
        
          − 
         
        
          φ 
         
        
          ( 
         
        
          m 
         
        
          − 
         
        
          1 
         
        
          ) 
         
        
          = 
         
         
         
           Θ 
          
         
           g 
          
         
        
          ( 
         
        
          q 
         
        
          , 
         
        
          k 
         
        
          , 
         
        
          1 
         
        
          ) 
         
        
          + 
         
        
          Θ 
         
        
          ( 
         
        
          k 
         
        
          ) 
         
        
          − 
         
        
          Θ 
         
        
          ( 
         
        
          q 
         
        
          ) 
         
        
       
         \varphi(m)-\varphi(m-1) = \Theta_g(\mathbf{q}, \mathbf{k}, 1) + \Theta(\mathbf{k}) - \Theta(\mathbf{q}) 
        
       
     φ(m)−φ(m−1)=Θg(q,k,1)+Θ(k)−Θ(q)
 因为上式右侧与m无关,所以上式左侧也必须与m无关,因此 
     
      
       
       
         φ 
        
       
      
        \varphi 
       
      
    φ是一个等差数列(arithmetic progression),如果我们设等差数列的初始值 
     
      
       
       
         φ 
        
       
         ( 
        
       
         0 
        
       
         ) 
        
       
         = 
        
       
         0 
        
       
      
        \varphi(0)=0 
       
      
    φ(0)=0, 
     
      
       
       
         φ 
        
       
         ( 
        
       
         1 
        
       
         ) 
        
       
         = 
        
       
         θ 
        
       
      
        \varphi(1)=\theta 
       
      
    φ(1)=θ,那么就可得 
     
      
       
       
         φ 
        
       
         ( 
        
       
         m 
        
       
         ) 
        
       
         = 
        
       
         m 
        
       
         θ 
        
       
      
        \varphi(m) = m \theta 
       
      
    φ(m)=mθ。
综上,我们得到了二维情况下用复数表示的RoPE, 它是满足恒等式2的一个解:
  
      
       
        
        
          f 
         
        
          ( 
         
        
          q 
         
        
          , 
         
        
          m 
         
        
          ) 
         
        
          = 
         
         
         
           R 
          
         
           f 
          
         
        
          ( 
         
        
          q 
         
        
          , 
         
        
          m 
         
        
          ) 
         
         
         
           e 
          
          
          
            i 
           
           
           
             Θ 
            
           
             f 
            
           
          
            ( 
           
          
            q 
           
          
            , 
           
          
            m 
           
          
            ) 
           
          
         
        
          = 
         
        
          ∣ 
         
        
          ∣ 
         
        
          q 
         
        
          ∣ 
         
        
          ∣ 
         
         
         
           e 
          
          
          
            i 
           
          
            ( 
           
          
            Θ 
           
          
            ( 
           
          
            q 
           
          
            ) 
           
          
            + 
           
          
            m 
           
          
            θ 
           
          
            ) 
           
          
         
        
          = 
         
        
          q 
         
         
         
           e 
          
          
          
            i 
           
          
            m 
           
          
            θ 
           
          
         
        
       
         f(\mathbf{q}, m) = R_f(\mathbf{q}, m)e^{i\Theta_f(\mathbf{q}, m)}=||\mathbf{q}||e^{i(\Theta(\mathbf{q})+m\mathbf{\theta})} = \mathbf{q} e^{im \theta} 
        
       
     f(q,m)=Rf(q,m)eiΘf(q,m)=∣∣q∣∣ei(Θ(q)+mθ)=qeimθ
 根据复数乘法的几何意义,这个变换对应着向量的旋转,所以RoPE作者将其称之为”旋转式位置编码“。
将上式表示成矩阵形式:
  
      
       
        
        
          f 
         
        
          ( 
         
        
          q 
         
        
          , 
         
        
          m 
         
        
          ) 
         
        
          = 
         
         
         
           ( 
          
          
           
            
             
              
              
                cos 
               
              
                 
               
              
                m 
               
              
                θ 
               
              
             
            
            
             
              
              
                − 
               
              
                sin 
               
              
                 
               
              
                m 
               
              
                θ 
               
              
             
            
           
           
            
             
              
              
                sin 
               
              
                 
               
              
                m 
               
              
                θ 
               
              
             
            
            
             
              
              
                cos 
               
              
                 
               
              
                m 
               
              
                θ 
               
              
             
            
           
          
         
           ) 
          
         
         
         
           ( 
          
          
           
           
             q 
            
           
             0 
            
           
           
           
             q 
            
           
             1 
            
           
          
         
           ) 
          
         
        
       
         \boldsymbol{f}(\boldsymbol{q}, m)=\left(\begin{array}{cc} \cos m \theta & -\sin m \theta \\ \sin m \theta & \cos m \theta \end{array}\right)\binom{q_0}{q_1} 
        
       
     f(q,m)=(cosmθsinmθ−sinmθcosmθ)(q1q0)
 由于内积满足线性叠加性,所以任意偶数维的RoPE都可以表示为二维情形的拼接。
  
      
       
        
        
          f 
         
        
          ( 
         
        
          q 
         
        
          , 
         
        
          m 
         
        
          ) 
         
        
          = 
         
         
         
           ( 
          
          
           
            
             
              
              
                M 
               
              
                0 
               
              
             
            
            
             
              
             
            
            
             
              
             
            
            
             
              
             
            
           
           
            
             
              
             
            
            
             
              
              
                M 
               
              
                1 
               
              
             
            
            
             
              
             
            
            
             
              
             
            
           
           
            
             
              
             
            
            
             
              
             
            
            
             
             
               ⋱ 
              
             
            
            
             
              
             
            
           
           
            
             
              
             
            
            
             
              
             
            
            
             
              
             
            
            
             
              
              
                M 
               
               
               
                 d 
                
               
                 / 
                
               
                 2 
                
               
                 − 
                
               
                 1 
                
               
              
             
            
           
          
         
           ) 
          
         
         
         
           ( 
          
          
           
            
             
              
              
                q 
               
              
                0 
               
              
             
            
           
           
            
             
              
              
                q 
               
              
                1 
               
              
             
            
           
           
            
             
              
              
                ⋮ 
               
               
                
               
              
             
            
           
           
            
             
              
              
                q 
               
               
               
                 d 
                
               
                 − 
                
               
                 1 
                
               
              
             
            
           
          
         
           ) 
          
         
        
          = 
         
         
          
          
            R 
           
          
            m 
           
          
          
          
            q 
           
          
            m 
           
          
         
        
          = 
         
         
          
          
            R 
           
          
            m 
           
          
          
          
            W 
           
          
            q 
           
          
          
          
            x 
           
          
            m 
           
          
         
        
       
         f(\mathbf{q}, m) = \begin{pmatrix} M_0 & & & \\ & M_1 & & \\ & & \ddots & \\ & & & M_{d/2-1} \end{pmatrix} \begin{pmatrix} q_0\\ q_1\\ \vdots\\ q_{d-1} \end{pmatrix} = \mathbf{R_m q_m} = \mathbf{R_m W_q x_m} 
        
       
     f(q,m)= 
              M0M1⋱Md/2−1 
               
              q0q1⋮qd−1 
              =Rmqm=RmWqxm
 上式中 
     
      
       
        
        
          M 
         
        
          j 
         
        
       
         = 
        
        
        
          ( 
         
         
          
           
            
             
             
               cos 
              
             
                
              
             
               m 
              
              
              
                θ 
               
              
                j 
               
              
             
            
           
           
            
             
             
               − 
              
             
               sin 
              
             
                
              
             
               m 
              
              
              
                θ 
               
              
                j 
               
              
             
            
           
          
          
           
            
             
             
               s 
              
             
               i 
              
             
               n 
              
             
               m 
              
              
              
                θ 
               
              
                j 
               
              
             
            
           
           
            
             
             
               cos 
              
             
                
              
             
               m 
              
              
              
                θ 
               
              
                j 
               
              
             
            
           
          
         
        
          ) 
         
        
       
      
        M_j=\begin{pmatrix}\cos m\theta_j & -\sin m\theta_j \\sin m\theta_j & \cos m\theta_j\end{pmatrix} 
       
      
    Mj=(cosmθjsinmθj−sinmθjcosmθj), 
     
      
       
       
         Θ 
        
       
         = 
        
       
         { 
        
        
        
          θ 
         
        
          i 
         
        
       
         = 
        
       
         1000 
        
        
        
          0 
         
         
         
           − 
          
         
           2 
          
         
           ( 
          
         
           i 
          
         
           − 
          
         
           1 
          
         
           ) 
          
         
           / 
          
         
           d 
          
         
        
       
         , 
        
       
         i 
        
       
         ∈ 
        
       
         [ 
        
       
         0 
        
       
         , 
        
       
         1 
        
       
         , 
        
       
         2 
        
       
         , 
        
       
         … 
        
       
         , 
        
       
         d 
        
       
         / 
        
       
         2 
        
       
         − 
        
       
         1 
        
       
         ] 
        
       
         } 
        
       
      
        \Theta = \{\theta_i=10000^{-2(i-1)/d}, i \in[0,1,2, \ldots, d/2-1] \} 
       
      
    Θ={θi=10000−2(i−1)/d,i∈[0,1,2,…,d/2−1]},  
     
      
       
        
        
          R 
         
        
          m 
         
        
       
      
        \mathbf{R_m} 
       
      
    Rm是对角旋转矩阵,它是一个正交矩阵, 
     
      
       
        
        
          W 
         
        
          q 
         
        
       
      
        \mathbf{W_q} 
       
      
    Wq是待学习的query权重, 
     
      
       
        
        
          x 
         
        
          m 
         
        
       
      
        \mathbf{x_m} 
       
      
    xm则是m处的token的embedding。
也就是说,给位置m的向量 
     
      
       
       
         q 
        
       
      
        \mathbf{q} 
       
      
    q乘上矩阵 
     
      
       
        
        
          R 
         
        
          m 
         
        
       
      
        \mathbf{R_m} 
       
      
    Rm,位置为n的向量 
     
      
       
       
         k 
        
       
      
        \mathbf{k} 
       
      
    k乘上矩阵 
     
      
       
        
        
          R 
         
        
          n 
         
        
       
      
        \mathbf{R_n} 
       
      
    Rn, 用变换后的序列做Attention,Attention就自动包含相对位置了,因为有如下恒等式:
  
      
       
        
        
          ( 
         
         
          
          
            R 
           
          
            m 
           
          
         
           q 
          
         
         
         
           ) 
          
         
           T 
          
         
        
          ( 
         
         
          
          
            R 
           
          
            n 
           
          
         
           k 
          
         
        
          ) 
         
        
          = 
         
         
          
          
            q 
           
          
            T 
           
          
          
          
            R 
           
          
            m 
           
          
            T 
           
          
          
          
            R 
           
          
            n 
           
          
         
           k 
          
         
        
          = 
         
         
          
          
            q 
           
          
            T 
           
          
          
          
            R 
           
           
           
             m 
            
           
             − 
            
           
             n 
            
           
          
         
           k 
          
         
        
       
         (\mathbf{R_m q})^T(\mathbf{R_n k}) = \mathbf{q^T R_m^T R_n k} = \mathbf{q^T R_{m-n} k} 
        
       
     (Rmq)T(Rnk)=qTRmTRnk=qTRm−nk
 因为 
     
      
       
        
        
          R 
         
        
          m 
         
        
       
      
        \mathbf{R_m} 
       
      
    Rm的稀疏性,直接用矩阵乘法来实现很浪费算法,所以RoPE作者推荐用如下方式来实现RoPE:
  
      
       
        
         
         
           ( 
          
          
           
            
             
              
              
                q 
               
              
                0 
               
              
             
            
           
           
            
             
              
              
                q 
               
              
                1 
               
              
             
            
           
           
            
             
              
              
                q 
               
              
                2 
               
              
             
            
           
           
            
             
              
              
                q 
               
              
                3 
               
              
             
            
           
           
            
             
              
              
                ⋮ 
               
               
                
               
              
             
            
           
           
            
             
              
              
                q 
               
               
               
                 d 
                
               
                 − 
                
               
                 2 
                
               
              
             
            
           
           
            
             
              
              
                q 
               
               
               
                 d 
                
               
                 − 
                
               
                 1 
                
               
              
             
            
           
          
         
           ) 
          
         
        
          ⊗ 
         
         
         
           ( 
          
          
           
            
             
              
              
                cos 
               
              
                 
               
              
                m 
               
               
               
                 θ 
                
               
                 0 
                
               
              
             
            
           
           
            
             
              
              
                cos 
               
              
                 
               
              
                m 
               
               
               
                 θ 
                
               
                 0 
                
               
              
             
            
           
           
            
             
              
              
                cos 
               
              
                 
               
              
                m 
               
               
               
                 θ 
                
               
                 1 
                
               
              
             
            
           
           
            
             
              
              
                cos 
               
              
                 
               
              
                m 
               
               
               
                 θ 
                
               
                 1 
                
               
              
             
            
           
           
            
             
              
              
                ⋮ 
               
               
                
               
              
             
            
           
           
            
             
              
              
                cos 
               
              
                 
               
              
                m 
               
               
               
                 θ 
                
                
                
                  d 
                 
                
                  / 
                 
                
                  2 
                 
                
                  − 
                 
                
                  1 
                 
                
               
              
             
            
           
           
            
             
              
              
                cos 
               
              
                 
               
              
                m 
               
               
               
                 θ 
                
                
                
                  d 
                 
                
                  / 
                 
                
                  2 
                 
                
                  − 
                 
                
                  1 
                 
                
               
              
             
            
           
          
         
           ) 
          
         
        
          + 
         
         
         
           ( 
          
          
           
            
             
              
              
                − 
               
               
               
                 q 
                
               
                 1 
                
               
              
             
            
           
           
            
             
              
              
                q 
               
              
                0 
               
              
             
            
           
           
            
             
              
              
                − 
               
               
               
                 q 
                
               
                 3 
                
               
              
             
            
           
           
            
             
              
              
                q 
               
              
                2 
               
              
             
            
           
           
            
             
              
              
                ⋮ 
               
               
                
               
              
             
            
           
           
            
             
              
              
                − 
               
               
               
                 q 
                
                
                
                  d 
                 
                
                  − 
                 
                
                  1 
                 
                
               
              
             
            
           
           
            
             
              
              
                q 
               
               
               
                 d 
                
               
                 − 
                
               
                 2 
                
               
              
             
            
           
          
         
           ) 
          
         
        
          ⊗ 
         
         
         
           ( 
          
          
           
            
             
              
              
                sin 
               
              
                 
               
              
                m 
               
               
               
                 θ 
                
               
                 0 
                
               
              
             
            
           
           
            
             
              
              
                sin 
               
              
                 
               
              
                m 
               
               
               
                 θ 
                
               
                 0 
                
               
              
             
            
           
           
            
             
              
              
                sin 
               
              
                 
               
              
                m 
               
               
               
                 θ 
                
               
                 1 
                
               
              
             
            
           
           
            
             
              
              
                sin 
               
              
                 
               
              
                m 
               
               
               
                 θ 
                
               
                 1 
                
               
              
             
            
           
           
            
             
              
              
                ⋮ 
               
               
                
               
              
             
            
           
           
            
             
              
              
                sin 
               
              
                 
               
              
                m 
               
               
               
                 θ 
                
                
                
                  d 
                 
                
                  / 
                 
                
                  2 
                 
                
                  − 
                 
                
                  1 
                 
                
               
              
             
            
           
           
            
             
              
              
                sin 
               
              
                 
               
              
                m 
               
               
               
                 θ 
                
                
                
                  d 
                 
                
                  / 
                 
                
                  2 
                 
                
                  − 
                 
                
                  1 
                 
                
               
              
             
            
           
          
         
           ) 
          
         
        
       
         \left(\begin{array}{c} q_0 \\ q_1 \\ q_2 \\ q_3 \\ \vdots \\ q_{d-2} \\ q_{d-1} \end{array}\right) \otimes\left(\begin{array}{c} \cos m \theta_0 \\ \cos m \theta_0 \\ \cos m \theta_1 \\ \cos m \theta_1 \\ \vdots \\ \cos m \theta_{d / 2-1} \\ \cos m \theta_{d / 2-1} \end{array}\right)+\left(\begin{array}{c} -q_1 \\ q_0 \\ -q_3 \\ q_2 \\ \vdots \\ -q_{d-1} \\ q_{d-2} \end{array}\right) \otimes\left(\begin{array}{c} \sin m \theta_0 \\ \sin m \theta_0 \\ \sin m \theta_1 \\ \sin m \theta_1 \\ \vdots \\ \sin m \theta_{d / 2-1} \\ \sin m \theta_{d / 2-1} \end{array}\right) 
        
       
      
              q0q1q2q3⋮qd−2qd−1 
              ⊗ 
              cosmθ0cosmθ0cosmθ1cosmθ1⋮cosmθd/2−1cosmθd/2−1 
              + 
              −q1q0−q3q2⋮−qd−1qd−2 
              ⊗ 
              sinmθ0sinmθ0sinmθ1sinmθ1⋮sinmθd/2−1sinmθd/2−1 
              
 上式中的 
     
      
       
       
         ⊗ 
        
       
      
        \otimes 
       
      
    ⊗是逐位对应相乘,是Numpy等中的*运算。 从这个实现也可以来解释为什么RoPE可以视为是乘性位置编码的变体。
RoPE的代码实现
按照上面RoPE作者推荐的方式实现RoPE的示例如下(来自参考资料4)
import torch
import math
def rotary_position_embedding(q, k):
    """
    Rotary Position Embedding (RoPE) for queries and keys.
    
    Args:
        q: tensor for queries of shape (batch_size, num_heads, seq_len, dim)
        k: tensor for keys of shape (batch_size, num_heads, seq_len, dim)
        
    Returns:
        Rotated queries and keys
    """
    batch_size, num_heads, seq_len, dim = q.size()
    
    # Begin of sinusoidal_position_embedding content
    # 序列对应的位置序号
    position = torch.arange(seq_len, dtype=torch.float).unsqueeze(-1).to(q.device)
    # q维度上的theta值
    div_term = torch.exp(torch.arange(0, dim, 2, dtype=torch.float) * -(math.log(10000.0) / dim)).to(q.device)
    
    pos_emb = position * div_term
    pos_emb = torch.stack([torch.sin(pos_emb), torch.cos(pos_emb)], dim=-1).flatten(-2, -1)
    pos_emb = pos_emb.unsqueeze(0).unsqueeze(1)
    pos_emb = pos_emb.expand(batch_size, num_heads, -1, -1)
    # End of sinusoidal_position_embedding content
    # Extract and duplicate cosine and sine embeddings
    cos_emb = pos_emb[..., 1::2].repeat_interleave(2, dim=-1)
    sin_emb = pos_emb[..., ::2].repeat_interleave(2, dim=-1)
    # Create alternate versions of q and k
    q_alternate = torch.stack([-q[..., 1::2], q[..., ::2]], dim=-1).reshape(q.size())
    k_alternate = torch.stack([-k[..., 1::2], k[..., ::2]], dim=-1).reshape(k.size())
    # Rotate queries and keys
    q_rotated = q * cos_emb + q_alternate * sin_emb
    k_rotated = k * cos_emb + k_alternate * sin_emb
    return q_rotated, k_rotated
llama实现RoPE的方式是先将向量转到复数域,再对两个向量进行旋转,接着将向量转回到实数域。
# 以长度为4,dim维度为6的q示意llama是如何实现RoPE的
q = torch.tensor([[1, 2, 4, 5, 6, 7], [1, 2, 5, 6, 7, 8], [2, 5, 4, 6, 7, 8], [1, 3, 5, 6, 7, 9]])
seq_len, dim = q.shape  # [4,6]
# 将q在其embedding维度分为一对一对的形式
q_per_token_split_into_pairs = q.float().view(q.shape[0], -1, 2)
q_per_token_split_into_pairs.shape  # [4,3,2]
# 计算复数域的cos和sin的频率
zero_to_one_split_into_dim_parts = torch.tensor(range(dim//2))/(dim//2)
rope_theta = 10000.0
freqs = 1.0 / (rope_theta ** zero_to_one_split_into_dim_parts)
freqs_for_each_token = torch.outer(torch.arange(seq_len), freqs)
freqs_cis = torch.polar(torch.ones_like(freqs_for_each_token), freqs_for_each_token)
freqs_cis
# 将q转到复数域
q_per_token_as_complex_numbers = torch.view_as_complex(q_per_token_split_into_pairs)
# 进行dot product来按位置旋转q向量
q_per_token_as_complex_numbers_rotated = q_per_token_as_complex_numbers * freqs_cis
q_per_token_as_complex_numbers_rotated
# 将旋转了的q向量转回到实数域
q_per_token_split_into_pairs_rotated = torch.view_as_real(q_per_token_as_complex_numbers_rotated)
# 将维度还原
q_per_token_rotated = q_per_token_split_into_pairs_rotated.view(q.shape)
参考资料
-  RoPE论文:Su, Jianlin, Yu Lu, Shengfeng Pan, Bo Wen, and Yunfeng Liu. 2021. “RoFormer: Enhanced Transformer with Rotary Position Embedding.” Cornell University - arXiv,Cornell University - arXiv, April. 
-  RoPE作者苏剑林的博客: 让研究人员绞尽脑汁的Transformer位置编码 ,Transformer升级之路:2、博采众长的旋转式位置编码 
-  eleuther.ai关于RoPE的博客 
-  知乎文章 位置编码|RoPE|ALiBi 
-  复数 wikipedia, 旋转矩阵wikipedia 
-  rotary_embedding-torch github 
-  llama3 from scratch 



















