RMSNorm
 
论文:https://openreview.net/pdf?id=SygkZ3MTJE
 Github:https://github.com/bzhangGo/rmsnorm?tab=readme-ov-file
 
 论文假设LayerNorm中的重新居中不变性是可有可无的,并提出了均方根层归一化(RMSNorm)。RMSNorm根据均方根(RMS)将一层神经元的总和输入正则化,得到模型重新缩放不变性特性和隐式学习率适应能力
LayerNorm 公式
深度学习当中,没有线性激活函数的预测公式
a i = ∑ j = 1 m w i j x j , y i = f ( a i + b i ) , \begin{aligned}a_i=\sum_{j=1}^mw_{ij}x_j,\quad y_i=f\left(a_i+b_i\right),\end{aligned} ai=j=1∑mwijxj,yi=f(ai+bi),
通过激活函数后,其中,随着前一层的更新,层的输入分布会发生变化。这可能会对参数梯度的稳定性产生负面影响,延迟模型收敛。为了减少这种转变,LayerNorm 对求和的输入进行归一化,以固定它们的均值和方差,如下所示:
a ˉ i = a i − μ σ g i , y i = f ( a ˉ i + b i ) , \begin{aligned}\bar{a}_i=\frac{a_i-\mu}{\sigma}g_i,\quad y_i=f\left(\bar{a}_i+b_i\right),\end{aligned} aˉi=σai−μgi,yi=f(aˉi+bi),
其中 a ˉ i \bar{a}_i aˉi是向量 a ˉ ∈ R n \bar{a}\in\mathbb{R}^n aˉ∈Rn的第 i i i个值,作为 α i \alpha_i αi的归一化替代值用于层激活。 g ∈ R n \mathbf{g}\in\mathbb{R}^n g∈Rn是增益参数,用于重新调整标准化求和输入的大小,一开始设置为 1。 μ \mu μ 和 σ 2 \sigma^2 σ2 分别是根据原始求和输入估计的均值和方差统计量。
μ = 1 n ∑ i = 1 n a i , σ = 1 n ∑ i = 1 n ( a i − μ ) 2 . \begin{aligned}\mu=\frac{1}{n}\sum_{i=1}^na_i,\quad\sigma=\sqrt{\frac{1}{n}\sum_{i=1}^n(a_i-\mu)^2}.\end{aligned} μ=n1i=1∑nai,σ=n1i=1∑n(ai−μ)2.
在本文中,假设重新缩放不变性是LayerNorm成功的原因,而不是重新定中心不变性。我们提出了RMSNorm,它只关注重新缩放不变性,并简单地根据均方根(RMS)统计对求和输入进行正则化:
  
      
       
        
         
          
           
            
             
              
              
                a 
               
              
                ˉ 
               
              
             
               i 
              
             
            
              = 
             
             
              
              
                a 
               
              
                i 
               
              
              
              
                RMS 
               
              
                ( 
               
              
                a 
               
              
                ) 
               
              
             
             
             
               g 
              
             
               i 
              
             
            
              , 
             
             
            
              where RMS 
             
            
              ( 
             
            
              a 
             
            
              ) 
             
            
              = 
             
             
              
               
               
                 1 
                
               
                 n 
                
               
               
               
                 ∑ 
                
                
                
                  i 
                 
                
                  = 
                 
                
                  1 
                 
                
               
                 n 
                
               
               
               
                 a 
                
               
                 i 
                
               
                 2 
                
               
              
             
            
              . 
             
            
           
          
         
        
       
         \begin{aligned}\bar{a}_i=\frac{a_i}{\text{RMS}(\mathbf{a})}g_i,\quad\text{where RMS}(\mathbf{a})=\sqrt{\frac{1}{n}\sum_{i=1}^na_i^2}.\end{aligned} 
        
       
     aˉi=RMS(a)aigi,where RMS(a)=n1i=1∑nai2.
python实现
# root mean square layer normalization
def rln(x, s):
    _eps = 1e-5
    output = x / tensor.sqrt((x * x).mean(1)[:,None] + _eps)
    output = s[None, :] * output
    return output
# layer normalization
def ln(x, b, s):
    _eps = 1e-5
    output = (x - x.mean(1)[:,None]) / tensor.sqrt((x.var(1)[:,None] + _eps))
    output = s[None, :] * output + b[None,:]
    return output
使用pytorch来写RMSNorm的函数
import torch
import torch.nn as nn
class RMSNorm(nn.Module):
    def __init__(self, d, p=-1., eps=1e-8, bias=False):
        """
            Root Mean Square Layer Normalization
        :param d: model size
        :param p: partial RMSNorm, valid value [0, 1], default -1.0 (disabled)
        :param eps:  epsilon value, default 1e-8
        :param bias: whether use bias term for RMSNorm, disabled by
            default because RMSNorm doesn't enforce re-centering invariance.
        """
        super(RMSNorm, self).__init__()
        self.eps = eps
        self.d = d
        self.p = p
        self.bias = bias
        self.scale = nn.Parameter(torch.ones(d))
        self.register_parameter("scale", self.scale)
        if self.bias:
            self.offset = nn.Parameter(torch.zeros(d))
            self.register_parameter("offset", self.offset)
    def forward(self, x):
        if self.p < 0. or self.p > 1.:
            norm_x = x.norm(2, dim=-1, keepdim=True)
            d_x = self.d
        else:
            partial_size = int(self.d * self.p)
            partial_x, _ = torch.split(x, [partial_size, self.d - partial_size], dim=-1)
            norm_x = partial_x.norm(2, dim=-1, keepdim=True)
            d_x = partial_size
        rms_x = norm_x * d_x ** (-1. / 2)
        x_normed = x / (rms_x + self.eps)
        if self.bias:
            return self.scale * x_normed + self.offset
        return self.scale * x_normed



















