GRU 参数梯度推导与梯度消失分析
1. GRU 前向计算回顾
GRU 单元的核心计算步骤(忽略偏置项):
更新门: z_t = σ(W_z · [h_{t-1}, x_t])
重置门: r_t = σ(W_r · [h_{t-1}, x_t])
候选状态: ̃h_t = tanh(W_h · [r_t ⊙ h_{t-1}, x_t])
新状态: h_t = (1 - z_t) ⊙ h_{t-1} + z_t ⊙ ̃h_t
其中 σ
为 sigmoid 函数,⊙
表示逐元素乘法。
2. 关键梯度推导(以 ∂L/∂W_h 为例)
设时间 T
的损失为 L
。需计算 ∂L/∂W_h
(影响候选状态 ̃h_t
)。
反向传播从 h_t
开始:
∂L/∂h_t = δ_t // 从更高层或损失函数接收的梯度
h_t
对 ̃h_t
的梯度:
∂h_t/∂̃h_t = diag(z_t) // 对角矩阵,元素为 z_t
̃h_t
对 W_h
的梯度:
̃h_t = tanh(W_h · [r_t ⊙ h_{t-1}, x_t])
∂̃h_t/∂W_h = [∂̃h_t/∂(W_h · in)] · [∂(W_h · in)/∂W_h]
= diag(tanh'(net_h)) · [r_t ⊙ h_{t-1}, x_t]^T
其中 net_h = W_h · [r_t ⊙ h_{t-1}, x_t]
。
合并得 ∂L/∂W_h
:
∂L/∂W_h = (∂L/∂h_t) · (∂h_t/∂̃h_t) · (∂̃h_t/∂W_h)
= δ_t^T · diag(z_t) · diag(tanh'(net_h)) · [r_t ⊙ h_{t-1}, x_t]^T
= [δ_t ⊙ z_t ⊙ tanh'(net_h)] · [r_t ⊙ h_{t-1}, x_t]^T
3. 时间反向传播与梯度消失分析
损失 L
对历史状态 h_k (k < t)
的梯度是分析梯度消失的关键:
∂L/∂h_k = ∂L/∂h_t · (∂h_t/∂h_k)
计算 ∂h_t/∂h_k
(核心路径):
h_t = (1 - z_t) ⊙ h_{t-1} + z_t ⊙ ̃h_t
̃h_t = tanh(W_h · [r_t ⊙ h_{t-1}, x_t])
展开递归关系:
∂h_t/∂h_k = ∏_{i=k+1}^{t} ∂h_i/∂h_{i-1}
∂h_i/∂h_{i-1}
的具体形式:
∂h_i/∂h_{i-1} = diag(1 - z_i) + // 直接传递项
diag(z_i ⊙ tanh'(net_{h_i})) · W_h^h · diag(r_i) + // 候选状态路径
(∂h_i/∂z_i) · (∂z_i/∂h_{i-1}) + // 更新门路径
(∂h_i/∂r_i) · (∂r_i/∂h_{i-1}) // 重置门路径
其中 W_h^h
是 W_h
中对应 h_{i-1}
的子矩阵。
4. GRU 如何避免梯度消失
GRU 通过以下机制有效缓解梯度消失:
✅ 1. 加性状态更新
h_t = (1 - z_t) ⊙ h_{t-1} + z_t ⊙ ̃h_t
- 梯度路径多样性:梯度可通过两条路径传播:
(1 - z_t) ⊙ h_{t-1}
→ 梯度乘以(1 - z_t)
z_t ⊙ ̃h_t
→ 梯度乘以z_t
- 无损传播通道:当
z_t ≈ 0
时,h_t ≈ h_{t-1}
,梯度直接传递:
此时梯度可跨时间步无损传播,类似残差连接。∂h_t/∂h_{t-1} ≈ I (单位矩阵)
✅ 2. 门控机制调节
- 更新门
z_t
的作用:- 若
z_t ≈ 0
:模型保留历史信息,梯度主要走(1 - z_t)
路径。 - 若
z_t ≈ 1
:模型重置状态,梯度来自当前输入(避免旧信息干扰)。
- 若
- 重置门
r_t
的作用:- 控制历史状态
h_{t-1}
对候选状态̃h_t
的影响:̃h_t = tanh(W_h · [r_t ⊙ h_{t-1}, x_t])
- 当
r_t ≈ 0
时,h_{t-1}
不影响̃h_t
,适合忽略无关历史。
- 控制历史状态
✅ 3. 梯度幅度分析
∂h_i/∂h_{i-1}
的主项为 diag(1 - z_i)
:
- 该矩阵特征值接近 1(因
z_i ∈ (0,1)
→1 - z_i ∈ (0,1)
)。 - 乘积
∏_{i} (1 - z_i)
不会指数级衰减到 0(除非所有z_i ≈ 1
,但罕见)。
📊 与传统RNN对比:
传统RNN:h_t = tanh(W·[h_{t-1}, x_t])
→∂h_t/∂h_{t-1} = diag(tanh'(...)) · W
梯度包含W
的连乘,若|W| < 1
则指数衰减。
5. 效果总结
机制 | 效果 |
---|---|
加性更新 | 提供低衰减梯度路径 (∂h_t/∂h_{t-1} ≈ I ),避免连乘权重矩阵 |
更新门 (z_t) | 自适应选择梯度来源:历史状态 (梯度保持) 或新输入 (及时更新) |
重置门 (r_t) | 控制历史信息对当前候选状态的影响,防止无关历史干扰梯度计算 |
门控导数有界 | sigmoid 导数最大值为 0.25,但加性路径的 (1 - z_t) 项主导,整体梯度更稳定 |
结论
GRU 通过门控加性状态更新,在参数梯度计算中引入了近似恒等映射的路径(当 z_t ≈ 0
时)。这使其梯度 ∂h_t/∂h_k
的衰减速度远低于传统RNN,显著缓解了梯度消失问题,尤其适用于学习长序列依赖。实验表明,GRU 在语言建模、机器翻译等任务中能有效捕捉超过 100 步的依赖关系。