No More Adam: 新型优化器SGD_SaI

news2025/7/17 23:22:08
image-20250517152635918 image-20250517152701732

一.核心思想和创新点

2024年12月提出的SGD-SaI算法(Stochastic Gradient Descent with Scaling at Initialization)本质上是一种在训练初始阶段对不同参数块(parameter block)基于**梯度信噪比(g-SNR, Gradient Signal-to-Noise Ratio)进行局部学习率缩放的SGDM变体。**代码开源:https://github.com/AnonymousAlethiometer/SGD_SaI/

通过在训练初始化阶段针对参数分组进行一次性学习率缩放(基于g-SNR)即可实现与自适应方法类似的性能,完全不依赖于动态二阶动量,不仅保留了SGDM的高效性和简洁性,还显著提升了大模型训练的资源利用率和稳定性。这为未来深度学习模型的高效训练,特别是大规模Transformer的可扩展性提供了一条全新路径。

1.质疑自适应梯度方法的必要性

作者质疑了当前深度学习中广泛应用的自适应梯度优化器(如Adam及其变体)的必要性,认为其主要优势(即根据梯度历史动态调整每个参数的学习率以应对训练初期的梯度噪声、稀疏和不同参数组间的学习率不均衡)可以通过更简单、更高效的方法实现。

2.提出SGD-SaI优化器

作者提出了一种基于动量的随机梯度下降优化器改进版——SGD-SaI(Stochastic Gradient Descent with Scaling at Initialization),其核心做法是在训练初始阶段,利用梯度信噪比(g-SNR),对不同参数分组的学习率进行一次性分组缩放,而非每步动态自适应。这种方法完全摒弃了存储和更新每个参数的二阶动量(即方差估计),极大减少了优化器的内存开销和计算复杂度。

3.梯度信噪比(g-SNR)引导的分组缩放

g-SNR度量了参数块的梯度范数与方差之比,可以稳定反映各参数块在不同训练阶段的梯度特性。实验表明g-SNR在参数块内具有时间上的稳定性(即初始化时刻的分布基本决定整个训练过程),据此对参数分组的学习率进行归一化缩放,有效平衡了参数块间的训练进度。

4.广泛的适用性和实证优势

SGD-SaI方法不仅在传统卷积神经网络(CNN)任务上表现良好,在Transformer、ViT、GPT-2等参数分布高度异质的大模型任务中,也能够实现与主流自适应方法(AdamW、Adam-mini等)相当甚至更优的性能,同时具备更好的超参数鲁棒性和极低的内存占用,显著提升了大模型训练的可扩展性与资源利用效率。

image-20250517155635651

5.实验验证

在大规模语言模型、视觉Transformer、LoRA微调、扩散模型微调以及CNN等多类任务上的实验证明,SGD-SaI在准确率、收敛稳定性、内存效率、训练速度等方面均表现优异,尤其是在Transformer类任务中,解决了传统SGD难以收敛的问题,并可节省高达50%甚至75%的优化器状态内存,显著降低了训练门槛。

二.算法流程

1.参数分块(Parameter Grouping)

神经网络的所有参数 θ \theta θ 按照网络结构分为 B B B 个参数块(如不同层、不同类型参数等),记为 θ ( i ) \theta^{(i)} θ(i) i = 1 , 2 , . . . , B i=1,2,...,B i=1,2,...,B)。

2.计算每个参数块的梯度信噪比(g-SNR)

对于每个参数块 i i i,在第一次训练迭代时计算:

  • 梯度范数

    G n o r m ( i ) = ∑ j = 1 d i ( g j ( i ) ) 2 G_{\mathrm{norm}}^{(i)}=\sqrt{\sum_{j=1}^{d_i}(g_j^{(i)})^2} Gnorm(i)=j=1di(gj(i))2

其中 g j ( i ) g_j^{(i)} gj(i)是第 i i i 个参数块第 j j j个参数的梯度, d i d_i di 是该参数块参数数量。

  • 梯度均值

g ˉ ( i ) = 1 d i ∑ j = 1 d i g j ( i ) \bar{g}^{(i)}=\frac{1}{d_i}\sum_{j=1}^{d_i}g_j^{(i)} gˉ(i)=di1j=1digj(i)

  • 梯度方差

G v a r ( i ) = 1 d i ∑ j = 1 d i ( g j ( i ) − g ˉ ( i ) ) 2 G_\mathrm{var}^{(i)}=\frac{1}{d_i}\sum_{j=1}^{d_i}(g_j^{(i)}-\bar{g}^{(i)})^2 Gvar(i)=di1j=1di(gj(i)gˉ(i))2

  • g-SNR定义

G s n r ( i ) = G n o r m ( i ) G v a r ( i ) + ϵ G_{\mathrm{snr}}^{(i)}=\frac{G_{\mathrm{norm}}^{(i)}}{\sqrt{G_{\mathrm{var}}^{(i)}+\epsilon}} Gsnr(i)=Gvar(i)+ϵ Gnorm(i)

​ 其中 ϵ \epsilon ϵ 是防止分母为零的小常数。

3. 归一化g-SNR得到缩放因子

对所有参数块的 G snr ( i ) G_{\text{snr}}^{(i)} Gsnr(i)做最大值归一化:

G ~ s n r ( i ) = G s n r ( i ) max ⁡ k G s n r ( k ) \tilde{G}_\mathrm{snr}^{(i)}=\frac{G_\mathrm{snr}^{(i)}}{\max_kG_\mathrm{snr}^{(k)}} G~snr(i)=maxkGsnr(k)Gsnr(i)

这样归一化后的值在0到1之间。

4. 局部学习率缩放

对于每个参数块,设全局基础学习率为 η \eta η,则每个参数块的实际学习率为

η ( i ) = G ~ s n r ( i ) ⋅ η \eta^{(i)}=\tilde{G}_{\mathrm{snr}}^{(i)}\cdot\eta η(i)=G~snr(i)η

5.训练过程

  • 动量项采用传统SGDM:

m t ( i ) = μ m t − 1 ( i ) + ( 1 − μ ) g t ( i ) m_t^{(i)}=\mu m_{t-1}^{(i)}+(1-\mu)g_t^{(i)} mt(i)=μmt1(i)+(1μ)gt(i)

  • 权重更新(以decoupled weight decay为例):

θ t ( i ) = θ t − 1 ( i ) − λ η θ t − 1 ( i ) − η ( i ) m t ( i ) \theta_t^{(i)}=\theta_{t-1}^{(i)}-\lambda\eta\theta_{t-1}^{(i)}-\eta^{(i)}m_t^{(i)} θt(i)=θt1(i)ληθt1(i)η(i)mt(i)

其中 λ \lambda λ为权重衰减系数。

补充解释一下Decoupled weight decay(解耦权重衰减):

是一种针对权重衰减(weight decay)正则化项的优化策略,最早由Loshchilov和Hutter在AdamW优化器中系统提出。 其核心思想是将权重衰减正则项与梯度更新过程解耦,从而更好地控制正则化效果,提升模型泛化能力,避免对自适应梯度的干扰。

在SGD及其变体(如Adam)中,L2正则化通常被实现为在每次参数更新时,将权重衰减项( λ θ \lambda \theta λθ)加入到梯度中:

g ′ = g + λ θ g^{\prime}=g+\lambda\theta g=g+λθ

其中, g g g 是损失函数关于参数的梯度, λ \lambda λ 是权重衰减系数, θ \theta θ是参数。

然后按照普通优化器的参数更新公式进行迭代: θ t + 1 = θ t − η g ′ \theta_{t+1}=\theta_t-\eta g^{\prime} θt+1=θtηg

这种方式实际上把权重衰减项当做损失函数梯度的一部分来处理

解耦策略下,权重衰减项不再与梯度混合计算,而是在参数更新时直接对参数进行衰减,其更新公式为:

θ t + 1 = θ t − η g − η λ θ t \theta_{t+1}=\theta_t-\eta g-\eta\lambda\theta_t θt+1=θtηgηλθt

也可以拆分为两步:

1.正常的梯度下降更新: θ t + 1 / 2 = θ t − η g \theta_{t+1/2}=\theta_t-\eta g θt+1/2=θtηg

2.单独进行权重衰减: θ t + 1 = θ t + 1 / 2 − η λ θ t \theta_{t+1}=\theta_{t+1/2}-\eta\lambda\theta_t θt+1=θt+1/2ηλθt

这样做的好处是,权重衰减只针对参数本身进行缩减,而不会受梯度自适应调整的影响,能更精确地施加正则化,从而提升模型泛化效果。

SGD-SaI算法采用了decoupled weight decay,即先计算梯度用于g-SNR,再单独对参数进行衰减,这样能够保证g-SNR反映真实的梯度稀疏性和噪声特征,而不会被权重衰减项混淆,从而提升分组学习率缩放的有效性

值得强调的是:整个训练过程中每个参数块的缩放因子 G ~ snr ( i ) \tilde{G}_{\text{snr}}^{(i)} G~snr(i)只在初始化阶段计算一次,后续训练保持不变,极大降低了内存和计算开销。

三.代码解释

核心代码是sgd_sai.py,这里完整注释如下

import torch  
from torch.optim.optimizer import Optimizer  # PyTorch优化器基类

class SGD_sai(Optimizer):  # 定义SGD_sai优化器类,继承自PyTorch优化器
    r"""
    该优化器实现了论文"SGD-SaI: Stochastic Gradient Descent with Scaling at Initialization"的核心算法思想。
    支持标准SGD参数设置及momentum、weight_decay等常用优化选项。
    """

    def __init__(self, params, lr=1e-2, momentum=0.9, eps=1e-8, weight_decay=0, maximize=False):
        # 构造函数,初始化优化器的各项参数和默认设置
        defaults = dict(lr=lr, momentum=momentum, eps=eps, weight_decay=weight_decay, maximize=maximize)
        super(SGD_sai, self).__init__(params, defaults)  # 调用父类初始化
        self.gsnr_initialized = False   # 标志变量,指示g-SNR缩放因子是否已完成初始化

    @torch.no_grad()
    def step(self, closure=None):
        """
        执行一次优化器更新。包括g-SNR的初始化、动量累计、权重衰减和参数更新。
        """
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()  # 支持自定义loss回调(如二阶梯度)

        # 如果还没有初始化g-SNR缩放因子,则进行一次初始化
        if not self.gsnr_initialized:
            gsnr_list = []  # 用于存储每个参数组的g-SNR
            for group in self.param_groups:
                for p in group['params']:
                    if p.grad is None:  # 跳过无梯度参数
                        continue
                    grad = p.grad.data  # 获取梯度
                    grad_norm = grad.norm()  # 计算L2范数
                    grad_var = grad.var()  # 计算方差
                    eps = group['eps']  # 取数值稳定用小常数
                    gsnr = grad_norm / (grad_var.sqrt() + eps)  # 计算g-SNR
                    gsnr_list.append(gsnr)  # 存入列表
            # 将所有g-SNR按最大值归一化
            max_gsnr = torch.max(torch.stack(gsnr_list))
            norm_gsnr_list = [x / max_gsnr for x in gsnr_list]
            # 保存每个参数的g-SNR缩放因子
            idx = 0
            for group in self.param_groups:
                for p in group['params']:
                    if p.grad is None:
                        continue
                    p.gsnr_scale = norm_gsnr_list[idx]  # 动态添加属性
                    idx += 1
            self.gsnr_initialized = True  # 完成g-SNR初始化
            return loss  # 初始化阶段不做参数更新

        # 正式参数更新过程
        for group in self.param_groups:
            lr = group['lr']  # 获取全局学习率
            momentum = group['momentum']  # 获取动量系数
            weight_decay = group['weight_decay']  # 获取权重衰减系数
            maximize = group['maximize']  # 是否最大化目标
            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad.data
                if maximize:
                    grad = -grad  # 支持最大化模式

                # 解耦式权重衰减:对参数本身做缩放,不叠加到梯度
                if weight_decay != 0:
                    p.data.add_(p.data, alpha=-lr * weight_decay)

                # 获取g-SNR缩放因子
                scale = getattr(p, 'gsnr_scale', 1.0)
                # 获取或初始化动量
                param_state = self.state[p]
                if 'momentum_buffer' not in param_state:
                    buf = param_state['momentum_buffer'] = torch.clone(grad).detach()
                else:
                    buf = param_state['momentum_buffer']
                    buf.mul_(momentum).add_(grad, alpha=1 - momentum)
                # 参数更新:带g-SNR缩放的动量SGD
                p.data.add_(buf, alpha=-lr * scale)

        return loss  # 返回损失值以便监控

代码中有几点重点说明:

(1)g-SNR初始化

  • 仅在第一次调用step()时触发,遍历所有参数,依据当前梯度分布计算g-SNR,并最大归一化。
  • 利用动态属性p.gsnr_scale将每个参数的缩放因子缓存下来,后续训练反复使用。

(2)动量与权重衰减

  • 动量缓存采用PyTorch标准做法(momentum_buffer)。
  • 权重衰减采用解耦式,即直接对参数做缩放操作,避免正则项混入梯度统计,理论上等同于AdamW等现代优化器。

(3)参数更新

  • 更新公式为:基础学习率 × g-SNR缩放 × 动量项,精确实现论文中的“初始化缩放,分组局部自适应学习率”思想。
  • 若未初始化g-SNR则直接跳过参数更新。

(4)max/min目标灵活性
maximize选项用于兼容极大极小化目标。

四.使用优化器

安装:

pip install sgd-sai

使用:

from sgd_sai import SGD_sai

# 初始化优化器
optimizer = SGD_sai(model.parameters(), lr=lr, momentum=0.9, eps=1e-08, weight_decay=weight_decay)

for _ in range(steps):
    pred = model(input_ids)
    loss = loss_fn(pred, labels)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad(set_to_none=True)

在每个训练step前调用 optimizer.zero_grad(),是为了清空所有参数的 .grad 属性,以避免梯度累积。否则多次反向传播会让梯度不断相加,导致梯度异常。

set_to_none=False(默认值)时,会将每个参数的 .grad 置为与原来形状相同的全零张量

set_to_none=True 时,则会直接把 .grad 设为 None

set_to_none=True 通常更高效,节省了将张量置零的时间与显存开销,特别适合大模型或分布式训练,且官方推荐优先使用。

只要后续所有反向传播都能正确地重新分配 .grad,则功能完全等价。

某些情况下(比如自定义梯度操作),如果你的代码假设 .grad 一定存在且是全零tensor,才建议用默认方式。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2378400.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

JSP链接MySQL8.0(Eclipse+Tomcat9.0+MySQL8.0)

所用环境 Eclipse Tomcat9.0 MySQL8.0.21(下载:MySQL Community Server 8.0.21 官方镜像源下载 | Renwole) mysql-connector-java-8.0.21(下载:MySQL :: Begin Your Download) .NET Framework 4.5.2(下…

SEO长尾词与关键词优化实战

内容概要 在SEO优化体系中,长尾关键词与核心关键词的协同作用直接影响流量获取效率与用户转化路径。长尾词通常由3-5个词组构成,搜索量较低但意图明确,能精准触达细分需求用户;核心关键词则具备高搜索量与广泛覆盖能力&#xff0…

机器学习-人与机器生数据的区分模型测试-数据处理1

附件为训练数据,总体的流程可以作为参考。 导入依赖 import pandas as pd import os import numpy as np from sklearn.model_selection import train_test_split,GridSearchCV from sklearn.ensemble import RandomForestClassifier,VotingClassifier from skle…

HelloWorld

HelloWorld 新建一个java文件 文件后缀名为 .javahello.java【注意】系统可能没有显示文件后缀名,我们需要手动打开 编写代码 public class hello {public static void main(String[] args) {System.out.print(Hello,World)} }编译 javac java文件,会生…

SEO 优化实战:ZKmall模板商城的 B2C商城的 URL 重构与结构化数据

在搜索引擎算法日益复杂的今天,B2C商城想要在海量信息中脱颖而出,仅靠优质商品和营销活动远远不够。ZKmall模板商城以实战为导向,通过URL 重构与结构化数据优化两大核心策略,帮助 B2C 商城实现从底层架构到搜索展示的全面升级&…

数字万用表与指针万用表使用方法及注意事项

在电子测量领域,万用表是极为常用的工具,数字万用表和指针万用表各具特点。熟练掌握它们的使用方法与注意事项,能确保测量的准确性与安全性。下面为您详细介绍: 一 、数字万用表按钮功能 > 进入及退出手动量程模式 每 按 […

【读代码】端到端多模态语言模型Ultravox深度解析

一、项目基本介绍 Ultravox是由Fixie AI团队开发的开源多模态大语言模型,专注于实现音频-文本的端到端实时交互。项目基于Llama 3、Mistral等开源模型,通过创新的跨模态投影架构,绕过了传统语音识别(ASR)的中间步骤,可直接将音频特征映射到语言模型的高维空间。 核心优…

RabbitMQ工作流程及使用方法

一、什么是RabbitMQ RabbitMQ 是一款基于 ‌AMQP(高级,消息队列协议)‌ 的开源消息中间件,专为分布式系统设计,用于实现应用程序间的异步通信,其核心功能是通过 ‌消息代理(Message Broker&…

算法:分治法

实验内容 在一个2kⅹ2k个方格组成的棋盘中,若恰有一个方格与其他方格不同,则称该方格为特殊方格,且称该棋盘为一特殊棋盘。 显然,特殊方格出现的位置有4k 种情况,即k>0,有4k 种不同的特殊棋盘 棋盘覆盖&#xff1a…

MySQL初阶:sql事务和索引

索引(index) 可以类似理解为一本书的目录,一个表可以有多个索引。 索引的意义和代价 在MySQL中使用select进行查询时会经过: 1.先遍历表 2.将条件带入每行记录中进行判断,看是否符合 3.不符合就跳过 但当表中的…

docker部署第一个Go项目

1.前期准备 目录结构 main.go package mainimport ("fmt""github.com/gin-gonic/gin""net/http" )func main() {fmt.Println("\n .::::.\n .::::::::.\n :::::::::::\n …

Visual Studio2022跨平台Avalonia开发搭建

由于我已经下载并安装了 VS2022版本,这里就跳过不做阐述。 1.安装 Visual Studio 2022 安装时工作负荷Tab页勾选 ‌“.NET 桌面开发”‌ 和“Visual Studio扩展开发”‌ ,这里由于不是用的微软的MAUI,所以不用选择其他的来支持跨平台开发&a…

css iconfont图标样式修改,js 点击后更改样式

背景: 在vue项目中,通过点击/鼠标覆盖,更改选中元素的样式,可以通过js逻辑,也可以根据css样式修改。包括以下内容:iconfont图标的引入以及使用,iconfont图标样式修改【导入文件是纯白&#xff0…

开源项目实战学习之YOLO11:12.4 ultralytics-models-sam-memory_attention.py源码分析

👉 点击关注不迷路 👉 点击关注不迷路 👉 另外,前些天发现了一个巨牛的AI人工智能学习网站,通俗易懂,风趣幽默,忍不住分享一下给大家。感兴趣的可以点击相关跳转链接。 点击跳转到网站。 ultralytics-models-sam 1.sam-modules-memory_attention.pyblocks.py: 定义模…

【沉浸式求职学习day42】【算法题:滑动窗口】

沉浸式求职学习 长度最小的子数组水果成篮 关于算法题:滑动窗口的几个题目 长度最小的子数组 给定一个含有 n 个正整数的数组和一个正整数 s ,找出该数组中满足其和 ≥ s 的长度最小的 连续 子数组,并返回其长度。如果不存在符合条件的子数组…

LIIGO ❤️ RUST 12 YEARS

LIIGO 💖 RUST 12 YEARS 今天是RUST语言1.0发布十周年纪念日。十年前的今天,2015年的今天,Rust 1.0 正式发行。这是值得全球Rust支持者隆重纪念的日子。我借此机会衷心感谢Rust语言创始人Graydon Hoare,Mozilla公司,以…

Linux基础开发工具二(gcc/g++,自动化构建makefile)

3. 编译器gcc/g 3.1 背景知识 1. 预处理(进行宏替换/去注释/条件编译/头文件展开等) 2. 编译(生成汇编) 3. 汇编(生成机器可识别代码) 4. 连接(生成可执行文件或库文件) 3.2 gcc编译选项 格式 : gcc …

全局异常处理:如何优雅地统一管理业务异常

在软件开发中,异常处理是保证系统健壮性的重要环节。一个良好的异常处理机制不仅能提高代码的可维护性,还能为使用者提供清晰的错误反馈。本文将介绍如何通过全局异常处理和业务异常统一处理来编写更加优雅的代码。 一、传统异常处理的痛点 1.1 典型问…

动态规划-LCR 166.珠宝的最大价值-力扣(LeetCode)

一、题目解析 frame二维矩阵中每个值代表珠宝的价值,现在从左上角开始拿珠宝,只能向右或向下拿珠宝,到达右下角时停止拿珠宝,要求拿的珠宝价值最大。 二、算法解析 1.状态表示 我们想要知道的是到达[i,j]为位置时的最大价值&am…

JDBC实现模糊、动态与分页查询的详解

文章目录 一. 模糊查询1. Mysql的写法2. JDBC的实现 二. 动态条件查询1. 创建生成动态条件查询sql的方法2. 完整的动态条件查询类以及测试类 三. 分页查询1. 什么是分页查询?2. 分页查询的分类3. MySQL的实现4. JDBC实现4.1. 创建page页4.2. 分页的实现 本章来讲一下…