Prerequsite:Adam优化算法
Adam优化算法很长一段时间都是比较主流的参数更新算法,也有很多变种,本文介绍在大模型训练过程中使用的AdamW和Adafator
AdamW
原论文:Decoupled Weight Decay Regularization
AdamW指的是Adam + Weight Decay(权重衰减)。
Adam相信很多读者已经了解了,Weight Decay解释起来也比较容易,为了防止过拟合,在计算损失函数时需要增加L2正则项:
  
      
       
        
        
          L 
         
        
          ( 
         
         
         
           θ 
          
          
          
            n 
           
          
            e 
           
          
            w 
           
          
         
        
          ) 
         
        
          = 
         
        
          L 
         
        
          ( 
         
         
         
           θ 
          
          
          
            o 
           
          
            l 
           
          
            d 
           
          
         
        
          ) 
         
        
          + 
         
        
          γ 
         
        
          / 
         
        
          2 
         
        
          ∣ 
         
        
          ∣ 
         
         
         
           θ 
          
         
           2 
          
         
        
          ∣ 
         
        
          ∣ 
         
        
          (公式 
         
        
          1 
         
        
          ) 
         
        
       
         L(\theta_{new})=L(\theta_{old})+\gamma/2||\theta^2|| (公式1) 
        
       
     L(θnew)=L(θold)+γ/2∣∣θ2∣∣(公式1)
求导计算梯度时:
  
      
       
        
         
         
           g 
          
         
           t 
          
         
        
          ← 
         
        
          ∇ 
         
         
         
           f 
          
         
           t 
          
         
        
          ( 
         
         
         
           θ 
          
          
          
            t 
           
          
            − 
           
          
            1 
           
          
         
        
          ) 
         
        
          + 
         
        
          γ 
         
         
         
           θ 
          
          
          
            t 
           
          
            − 
           
          
            1 
           
          
         
        
          (公式 
         
        
          2 
         
        
          ) 
         
        
       
         g_t \leftarrow \nabla f_t(\theta_{t-1}) + \gamma \theta_{t-1}(公式2) 
        
       
     gt←∇ft(θt−1)+γθt−1(公式2)
Weight Decay即在正则项前面乘以 γ ( 0 < γ < 1 ) \gamma (0<\gamma<1) γ(0<γ<1),用来缩放正则项产生的影响:L2正则会使得参数趋近于0,Weight Decay减轻这种趋势。
AdamW将Weight Decay应用在优化算法最后一步参数更新,参见下图(下图中的w等价于上面公式内的 
     
      
       
       
         γ 
        
       
      
        \gamma 
       
      
    γ)。
 
 图中紫色部分和绿色部分等价于公式2,紫色部分是原始的Adam应用Weight Decay的地方,绿色部分是AdamW应用Weight Decay的地方。
代码实现可以参见:理解AdamW
Adafator
原论文:Adafactor: Adaptive Learning Rates with Sublinear Memory Cost
Adafator没有像Adam那样保存权重矩阵每个元素的滑动平均值,而是保存了行维度或者是列维度的滑动平均值之和,这样显著降低了需要参数更新时需要的存储空间,计算方法如下图所示:
 
注意:由于 β 1 = 0 \beta_1=0 β1=0,相当于去掉了Adam的Weight Decay。这导致相较于Adam算法, Adafator存在表现不稳定的缺陷,有时候能比Adam更快收敛,有时候则不能。
参考文献
- Optimizer
- 理解AdamW
- 权重衰减/权重衰退——weight_decay



















