Transformer 基础
Transformer 模型架构

主要组成: Encoder, Decoder, Generator.
Encoder (编码器)
由  
     
      
       
       
         N 
        
       
      
        N 
       
      
    N 层结构相同(参数不同)的 EncoderLayer 网络组成.
  
     
      
       
       
         In 
        
       
         : 
        
       
         [ 
        
       
         b 
        
       
         a 
        
       
         t 
        
       
         c 
        
       
         h 
        
       
         _ 
        
       
         s 
        
       
         z 
        
       
         , 
        
       
         s 
        
       
         e 
        
       
         q 
        
       
         _ 
        
       
         l 
        
       
         e 
        
       
         n 
        
       
         , 
        
        
        
          d 
         
         
         
           m 
          
         
           o 
          
         
           d 
          
         
           e 
          
         
           l 
          
         
        
       
         ] 
        
       
      
        \textbf{In}: [batch\_sz, seq\_len, d_{model}] 
       
      
    In:[batch_sz,seq_len,dmodel],  
     
      
       
       
         Out 
        
       
         : 
        
       
         [ 
        
       
         b 
        
       
         a 
        
       
         t 
        
       
         c 
        
       
         h 
        
       
         _ 
        
       
         s 
        
       
         z 
        
       
         , 
        
       
         s 
        
       
         e 
        
       
         q 
        
       
         _ 
        
       
         l 
        
       
         e 
        
       
         n 
        
       
         , 
        
        
        
          d 
         
         
         
           m 
          
         
           o 
          
         
           d 
          
         
           e 
          
         
           l 
          
         
        
       
         ] 
        
       
      
        \textbf{Out}: [batch\_sz, seq\_len, d_{model}] 
       
      
    Out:[batch_sz,seq_len,dmodel]
EncoderLayer: 由一层自注意力 Multi-Head Attention (多头注意力) 子网络, 一层 Position-wise Feed-Forward (基于位置的前馈) 子网络, 以及用于连接子网络的 Residual Connection (残差连接) 和 Layer Normalization (层标准化) 组成.
  
     
      
       
       
         In 
        
       
         : 
        
       
         [ 
        
       
         b 
        
       
         a 
        
       
         t 
        
       
         c 
        
       
         h 
        
       
         _ 
        
       
         s 
        
       
         z 
        
       
         , 
        
       
         s 
        
       
         e 
        
       
         q 
        
       
         _ 
        
       
         l 
        
       
         e 
        
       
         n 
        
       
         , 
        
        
        
          d 
         
         
         
           m 
          
         
           o 
          
         
           d 
          
         
           e 
          
         
           l 
          
         
        
       
         ] 
        
       
         , 
        
       
         Out 
        
       
         : 
        
       
         [ 
        
       
         b 
        
       
         a 
        
       
         t 
        
       
         c 
        
       
         h 
        
       
         _ 
        
       
         s 
        
       
         z 
        
       
         , 
        
       
         s 
        
       
         e 
        
       
         q 
        
       
         _ 
        
       
         l 
        
       
         e 
        
       
         n 
        
       
         , 
        
        
        
          d 
         
         
         
           m 
          
         
           o 
          
         
           d 
          
         
           e 
          
         
           l 
          
         
        
       
         ] 
        
       
      
        \textbf{In}: [batch\_sz, seq\_len, d_{model}], \textbf{Out}: [batch\_sz, seq\_len, d_{model}] 
       
      
    In:[batch_sz,seq_len,dmodel],Out:[batch_sz,seq_len,dmodel]
- 自注意力 Multi-Head Attention 网络: Q, K, V 均来自上一层(Input Embedding/EncoderLayer)网络.
 In : [ b a t c h _ s z , s e q _ l e n , d m o d e l ] , Out : [ b a t c h _ s z , s e q _ l e n , d m o d e l ] \textbf{In}: [batch\_sz, seq\_len, d_{model}], \textbf{Out}: [batch\_sz, seq\_len, d_{model}] In:[batch_sz,seq_len,dmodel],Out:[batch_sz,seq_len,dmodel]
Decoder (解码器)
由  
     
      
       
       
         N 
        
       
      
        N 
       
      
    N 层结构相同(参数不同)的 DecoderLayer 网络组成.
  
     
      
       
       
         In 
        
       
         : 
        
       
         [ 
        
       
         b 
        
       
         a 
        
       
         t 
        
       
         c 
        
       
         h 
        
       
         _ 
        
       
         s 
        
       
         z 
        
       
         , 
        
       
         s 
        
       
         e 
        
       
         q 
        
       
         _ 
        
       
         l 
        
       
         e 
        
       
         n 
        
       
         , 
        
        
        
          d 
         
         
         
           m 
          
         
           o 
          
         
           d 
          
         
           e 
          
         
           l 
          
         
        
       
         ] 
        
       
         , 
        
       
         Out 
        
       
         : 
        
       
         [ 
        
       
         b 
        
       
         a 
        
       
         t 
        
       
         c 
        
       
         h 
        
       
         _ 
        
       
         s 
        
       
         z 
        
       
         , 
        
       
         s 
        
       
         e 
        
       
         q 
        
       
         _ 
        
       
         l 
        
       
         e 
        
       
         n 
        
       
         , 
        
        
        
          d 
         
         
         
           m 
          
         
           o 
          
         
           d 
          
         
           e 
          
         
           l 
          
         
        
       
         ] 
        
       
      
        \textbf{In}: [batch\_sz, seq\_len, d_{model}], \textbf{Out}: [batch\_sz, seq\_len, d_{model}] 
       
      
    In:[batch_sz,seq_len,dmodel],Out:[batch_sz,seq_len,dmodel]
DecoderLayer: 由一层自注意力 Masked Multi-Head Attention 子网络, 一层(Encoder-Decoder)注意力 Multi-Head Attention 子网络, 一层 Position-wise Feed-Forward (基于位置的前馈) 子网络, 以及用于连接子网络的 Residual Connection (残差连接) 和 Layer Normalization (层标准化) 组成.
  
     
      
       
       
         In 
        
       
         : 
        
       
         [ 
        
       
         b 
        
       
         a 
        
       
         t 
        
       
         c 
        
       
         h 
        
       
         _ 
        
       
         s 
        
       
         z 
        
       
         , 
        
       
         s 
        
       
         e 
        
       
         q 
        
       
         _ 
        
       
         l 
        
       
         e 
        
       
         n 
        
       
         , 
        
        
        
          d 
         
         
         
           m 
          
         
           o 
          
         
           d 
          
         
           e 
          
         
           l 
          
         
        
       
         ] 
        
       
         , 
        
       
         Out 
        
       
         : 
        
       
         [ 
        
       
         b 
        
       
         a 
        
       
         t 
        
       
         c 
        
       
         h 
        
       
         _ 
        
       
         s 
        
       
         z 
        
       
         , 
        
       
         s 
        
       
         e 
        
       
         q 
        
       
         _ 
        
       
         l 
        
       
         e 
        
       
         n 
        
       
         , 
        
        
        
          d 
         
         
         
           m 
          
         
           o 
          
         
           d 
          
         
           e 
          
         
           l 
          
         
        
       
         ] 
        
       
      
        \textbf{In}: [batch\_sz, seq\_len, d_{model}], \textbf{Out}: [batch\_sz, seq\_len, d_{model}] 
       
      
    In:[batch_sz,seq_len,dmodel],Out:[batch_sz,seq_len,dmodel]
- 自注意力 Masked Multi-Head Attention 网络: Q, K, V 均来自上一层(Output Embedding/DecoderLayer)网络. “Masked” 是通过掩码( 
      
       
        
        
          [ 
         
        
          1 
         
        
          , 
         
        
          s 
         
        
          e 
         
        
          q 
         
        
          _ 
         
        
          l 
         
        
          e 
         
        
          n 
         
        
          , 
         
        
          s 
         
        
          e 
         
        
          q 
         
        
          _ 
         
        
          l 
         
        
          e 
         
        
          n 
         
        
          ] 
         
        
       
         [1,seq\_len,seq\_len] 
        
       
     [1,seq_len,seq_len])将后续位置屏蔽, 仅关注需要预测的下一个位置.
 In : [ b a t c h _ s z , s e q _ l e n , d m o d e l ] , Out : [ b a t c h _ s z , s e q _ l e n , d m o d e l ] \textbf{In}: [batch\_sz, seq\_len, d_{model}], \textbf{Out}: [batch\_sz, seq\_len, d_{model}] In:[batch_sz,seq_len,dmodel],Out:[batch_sz,seq_len,dmodel]
- (Encoder-Decoder)注意力 Multi-Head Attention 网络: Q 来自上一层(Masked Multi-Head Attention)网络; K,V 来自 Encoder 的输出 memory.
 In : [ b a t c h _ s z , s e q _ l e n , d m o d e l ] , Out : [ b a t c h _ s z , s e q _ l e n , d m o d e l ] \textbf{In}: [batch\_sz, seq\_len, d_{model}], \textbf{Out}: [batch\_sz, seq\_len, d_{model}] In:[batch_sz,seq_len,dmodel],Out:[batch_sz,seq_len,dmodel]
Generator (生成器)
由  
     
      
       
       
         [ 
        
       
         In 
        
       
         : 
        
        
        
          d 
         
         
         
           m 
          
         
           o 
          
         
           d 
          
         
           e 
          
         
           l 
          
         
        
       
         , 
        
       
         Out 
        
       
         : 
        
       
         v 
        
       
         o 
        
       
         c 
        
       
         a 
        
       
         b 
        
       
         _ 
        
       
         s 
        
       
         z 
        
       
         ] 
        
       
      
        [\text{In}: d_{model}, \text{Out}:vocab\_sz] 
       
      
    [In:dmodel,Out:vocab_sz] 的线性网络和 Softmax 操作组成.
  
      
       
        
        
          y 
         
        
          = 
         
         
         
           s 
          
         
           o 
          
         
           f 
          
         
           t 
          
         
           m 
          
         
           a 
          
         
           x 
          
         
        
          ( 
         
         
         
           L 
          
         
           i 
          
         
           n 
          
         
           e 
          
         
           a 
          
         
           r 
          
         
        
          ( 
         
        
          x 
         
        
          ) 
         
        
          ) 
         
        
          = 
         
         
         
           s 
          
         
           o 
          
         
           f 
          
         
           t 
          
         
           m 
          
         
           a 
          
         
           x 
          
         
        
          ( 
         
        
          x 
         
         
         
           A 
          
         
           T 
          
         
        
          + 
         
        
          b 
         
        
          ) 
         
        
       
         y = \mathrm{softmax}(\mathrm{Linear}(x))=\mathrm{softmax}(xA^T+b) 
        
       
     y=softmax(Linear(x))=softmax(xAT+b)
 生成器是按序列顺序一次只输出下一个位置的预测概率.
  
     
      
       
       
         In 
        
       
         : 
        
       
         [ 
        
       
         b 
        
       
         a 
        
       
         t 
        
       
         c 
        
       
         h 
        
       
         _ 
        
       
         s 
        
       
         z 
        
       
         , 
        
        
        
          d 
         
         
         
           m 
          
         
           o 
          
         
           d 
          
         
           e 
          
         
           l 
          
         
        
       
         ] 
        
       
         , 
        
       
         Out 
        
       
         : 
        
       
         [ 
        
       
         b 
        
       
         a 
        
       
         t 
        
       
         c 
        
       
         h 
        
       
         _ 
        
       
         s 
        
       
         z 
        
       
         , 
        
       
         v 
        
       
         o 
        
       
         c 
        
       
         a 
        
       
         b 
        
       
         _ 
        
       
         s 
        
       
         z 
        
       
         ] 
        
       
      
        \textbf{In}: [batch\_sz, d_{model}], \textbf{Out}: [batch\_sz, vocab\_sz] 
       
      
    In:[batch_sz,dmodel],Out:[batch_sz,vocab_sz]
※ Multi-Head Attention
Scaled Dot-Product Attention (缩放点积注意力):
  
      
       
        
         
          
          
            A 
           
          
            t 
           
          
            t 
           
          
            e 
           
          
            n 
           
          
            t 
           
          
            i 
           
          
            o 
           
          
            n 
           
          
         
           ( 
          
         
           Q 
          
         
           , 
          
         
           K 
          
         
           , 
          
         
           V 
          
         
           ) 
          
         
           = 
          
          
          
            s 
           
          
            o 
           
          
            f 
           
          
            t 
           
          
            m 
           
          
            a 
           
          
            x 
           
          
         
           ( 
          
          
           
           
             Q 
            
            
            
              K 
             
            
              ⊤ 
             
            
           
           
            
            
              d 
             
            
              k 
             
            
           
          
         
           ) 
          
         
           V 
          
         
        
       
         \pmb{\mathrm{Attention}(Q,K,V) = \mathrm{softmax}(\frac{QK^{\top}}{\sqrt{d_k}})V} 
        
       
     Attention(Q,K,V)=softmax(dkQK⊤)V
 维度变化:
- 输入: 
  - Q [ b a t c h _ s z , h , s e q _ l e n , d k ] Q\ [batch\_sz,h,seq\_len,d_k] Q [batch_sz,h,seq_len,dk]
- K [ b a t c h _ s z , h , s e q _ l e n , d k ] K\ [batch\_sz,h,seq\_len,d_k] K [batch_sz,h,seq_len,dk], K ⊤ [ b a t c h _ s z , h , d k , s e q _ l e n ] K^{\top}\ [batch\_sz,h,d_k,seq\_len] K⊤ [batch_sz,h,dk,seq_len]
- V [ b a t c h _ s z , h , s e q _ l e n , d k ] V\ [batch\_sz,h,seq\_len,d_k] V [batch_sz,h,seq_len,dk]
 
- Q K ⊤ [ b a t c h _ s z , h , s e q _ l e n , s e q _ l e n ] QK^{\top}\ [batch\_sz,h,seq\_len,seq\_len] QK⊤ [batch_sz,h,seq_len,seq_len]
- Q K ⊤ d k \frac{QK^{\top}}{\sqrt{d_k}} dkQK⊤ 与 Mask 操作: 不改变形状 [ b a t c h _ s z , h , s e q _ l e n , s e q _ l e n ] [batch\_sz,h,seq\_len,seq\_len] [batch_sz,h,seq_len,seq_len]
- s o f t m a x ( Q K ⊤ d k ) \mathrm{softmax}(\frac{QK^{\top}}{\sqrt{d_k}}) softmax(dkQK⊤): 最后一维进行 Softmax 操作, 不改变形状 [ b a t c h _ s z , h , s e q _ l e n , s e q _ l e n ] [batch\_sz,h,seq\_len,seq\_len] [batch_sz,h,seq_len,seq_len]
-  
      
       
        
         
         
           s 
          
         
           o 
          
         
           f 
          
         
           t 
          
         
           m 
          
         
           a 
          
         
           x 
          
         
        
          ( 
         
         
          
          
            Q 
           
           
           
             K 
            
           
             ⊤ 
            
           
          
          
           
           
             d 
            
           
             k 
            
           
          
         
        
          ) 
         
        
          V 
         
        
       
         \mathrm{softmax}(\frac{QK^{\top}}{\sqrt{d_k}})V 
        
       
     softmax(dkQK⊤)V:  
      
       
        
        
          [ 
         
        
          b 
         
        
          a 
         
        
          t 
         
        
          c 
         
        
          h 
         
        
          _ 
         
        
          s 
         
        
          z 
         
        
          , 
         
        
          h 
         
        
          , 
         
        
          s 
         
        
          e 
         
        
          q 
         
        
          _ 
         
        
          l 
         
        
          e 
         
        
          n 
         
        
          , 
         
         
         
           d 
          
         
           k 
          
         
        
          ] 
         
        
       
         [batch\_sz,h,seq\_len,d_k] 
        
       
     [batch_sz,h,seq_len,dk]
  
完整公式(参考 FlashAttention):
  
      
       
        
         
          
           
            
           
          
          
           
            
             
            
              S 
             
            
              = 
             
            
              τ 
             
            
              Q 
             
             
             
               K 
              
             
               ⊤ 
              
             
            
              ∈ 
             
             
             
               R 
              
              
              
                N 
               
              
                × 
               
              
                N 
               
              
             
            
           
          
         
         
          
           
            
           
          
          
           
            
             
             
             
               S 
              
             
               masked 
              
             
            
              = 
             
            
              MASK 
             
            
              ( 
             
            
              S 
             
            
              ) 
             
            
              ∈ 
             
             
             
               R 
              
              
              
                N 
               
              
                × 
               
              
                N 
               
              
             
            
           
          
         
         
          
           
            
           
          
          
           
            
             
            
              P 
             
            
              = 
             
            
              softmax 
             
            
              ( 
             
             
             
               S 
              
             
               masked 
              
             
            
              ) 
             
            
              ∈ 
             
             
             
               R 
              
              
              
                N 
               
              
                × 
               
              
                N 
               
              
             
            
           
          
         
         
          
           
            
           
          
          
           
            
             
             
             
               P 
              
             
               dropped 
              
             
            
              = 
             
            
              dropout 
             
            
              ( 
             
            
              P 
             
            
              , 
             
             
             
               p 
              
              
              
                d 
               
              
                r 
               
              
                o 
               
              
                p 
               
              
             
            
              ) 
             
            
              ] 
             
            
           
          
         
         
          
           
            
           
          
          
           
            
             
            
              Attention 
             
            
              ( 
             
            
              Q 
             
            
              , 
             
            
              K 
             
            
              , 
             
            
              V 
             
            
              ) 
             
            
              = 
             
            
              O 
             
            
              = 
             
             
             
               P 
              
             
               dropped 
              
             
            
              V 
             
            
              ∈ 
             
             
             
               R 
              
              
              
                N 
               
              
                × 
               
              
                d 
               
              
             
            
           
          
         
        
       
         \begin{aligned} & S=\tau QK^{\top}\in\mathbb{R}^{N\times N}\\ & S^{\text{masked}}=\text{MASK}(S)\in\mathbb{R}^{N\times N}\\ & P=\text{softmax}(S^{\text{masked}})\in\mathbb{R}^{N\times N}\\ & P^{\text{dropped}}=\text{dropout}(P, p_{drop})]\\ & \text{Attention}(Q,K,V)=O=P^{\text{dropped}}V\in\mathbb{R}^{N\times d} \end{aligned} 
        
       
     S=τQK⊤∈RN×NSmasked=MASK(S)∈RN×NP=softmax(Smasked)∈RN×NPdropped=dropout(P,pdrop)]Attention(Q,K,V)=O=PdroppedV∈RN×d
Multi-Head Attention (多头注意力) 机制:
  
      
       
        
         
          
           
            
            
              M 
             
            
              u 
             
            
              l 
             
            
              t 
             
            
              i 
             
            
              H 
             
            
              e 
             
            
              a 
             
            
              d 
             
            
              A 
             
            
              t 
             
            
              t 
             
            
              n 
             
            
              ( 
             
            
              Q 
             
            
              , 
             
            
              K 
             
            
              , 
             
            
              V 
             
            
              ) 
             
            
           
          
          
           
            
             
            
              = 
             
            
              C 
             
            
              o 
             
            
              n 
             
            
              c 
             
            
              a 
             
            
              t 
             
            
              ( 
             
            
              h 
             
            
              e 
             
            
              a 
             
             
             
               d 
              
             
               1 
              
             
            
              , 
             
            
              . 
             
            
              . 
             
            
              . 
             
            
              , 
             
            
              h 
             
            
              e 
             
            
              a 
             
             
             
               d 
              
             
               h 
              
             
            
              ) 
             
             
             
               W 
              
             
               O 
              
             
            
           
          
         
         
          
           
            
             
             
               w 
              
             
               h 
              
             
               e 
              
             
               r 
              
             
               e 
              
             
            
                
             
            
              h 
             
            
              e 
             
            
              a 
             
             
             
               d 
              
             
               i 
              
             
            
           
          
          
           
            
             
            
              = 
             
            
              A 
             
            
              t 
             
            
              t 
             
            
              e 
             
            
              n 
             
            
              t 
             
            
              i 
             
            
              o 
             
            
              n 
             
            
              ( 
             
            
              Q 
             
             
             
               W 
              
             
               i 
              
             
               Q 
              
             
            
              , 
             
            
              K 
             
             
             
               W 
              
             
               i 
              
             
               K 
              
             
            
              , 
             
            
              V 
             
             
             
               W 
              
             
               i 
              
             
               V 
              
             
            
              ) 
             
            
           
          
         
        
       
         \begin{aligned} MultiHeadAttn(Q,K,V) &= Concat(head_1, ..., head_h)W^O\\ \mathrm{where}\ head_i &= Attention(QW^Q_i, KW^K_i, VW^V_i) \end{aligned} 
        
       
     MultiHeadAttn(Q,K,V)where headi=Concat(head1,...,headh)WO=Attention(QWiQ,KWiK,VWiV)
其中,  
     
      
       
        
        
          W 
         
        
          i 
         
        
          Q 
         
        
       
         ∈ 
        
        
        
          R 
         
         
         
           d 
          
          
          
            m 
           
          
            o 
           
          
            d 
           
          
            e 
           
          
            l 
           
          
            × 
           
           
           
             d 
            
           
             k 
            
           
          
         
        
       
         , 
        
        
        
          W 
         
        
          i 
         
        
          K 
         
        
       
         ∈ 
        
        
        
          R 
         
         
          
          
            d 
           
           
           
             m 
            
           
             o 
            
           
             d 
            
           
             e 
            
           
             l 
            
           
          
         
           × 
          
          
          
            d 
           
          
            k 
           
          
         
        
       
         , 
        
        
        
          W 
         
        
          i 
         
        
          V 
         
        
       
         ∈ 
        
        
        
          R 
         
         
          
          
            d 
           
           
           
             m 
            
           
             o 
            
           
             d 
            
           
             e 
            
           
             l 
            
           
          
         
           × 
          
          
          
            d 
           
          
            v 
           
          
         
        
       
         , 
        
        
        
          W 
         
        
          O 
         
        
       
         ∈ 
        
        
        
          R 
         
         
         
           h 
          
          
          
            d 
           
          
            v 
           
          
         
           × 
          
          
          
            d 
           
           
           
             m 
            
           
             o 
            
           
             d 
            
           
             e 
            
           
             l 
            
           
          
         
        
       
      
        W^Q_i\in\mathbb{R}^{d_{model\times d_k}}, W^K_i\in\mathbb{R}^{d_{model}\times d_k}, W^V_i\in\mathbb{R}^{d_{model}\times d_v}, W^O\in\mathbb{R}^{hd_v\times d_{model}} 
       
      
    WiQ∈Rdmodel×dk,WiK∈Rdmodel×dk,WiV∈Rdmodel×dv,WO∈Rhdv×dmodel
 在实现中,  
     
      
       
        
        
          W 
         
        
          Q 
         
        
       
         = 
        
       
         ( 
        
        
        
          W 
         
        
          1 
         
        
          Q 
         
        
       
         , 
        
       
         . 
        
       
         . 
        
       
         . 
        
       
         , 
        
        
        
          W 
         
        
          h 
         
        
          Q 
         
        
       
         ) 
        
       
      
        W^Q=(W^Q_1,...,W^Q_h) 
       
      
    WQ=(W1Q,...,WhQ),  
     
      
       
        
        
          W 
         
        
          K 
         
        
       
         = 
        
       
         ( 
        
        
        
          W 
         
        
          1 
         
        
          K 
         
        
       
         , 
        
       
         . 
        
       
         . 
        
       
         . 
        
       
         , 
        
        
        
          W 
         
        
          h 
         
        
          K 
         
        
       
         ) 
        
       
      
        W^K=(W^K_1,...,W^K_h) 
       
      
    WK=(W1K,...,WhK),  
     
      
       
        
        
          W 
         
        
          V 
         
        
       
         = 
        
       
         ( 
        
        
        
          W 
         
        
          1 
         
        
          V 
         
        
       
         , 
        
       
         . 
        
       
         . 
        
       
         . 
        
       
         , 
        
        
        
          W 
         
        
          h 
         
        
          V 
         
        
       
         ) 
        
       
      
        W^V=(W^V_1,...,W^V_h) 
       
      
    WV=(W1V,...,WhV),  
     
      
       
        
        
          W 
         
        
          O 
         
        
       
      
        W^O 
       
      
    WO, 由 4 个  
     
      
       
       
         [ 
        
       
         In 
        
       
         : 
        
        
        
          d 
         
         
         
           m 
          
         
           o 
          
         
           d 
          
         
           e 
          
         
           l 
          
         
        
       
         , 
        
       
         Out 
        
       
         : 
        
        
        
          d 
         
         
         
           m 
          
         
           o 
          
         
           d 
          
         
           e 
          
         
           l 
          
         
        
       
         ] 
        
       
      
        [\text{In}: d_{model}, \text{Out}:d_{model}] 
       
      
    [In:dmodel,Out:dmodel] 的线性网络组成,  
     
      
       
        
        
          d 
         
        
          k 
         
        
       
         = 
        
        
        
          d 
         
        
          v 
         
        
       
         = 
        
        
        
          d 
         
         
         
           m 
          
         
           o 
          
         
           d 
          
         
           e 
          
         
           l 
          
         
        
       
         / 
        
       
         h 
        
       
      
        d_k=d_v=d_{model}/h 
       
      
    dk=dv=dmodel/h

维度变化:
- 输入: X [ b a t c h _ s z , s e q _ l e n , d m o d e l ] X\ [batch\_sz, seq\_len, d_{model}] X [batch_sz,seq_len,dmodel]
- 多头预处理: X [ b a t c h _ s z , s e q _ l e n , d m o d e l ] X\ [batch\_sz, seq\_len, d_{model}] X [batch_sz,seq_len,dmodel] → X [ b a t c h _ s z , h , s e q _ l e n , d k ] X\ [batch\_sz,h,seq\_len,d_k] X [batch_sz,h,seq_len,dk]
- 注意力机制: X [ b a t c h _ s z , h , s e q _ l e n , d k ] X\ [batch\_sz,h,seq\_len,d_k] X [batch_sz,h,seq_len,dk] → Q , K , V [ b a t c h _ s z , h , s e q _ l e n , d k ] Q,K,V\ [batch\_sz,h,seq\_len,d_k] Q,K,V [batch_sz,h,seq_len,dk] → A t t e n t i o n ( Q , K , V ) [ b a t c h _ s z , h , s e q _ l e n , d k ] \mathrm{Attention}(Q,K,V)\ [batch\_sz,h,seq\_len,d_k] Attention(Q,K,V) [batch_sz,h,seq_len,dk]
- 拼接多头结果: C o n c a t ( h e a d 1 , . . . , h e a d h ) [ b a t c h _ s z , h , s e q _ l e n , d k ] Concat(head_1, ..., head_h)\ [batch\_sz,h,seq\_len,d_k] Concat(head1,...,headh) [batch_sz,h,seq_len,dk]
- 输出: M u l t i H e a d A t t n ( Q , K , V ) [ b a t c h _ s z , s e q _ l e n , d m o d e l ] MultiHeadAttn(Q,K,V)\ [batch\_sz, seq\_len, d_{model}] MultiHeadAttn(Q,K,V) [batch_sz,seq_len,dmodel]
Position-wise Feed-Forward
F F N ( x ) = L i n e a r 2 ( R e L U ( L i n e a r 1 ( x ) ) ) = max  ( 0 , x W 1 + b 1 ) W 2 + b 2 \mathrm{FFN}(x)=\mathrm{Linear}_2(\mathrm{ReLU}(\mathrm{Linear}_1(x)))=\max(0, xW_1 + b_1) W_2 + b_2 FFN(x)=Linear2(ReLU(Linear1(x)))=max(0,xW1+b1)W2+b2
 
     
      
       
        
         
         
           L 
          
         
           i 
          
         
           n 
          
         
           e 
          
         
           a 
          
         
           r 
          
         
        
          1 
         
        
       
         ( 
        
       
         x 
        
       
         ) 
        
       
      
        \mathrm{Linear}_1(x) 
       
      
    Linear1(x) :  
     
      
       
       
         [ 
        
       
         In 
        
       
         : 
        
        
        
          d 
         
         
         
           m 
          
         
           o 
          
         
           d 
          
         
           e 
          
         
           l 
          
         
        
       
         , 
        
       
          Out 
        
       
         : 
        
        
        
          d 
         
         
         
           f 
          
         
           f 
          
         
        
       
         ] 
        
       
      
        [\text{In}:d_{model},\ \text{Out}:d_{ff}] 
       
      
    [In:dmodel, Out:dff]
  
     
      
       
        
         
         
           L 
          
         
           i 
          
         
           n 
          
         
           e 
          
         
           a 
          
         
           r 
          
         
        
          2 
         
        
       
         ( 
        
       
         x 
        
       
         ) 
        
       
      
        \mathrm{Linear}_2(x) 
       
      
    Linear2(x) :  
     
      
       
       
         [ 
        
       
         In 
        
       
         : 
        
        
        
          d 
         
         
         
           f 
          
         
           f 
          
         
        
       
         , 
        
       
          Out 
        
       
         : 
        
        
        
          d 
         
         
         
           m 
          
         
           o 
          
         
           d 
          
         
           e 
          
         
           l 
          
         
        
       
         ] 
        
       
      
        [\text{In}:d_{ff},\ \text{Out}:d_{model}] 
       
      
    [In:dff, Out:dmodel]
  
     
      
       
       
         In 
        
       
         : 
        
       
         [ 
        
       
         b 
        
       
         a 
        
       
         t 
        
       
         c 
        
       
         h 
        
       
         _ 
        
       
         s 
        
       
         z 
        
       
         , 
        
       
         s 
        
       
         e 
        
       
         q 
        
       
         _ 
        
       
         l 
        
       
         e 
        
       
         n 
        
       
         , 
        
        
        
          d 
         
         
         
           m 
          
         
           o 
          
         
           d 
          
         
           e 
          
         
           l 
          
         
        
       
         ] 
        
       
         , 
        
       
         Out 
        
       
         : 
        
       
         [ 
        
       
         b 
        
       
         a 
        
       
         t 
        
       
         c 
        
       
         h 
        
       
         _ 
        
       
         s 
        
       
         z 
        
       
         , 
        
       
         s 
        
       
         e 
        
       
         q 
        
       
         _ 
        
       
         l 
        
       
         e 
        
       
         n 
        
       
         , 
        
        
        
          d 
         
         
         
           m 
          
         
           o 
          
         
           d 
          
         
           e 
          
         
           l 
          
         
        
       
         ] 
        
       
      
        \textbf{In}: [batch\_sz, seq\_len, d_{model}], \textbf{Out}: [batch\_sz, seq\_len, d_{model}] 
       
      
    In:[batch_sz,seq_len,dmodel],Out:[batch_sz,seq_len,dmodel]
Add&Norm
论文中: (post-Norm)
  
      
       
        
         
         
           S 
          
         
           u 
          
         
           b 
          
         
           l 
          
         
           a 
          
         
           y 
          
         
           e 
          
         
           r 
          
         
           C 
          
         
           o 
          
         
           n 
          
         
           n 
          
         
           e 
          
         
           c 
          
         
           t 
          
         
           i 
          
         
           o 
          
         
           n 
          
         
        
          ( 
         
        
          X 
         
        
          ) 
         
        
          = 
         
         
         
           L 
          
         
           a 
          
         
           y 
          
         
           e 
          
         
           r 
          
         
           N 
          
         
           o 
          
         
           r 
          
         
           m 
          
         
        
          ( 
         
        
          X 
         
        
          + 
         
         
         
           S 
          
         
           u 
          
         
           b 
          
         
           l 
          
         
           a 
          
         
           y 
          
         
           e 
          
         
           r 
          
         
        
          ( 
         
        
          X 
         
        
          ) 
         
        
          ) 
         
        
       
         \mathrm{SublayerConnection}(X)= \mathrm{LayerNorm}(X +\mathrm{Sublayer}(X)) 
        
       
     SublayerConnection(X)=LayerNorm(X+Sublayer(X))
AnnotatedTransformer 实现中: (pre-Norm)
  
      
       
        
         
         
           S 
          
         
           u 
          
         
           b 
          
         
           l 
          
         
           a 
          
         
           y 
          
         
           e 
          
         
           r 
          
         
           C 
          
         
           o 
          
         
           n 
          
         
           n 
          
         
           e 
          
         
           c 
          
         
           t 
          
         
           i 
          
         
           o 
          
         
           n 
          
         
        
          ( 
         
        
          X 
         
        
          ) 
         
        
          = 
         
        
          X 
         
        
          + 
         
         
         
           S 
          
         
           u 
          
         
           b 
          
         
           l 
          
         
           a 
          
         
           y 
          
         
           e 
          
         
           r 
          
         
        
          ( 
         
         
         
           L 
          
         
           a 
          
         
           y 
          
         
           e 
          
         
           r 
          
         
           N 
          
         
           o 
          
         
           r 
          
         
           m 
          
         
        
          ( 
         
        
          X 
         
        
          ) 
         
        
          ) 
         
        
       
         \mathrm{SublayerConnection}(X)= X+\mathrm{Sublayer}(\mathrm{LayerNorm}(X)) 
        
       
     SublayerConnection(X)=X+Sublayer(LayerNorm(X))
In : [ b a t c h _ s z , s e q _ l e n , d m o d e l ] , Out : [ b a t c h _ s z , s e q _ l e n , d m o d e l ] \textbf{In}: [batch\_sz, seq\_len, d_{model}], \textbf{Out}: [batch\_sz, seq\_len, d_{model}] In:[batch_sz,seq_len,dmodel],Out:[batch_sz,seq_len,dmodel]
其中:
- S u b l a y e r ∈ { M u l t i H e a d A t t n , F F N } \mathrm{Sublayer}\in\{\mathrm{MultiHeadAttn},\mathrm{FFN}\} Sublayer∈{MultiHeadAttn,FFN}
- 层标准化  
      
       
        
         
         
           L 
          
         
           a 
          
         
           y 
          
         
           e 
          
         
           r 
          
         
           N 
          
         
           o 
          
         
           r 
          
         
           m 
          
         
        
          ( 
         
        
          X 
         
        
          ) 
         
        
       
         \mathrm{LayerNorm}(X) 
        
       
     LayerNorm(X): 对张量  
      
       
        
        
          X 
         
        
       
         X 
        
       
     X 的最后一维( 
      
       
        
         
         
           d 
          
          
          
            m 
           
          
            o 
           
          
            d 
           
          
            e 
           
          
            l 
           
          
         
        
       
         d_{model} 
        
       
     dmodel 维, 表示每个样本)  
      
       
        
        
          x 
         
        
          = 
         
        
          X 
         
        
          [ 
         
        
          b 
         
        
          , 
         
        
          p 
         
        
          o 
         
        
          s 
         
        
          , 
         
        
          : 
         
        
          ] 
         
        
          ∈ 
         
         
         
           R 
          
          
          
            d 
           
           
           
             m 
            
           
             o 
            
           
             d 
            
           
             e 
            
           
             l 
            
           
          
         
        
       
         x=X[b,pos,:]\in\mathbb{R}^{d_{model}} 
        
       
     x=X[b,pos,:]∈Rdmodel 进行标准化.
 N o r m ( x ) = x − E ( x ) S D ( x ) + ϵ ∗ γ + β \mathrm{Norm}(x)=\frac{x-E(x)}{SD(x)+\epsilon}*\gamma+\beta Norm(x)=SD(x)+ϵx−E(x)∗γ+β. 其中, E ( x ) E(x) E(x) 为平均值(期望), S D ( x ) SD(x) SD(x) 为标准差, γ , β ∈ R d m o d e l \gamma,\beta\in\mathbb{R}^{d_{model}} γ,β∈Rdmodel 为可学习的参数, ϵ \epsilon ϵ 是用于数值稳定性(避免除 0)在分母上加的一个极小值标量.
- 残差连接 (Residual Connection): y = x + F ( x ) y=x+\mathcal{F}(x) y=x+F(x)
- 注: pre-Norm 与 post-Norm 的区别, 参考: 【重新了解Transformer模型系列_1】PostNorm/PreNorm的差别 - 知乎
Token Embedding
大小为  
     
      
       
       
         v 
        
       
         o 
        
       
         c 
        
       
         a 
        
       
         b 
        
       
         _ 
        
       
         s 
        
       
         z 
        
       
      
        vocab\_sz 
       
      
    vocab_sz 嵌入维度为  
     
      
       
        
        
          d 
         
         
         
           m 
          
         
           o 
          
         
           d 
          
         
           e 
          
         
           l 
          
         
        
       
      
        d_{model} 
       
      
    dmodel 的查询表(lookup table).
  
     
      
       
       
         In 
        
       
         : 
        
       
         [ 
        
       
         b 
        
       
         a 
        
       
         t 
        
       
         c 
        
       
         h 
        
       
         _ 
        
       
         s 
        
       
         z 
        
       
         , 
        
       
         s 
        
       
         e 
        
       
         q 
        
       
         _ 
        
       
         l 
        
       
         e 
        
       
         n 
        
       
         ] 
        
       
         , 
        
       
         Out 
        
       
         : 
        
       
         [ 
        
       
         b 
        
       
         a 
        
       
         t 
        
       
         c 
        
       
         h 
        
       
         _ 
        
       
         s 
        
       
         z 
        
       
         , 
        
       
         s 
        
       
         e 
        
       
         q 
        
       
         _ 
        
       
         l 
        
       
         e 
        
       
         n 
        
       
         , 
        
        
        
          d 
         
         
         
           m 
          
         
           o 
          
         
           d 
          
         
           e 
          
         
           l 
          
         
        
       
         ] 
        
       
      
        \textbf{In}: [batch\_sz, seq\_len], \textbf{Out}: [batch\_sz, seq\_len, d_{model}] 
       
      
    In:[batch_sz,seq_len],Out:[batch_sz,seq_len,dmodel]
  
      
       
        
         
         
           E 
          
         
           m 
          
         
           b 
          
         
           e 
          
         
           d 
          
         
           d 
          
         
           i 
          
         
           n 
          
         
           g 
          
         
           ( 
          
         
           x 
          
         
           ) 
          
         
        
          = 
         
         
         
           l 
          
         
           u 
          
         
           t 
          
         
        
          ( 
         
        
          x 
         
        
          ) 
         
        
          ⋅ 
         
         
          
          
            d 
           
           
           
             m 
            
           
             o 
            
           
             d 
            
           
             e 
            
           
             l 
            
           
          
         
        
       
         \mathrm{Embedding(x)} = \mathrm{lut}(x)\cdot\sqrt{d_{model}} 
        
       
     Embedding(x)=lut(x)⋅dmodel
Positional Encoding
用于
  
      
       
        
         
          
           
            
           
          
          
           
            
             
            
              P 
             
             
             
               E 
              
              
              
                ( 
               
              
                p 
               
              
                o 
               
              
                s 
               
              
                , 
               
              
                2 
               
              
                i 
               
              
                ) 
               
              
             
            
              = 
             
            
              sin 
             
            
               
             
            
              ( 
             
            
              p 
             
            
              o 
             
            
              s 
             
            
              / 
             
            
              1000 
             
             
             
               0 
              
              
              
                2 
               
              
                i 
               
              
                / 
               
               
               
                 d 
                
               
                 model 
                
               
              
             
            
              ) 
             
            
           
          
         
         
          
           
            
           
          
          
           
            
             
            
              P 
             
             
             
               E 
              
              
              
                ( 
               
              
                p 
               
              
                o 
               
              
                s 
               
              
                , 
               
              
                2 
               
              
                i 
               
              
                + 
               
              
                1 
               
              
                ) 
               
              
             
            
              = 
             
            
              cos 
             
            
               
             
            
              ( 
             
            
              p 
             
            
              o 
             
            
              s 
             
            
              / 
             
            
              1000 
             
             
             
               0 
              
              
              
                2 
               
              
                i 
               
              
                / 
               
               
               
                 d 
                
               
                 model 
                
               
              
             
            
              ) 
             
            
           
          
         
        
       
         \begin{aligned} &PE_{(pos,2i)} = \sin(pos / 10000^{2i/d_{\text{model}}})\\ &PE_{(pos,2i+1)} = \cos(pos / 10000^{2i/d_{\text{model}}}) \end{aligned} 
        
       
     PE(pos,2i)=sin(pos/100002i/dmodel)PE(pos,2i+1)=cos(pos/100002i/dmodel)
P E ( X ) = X + P , where ( p ( b , p o s , i ) ) = P , p ( b , p o s , i ) = P E ( p o s , i ) \mathrm{PE}(X)=X+ P,\ \text{where}\ (p_{(b,pos,i)})=P,\ p_{(b,pos,i)} = PE_{(pos,i)} PE(X)=X+P, where (p(b,pos,i))=P, p(b,pos,i)=PE(pos,i)
其中,  
     
      
       
       
         X 
        
       
         , 
        
       
         P 
        
       
         ∈ 
        
        
        
          R 
         
         
         
           b 
          
         
           a 
          
         
           t 
          
         
           c 
          
         
           h 
          
         
           _ 
          
         
           s 
          
         
           z 
          
         
           × 
          
         
           s 
          
         
           e 
          
         
           q 
          
         
           _ 
          
         
           l 
          
         
           e 
          
         
           n 
          
         
           × 
          
          
          
            d 
           
           
           
             m 
            
           
             o 
            
           
             d 
            
           
             e 
            
           
             l 
            
           
          
         
        
       
      
        X,P\in\mathbb{R}^{batch\_sz\times seq\_len\times d_{model}} 
       
      
    X,P∈Rbatch_sz×seq_len×dmodel, 即  
     
      
       
       
         X 
        
       
      
        X 
       
      
    X 和  
     
      
       
       
         P 
        
       
      
        P 
       
      
    P 为  
     
      
       
       
         [ 
        
       
         b 
        
       
         a 
        
       
         t 
        
       
         c 
        
       
         h 
        
       
         _ 
        
       
         s 
        
       
         z 
        
       
         , 
        
       
         s 
        
       
         e 
        
       
         q 
        
       
         _ 
        
       
         l 
        
       
         e 
        
       
         n 
        
       
         , 
        
        
        
          d 
         
         
         
           m 
          
         
           o 
          
         
           d 
          
         
           e 
          
         
           l 
          
         
        
       
         ] 
        
       
      
        [batch\_sz,seq\_len,d_{model}] 
       
      
    [batch_sz,seq_len,dmodel] 形状的张量;  
     
      
       
        
        
          p 
         
         
         
           ( 
          
         
           b 
          
         
           , 
          
         
           p 
          
         
           o 
          
         
           s 
          
         
           , 
          
         
           i 
          
         
           ) 
          
         
        
       
      
        p_{(b,pos,i)} 
       
      
    p(b,pos,i) 为  
     
      
       
       
         P 
        
       
      
        P 
       
      
    P 对应位置的元素,  
     
      
       
       
         p 
        
       
         o 
        
       
         s 
        
       
      
        pos 
       
      
    pos 为 token 在  
     
      
       
       
         s 
        
       
         e 
        
       
         q 
        
       
         _ 
        
       
         l 
        
       
         e 
        
       
         n 
        
       
      
        seq\_len 
       
      
    seq_len 长度的序列中位置,  
     
      
       
       
         i 
        
       
      
        i 
       
      
    i 为  
     
      
       
        
        
          d 
         
         
         
           m 
          
         
           o 
          
         
           d 
          
         
           e 
          
         
           l 
          
         
        
       
      
        d_{model} 
       
      
    dmodel 中的维度.
  
     
      
       
       
         In 
        
       
         : 
        
       
         [ 
        
       
         b 
        
       
         a 
        
       
         t 
        
       
         c 
        
       
         h 
        
       
         _ 
        
       
         s 
        
       
         z 
        
       
         , 
        
       
         s 
        
       
         e 
        
       
         q 
        
       
         _ 
        
       
         l 
        
       
         e 
        
       
         n 
        
       
         , 
        
        
        
          d 
         
         
         
           m 
          
         
           o 
          
         
           d 
          
         
           e 
          
         
           l 
          
         
        
       
         ] 
        
       
         , 
        
       
         Out 
        
       
         : 
        
       
         [ 
        
       
         b 
        
       
         a 
        
       
         t 
        
       
         c 
        
       
         h 
        
       
         _ 
        
       
         s 
        
       
         z 
        
       
         , 
        
       
         s 
        
       
         e 
        
       
         q 
        
       
         _ 
        
       
         l 
        
       
         e 
        
       
         n 
        
       
         , 
        
        
        
          d 
         
         
         
           m 
          
         
           o 
          
         
           d 
          
         
           e 
          
         
           l 
          
         
        
       
         ] 
        
       
      
        \textbf{In}: [batch\_sz, seq\_len, d_{model}], \textbf{Out}: [batch\_sz, seq\_len, d_{model}] 
       
      
    In:[batch_sz,seq_len,dmodel],Out:[batch_sz,seq_len,dmodel]
Subsequent Mask
也称为 “Causal Attention Mask”, 因果注意力掩码("FlashAttention"中的说法). 用于 Decoder 的注意力网络中屏蔽预测位置之后的信息, 即仅根据预测位置及之前的信息进行预测.
 掩码应用于矩阵  
     
      
       
       
         Q 
        
        
        
          K 
         
        
          T 
         
        
       
         / 
        
        
         
         
           d 
          
         
           k 
          
         
        
       
      
        QK^T/\sqrt{d_k} 
       
      
    QKT/dk, 是一个包括对角线的下三角矩阵(对应保留  
     
      
       
       
         Q 
        
       
      
        Q 
       
      
    Q 的  
     
      
       
       
         s 
        
       
         e 
        
       
         q 
        
       
         _ 
        
       
         l 
        
       
         e 
        
       
         n 
        
       
      
        seq\_len 
       
      
    seq_len 索引  
     
      
       
       
         i 
        
       
      
        i 
       
      
    i 大于等于  
     
      
       
        
        
          K 
         
        
          T 
         
        
       
      
        K^T 
       
      
    KT 的  
     
      
       
       
         s 
        
       
         e 
        
       
         q 
        
       
         _ 
        
       
         l 
        
       
         e 
        
       
         n 
        
       
      
        seq\_len 
       
      
    seq_len 索引  
     
      
       
       
         j 
        
       
      
        j 
       
      
    j 的计算结果), 将掩码为 0 部分(上三角部分为 0)对应的矩阵数据替换为极小值(如 -1e9).
  
     
      
       
       
         shape 
        
       
         : 
        
       
         [ 
        
       
         1 
        
       
         , 
        
       
         s 
        
       
         e 
        
       
         q 
        
       
         _ 
        
       
         l 
        
       
         e 
        
       
         n 
        
       
         , 
        
       
         s 
        
       
         e 
        
       
         q 
        
       
         _ 
        
       
         l 
        
       
         e 
        
       
         n 
        
       
         ] 
        
       
      
        \text{shape}: [1,seq\_len, seq\_len] 
       
      
    shape:[1,seq_len,seq_len]
 
代码实现
- The Annotated Transformer 官方 Colab 代码: AnnotatedTransformer.ipynb
- 带详细中文注释的 Colab 代码: AnnotatedTransformer.ipynb
- The Annotated Transformer 官方 GitHub 仓库: harvardnlp/annotated-transformer
- 带详细中文注释且模型代码分离的 GitHub 仓库: peakcrosser7/annotated-transformer
参考资料
- Vaswani A, Shazeer N, Parmar N, et al. Attention is all you need[J]. Advances in neural information processing systems, 2017, 30. https://dl.acm.org/doi/10.5555/3295222.3295349
- The Annotated Transformer - Harvard University
- Self-Attention v/s Attention: understanding the differences | by Nishant Usapkar | Medium
- Self attention vs attention in transformers | MLearning.ai
- 【重新了解Transformer模型系列_1】PostNorm/PreNorm的差别 - 知乎



















