PyTorch使用Tricks:学习率衰减 !!

news2025/6/24 20:27:19

文章目录

前言

1、指数衰减

2、固定步长衰减

3、多步长衰减

4、余弦退火衰减

5、自适应学习率衰减

6、自定义函数实现学习率调整:不同层不同的学习率


前言

在训练神经网络时,如果学习率过大,优化算法可能会在最优解附近震荡而无法收敛;如果学习率过小,优化算法的收敛速度可能会非常慢。因此,一种常见的策略是在训练初期使用较大的学习率来快速接近最优解,然后逐渐减小学习率,使得优化算法可以更精细地调整模型参数,从而找到更好的最优解。

通常学习率衰减有以下的措施:

  • 指数衰减:学习率按照指数的形式衰减,每次乘以一个固定的衰减系数,可以使用 torch.optim.lr_scheduler.ExponentialLR 类来实现,需要指定优化器和衰减系数。
  • 固定步长衰减:学习率每隔一定步数(或者epoch)就减少为原来的一定比例,可以使用 torch.optim.lr_scheduler.StepLR 类来实现,需要指定优化器、步长和衰减比例。
  • 多步长衰减:学习率在指定的区间内保持不变,在区间的右侧值进行一次衰减,可以使用 torch.optim.lr_scheduler.MultiStepLR 类来实现,需要指定优化器、区间列表和衰减比例。
  • 余弦退火衰减:学习率按照余弦函数的周期和最值进行变化,可以使用 torch.optim.lr_scheduler.CosineAnnealingLR 类来实现,需要指定优化器、周期和最小值。
  • 自适应学习率衰减:这种策略会根据模型的训练进度自动调整学习率,可以使用 torch.optim.lr_scheduler.ReduceLROnPlateau 类来实现。例如,如果模型的验证误差停止下降,那么就减小学习率;如果模型的训练误差上升,那么就增大学习率。
  • 自适应函数实现学习率调整:不同层不同的学习率。

1、指数衰减

指数衰减是一种常用的学习率调整策略,其主要思想是在每个训练周期(epoch)结束时,将当前学习率乘以一个固定的衰减系数(gamma),从而实现学习率的指数衰减。这种策略可以帮助模型在训练初期快速收敛,同时在训练后期通过降低学习率来提高模型的泛化能力。

在PyTorch中,可以使用 torch.optim.lr_scheduler.ExponentialLR 类来实现指数衰减。该类的构造函数需要两个参数:一个优化器对象和一个衰减系数。在每个训练周期结束时,需要调用ExponentialLR 对象的 step() 方法来更新学习率。

以下是一个使用 ExponentialLR代码示例:

import torch
from torch.optim import SGD
from torch.optim.lr_scheduler import ExponentialLR

# 假设有一个模型参数
model_param = torch.nn.Parameter(torch.randn(2, 2, requires_grad=True))

# 使用SGD优化器,初始学习率设置为0.1
optimizer = SGD([model_param], lr=0.1)

# 创建ExponentialLR对象,衰减系数设置为0.9
scheduler = ExponentialLR(optimizer, gamma=0.9)

# 在每个训练周期结束时,调用step()方法来更新学习率
for epoch in range(100):
    # 这里省略了模型的训练代码
    # ...
    
    # 更新学习率
    scheduler.step()

在这个例子中,初始的学习率是0.1,每个训练周期结束时,学习率会乘以0.9,因此学习率会按照指数的形式衰减。

2、固定步长衰减

固定步长衰减是一种学习率调整策略,它的原理是每隔一定的迭代次数(或者epoch),就将学习率乘以一个固定的比例,从而使学习率逐渐减小。这样做的目的是在训练初期使用较大的学习率,加快收敛速度,而在训练后期使用较小的学习率,提高模型精度。

PyTorch提供了 torch.optim.lr_scheduler.StepLR 类来实现固定步长衰减,它的参数有:

  • optimizer:要进行学习率衰减的优化器,例如 torch.optim.SGD torch.optim.Adam等。
  • step_size:每隔多少隔迭代次数(或者epoch)进行一次学习率衰减,必须是正整数。
  • gamma:学习率衰减的乘法因子,必须是0到1之间的数,表示每次衰减为原来的 gamma倍。
  • last_epoch:最后一个epoch的索引,用于恢复训练的状态,默认为-1,表示从头开始训练。
  • verbose:是否打印学习率更新的信息,默认为False。

下面是一个使用  torch.optim.lr_scheduler.StepLR 类的具体例子,假设有一个简单的线性模型,使用 torch.optim.SGD 作为优化器,初始学习率为0.1,每隔5个epoch就将学习率乘以0.8,训练100个epoch:

import torch
import matplotlib.pyplot as plt

# 定义一个简单的线性模型
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc = torch.nn.Linear(1, 1)

    def forward(self, x):
        return self.fc(x)

# 创建模型和优化器
model = Net()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

# 创建固定步长衰减的学习率调度器
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.8)

# 记录学习率变化
lr_list = []

# 模拟训练过程
for epoch in range(100):
    # 更新学习率
    scheduler.step()
    # 记录当前学习率
    lr_list.append(optimizer.param_groups[0]['lr'])

# 绘制学习率曲线
plt.plot(range(100), lr_list, color='r')
plt.xlabel('Epoch')
plt.ylabel('Learning Rate')
plt.show()

学习率在每个5个epoch后都会下降为原来的0.8倍,直到最后接近0。

固定步长衰减指数衰减都是学习率衰减的策略,但它们在衰减的方式和速度上有所不同:

  • 固定步长衰减:在每隔固定的步数(或epoch)后,学习率会减少为原来的一定比例。这种策略的衰减速度是均匀的,不会随着训练的进行而改变。
  • 指数衰减:在每个训练周期(或epoch)结束时,学习率会乘以一个固定的衰减系数,从而实现学习率的指数衰减。这种策略的衰减速度是逐渐加快的,因为每次衰减都是基于当前的学习率进行的。

3、多步长衰减

多步长衰减是一种学习率调整策略,它在指定的训练周期(或epoch)达到预设的里程碑时,将学习率减少为原来的一定比例。这种策略可以在模型训练的关键阶段动态调整学习率。

在PyTorch中,可以使用 torch.optim.lr_scheduler.MultiStepLR 类来实现多步长衰减。以下是一个使用 MultiStepLR 的代码示例:

import torch
import torch.optim as optim
from torch.optim.lr_scheduler import MultiStepLR

# 假设有一个简单的模型
model = torch.nn.Linear(10, 2)
optimizer = optim.SGD(model.parameters(), lr=0.1)

# 创建 MultiStepLR 对象,设定在第 30 和 80 个 epoch 时学习率衰减为原来的 0.1 倍
scheduler = MultiStepLR(optimizer, milestones=[30, 80], gamma=0.1)

# 在每个训练周期结束时,调用 step() 方法来更新学习率
for epoch in range(100):
    # 这里省略了模型的训练代码
    # ...
    
    # 更新学习率
    scheduler.step()

在这个例子中,初始的学习率是0.1,当训练到第30个epoch时,学习率会变为0.01(即0.1*0.1),当训练到第80个epoch时,学习率会再次衰减为0.001(即0.01*0.1)。

4、余弦退火衰减

余弦退火衰减是一种学习率调整策略,它按照余弦函数的周期和最值来调整学习率。在PyTorch中,可以使用 torch.optim.lr_scheduler.CosineAnnealingLR 类来实现余弦退火衰减。

以下是一个使用 CosineAnnealingLR 的代码示例:

import torch
import torch.nn as nn
from torch.optim.lr_scheduler import CosineAnnealingLR
import matplotlib.pyplot as plt

# 假设有一个简单的模型
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.linear = nn.Linear(10, 2)

model = SimpleModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

# 创建 CosineAnnealingLR 对象,周期设置为 10,最小学习率设置为 0.01
scheduler = CosineAnnealingLR(optimizer, T_max=10, eta_min=0.01)

lr_list = []
# 在每个训练周期结束时,调用 step() 方法来更新学习率
for epoch in range(100):
    # 这里省略了模型的训练代码
    # ...
    
    # 更新学习率
    scheduler.step()
    lr_list.append(optimizer.param_groups[0]['lr'])

# 绘制学习率变化曲线
plt.plot(range(100), lr_list)
plt.xlabel('Epoch')
plt.ylabel('Learning Rate')
plt.title("Learning rate schedule: CosineAnnealingLR")
plt.show()

在这个例子中,初始的学习率是0.1,学习率会按照余弦函数的形式在0.01到0.1之间变化,周期为10个epoch。

5、自适应学习率衰减

自适应学习率衰减是一种学习率调整策略,它会根据模型的训练进度自动调整学习率。例如,如果模型的验证误差停止下降,那么就减小学习率;如果模型的训练误差上升,那么就增大学习率。在PyTorch中,可以使用 torch.optim.lr_scheduler.ReduceLROnPlateau 类来实现自适应学习率衰减。

以下是一个使用 ReduceLROnPlateau 的代码示例:

import torch
import torch.nn as nn
from torch.optim.lr_scheduler import ReduceLROnPlateau

# 假设有一个简单的模型
model = nn.Linear(10, 2)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

# 创建 ReduceLROnPlateau 对象,当验证误差在 10 个 epoch 内没有下降时,将学习率减小为原来的 0.1 倍
scheduler = ReduceLROnPlateau(optimizer, 'min', patience=10, factor=0.1)

# 模拟训练过程
for epoch in range(100):
    # 这里省略了模型的训练代码
    # ...

    # 假设有一个验证误差
    val_loss = ...

    # 在每个训练周期结束时,调用 step() 方法来更新学习率
    scheduler.step(val_loss)

6、自定义函数实现学习率调整:不同层不同的学习率

可以通过为优化器提供一个参数组列表来实现对不同层使用不同的学习率。每个参数组是一个字典,其中包含一组参数和这组参数的学习率。

以下是一个具体的例子,假设有一个包含两个线性层的模型,想要对第一层使用学习率0.01,对第二层使用学习率0.001:

import torch
import torch.nn as nn
import torch.optim as optim

# 定义一个包含两个线性层的模型
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.layer1 = nn.Linear(10, 2)
        self.layer2 = nn.Linear(2, 10)

model = SimpleModel()

# 创建一个参数组列表,每个参数组包含一组参数和这组参数的学习率
params = [{'params': model.layer1.parameters(), 'lr': 0.01},
          {'params': model.layer2.parameters(), 'lr': 0.001}]

# 创建优化器
optimizer = optim.SGD(params)

# 现在,当调用 optimizer.step() 时,第一层的参数会使用学习率 0.01 进行更新,第二层的参数会使用学习率 0.001 进行更新

在这个例子中,首先定义了一个包含两个线性层的模型。然后,创建一个参数组列表,每个参数组都包含一组参数和这组参数的学习率。最后创建了一个优化器SGD,将这个参数组列表传递给优化器。这样,当调用 optimizer.step() 时,第一层的参数会使用学习率0.01进行更新,第二层的参数会使用学习率0.001进行更新。

参考:深度图学习与大模型LLM

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1453910.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

算法刷题:长度最小的子数组

长度最小的子数组 .题目链接题目详情算法原理滑动窗口定义指针进窗口判断出窗口 我的答案 . 题目链接 长度最小的子数组 题目详情 算法原理 滑动窗口 这道题,我们采用滑动窗口的思想来解决,具体步骤如图所示 定义指针 如图所示,两个指针都需要从左往右进行遍历,因此初始值…

AIGC实战——能量模型(Energy-Based Model)

AIGC实战——能量模型 0. 前言1. 能量模型1.1 模型原理1.2 MNIST 数据集1.3 能量函数 2. 使用 Langevin 动力学进行采样2.1 随机梯度 Langevin 动力学2.2 实现 Langevin 采样函数 3. 利用对比散度训练小结系列链接 0. 前言 能量模型 (Energy-based Model, EBM) 是一类常见的生…

食物厨艺展示404错误页面模板源码

食物厨艺展示404错误页面模板源码,HTMLCSSJSCSS,记事本打开源码文件可以进行内容文字之类的修改,双击html文件可以本地运行效果,也可以上传到服务器里面,重定向这个界面 蓝奏云:https://wfr.lanzout.com/i3uC71oj52ah…

初识数据库:探索数据的世界

初识数据库:探索数据的世界 1. 什么是数据库?2. 数据库的类型2.1 关系型数据库(RDBMS)2.2 非关系型数据库(NoSQL) 3. 为什么使用数据库?4. 如何选择合适的数据库?5. 结语 在信息技术…

ovs和ovn安装

ovn和ovs介绍 ovn架构图 CMS||-----------|-----------| | || OVN/CMS Plugin || | || | || OVN Northbound DB || | || | || ovn-northd || …

设置windows10资源管理器等的边框

Windows10默认状态下,资源管理器、浏览器等没有边框,在打开多个窗口等情况下,想要拖动或选择某个窗口时,不是很好定位到窗口标题栏。 通过:设置(可以通过wini组合键打开设置)---》个性化 ---》…

第十二章[模块]:12.4:标准库:datetime

一,官方文档: 1,文档地址 datetime --- 基本日期和时间类型 — Python 3.12.2 文档源代码: Lib/datetime.py datetime 模块提供了用于操作日期和时间的类。 在支持日期时间数学运算的同时,实现的关注点更着重于如何能够更有效地解析其属性用于格式化输出和数据操作。 感知型…

vue-自定义创建项目(六)

为什么要自定义创建项目? 因为VueCli默认创建的项目不能够满足我们的要求,比如默认的项目中没有帮我们集成路由,vuex,eslink等功能。 默认项目 自定义创建项目 流程: 创建项目命令:vue create custom_dem…

VNCTF 2024 Web方向 WP

Checkin 题目描述:Welcome to VNCTF 2024~ long time no see. 开题,是前端小游戏 源码里面发现一个16进制编码字符串 解码后是flag CutePath 题目描述:源自一次现实渗透 开题 当前页面没啥好看的,先爆破密码登录试试。爆破无果…

项目第一次git commit后如何撤销

问题描述: # 1. 新建gitcode目录,然后在目录下 git init# 2. 用idea打开目录后,新建.gitignore文件后 git add .git commit -m "init project"git log# 3. 就出现如下图情况目的:向撤销该次代码提交 # 仅撤销 git com…

电商API接口|大数据关键技术之数据采集发展趋势

在大数据和人工智能时代,数据之于人工智能的重要性不言而喻。今天,让我们一起聊聊数据采集相关的发展趋势。 本文从电商API接口数据采集场景、数据采集系统、数据采集技术方面阐述数据采集的发展趋势。 01 数据采集场景的发展趋势 作为大数据和人工智…

Eclipse 分栏显示同一文件

Eclipse 分栏显示同一文件 1. Window -> EditorReferences 1. Window -> Editor Toggle Split Editor (Horizontal) :取消或设置水平分栏显示 Toggle Split Editor (Vertical) :取消或设置垂直分栏显示 References [1] Yongqiang Cheng, https:/…

【制作100个unity游戏之25】3D背包、库存、制作、快捷栏、存储系统、砍伐树木获取资源、随机战利品宝箱11(附带项目源码)

效果演示 文章目录 效果演示系列目录前言物品堆叠源码完结 系列目录 前言 欢迎来到【制作100个Unity游戏】系列!本系列将引导您一步步学习如何使用Unity开发各种类型的游戏。在这第25篇中,我们将探索如何用unity制作一个3D背包、库存、制作、快捷栏、存…

JoySSL免费可续签通配符证书:中小企业网站安全的理想选择

在数字时代,网络安全已成为企业和个人用户不可忽视的重要议题。其中,SSL证书作为保障网站数据传输加密和身份验证的关键工具,对于维护网站的安全性和可信度至关重要。然而,对于许多中小企业来说,高昂的SSL证书费用常常…

网络原理(UDP与TCP篇)

网络原理 协议应用层现成的自定义协议格式运输层端口号UDP协议首部格式TCP报文段的首部格式首部格式源端口号 和 目的端口号序号确认号ack 和 ACK数据偏移(**首部长度**)保留窗口校验和SYNFINRSTPSHURG 和 紧急指针扩展首部填充 TCP的可靠传输TCP的超时重…

函数、极限、连续——刷题(7

目录 1.题目:2.解题思路和步骤:3.总结:小结: 1.题目: 2.解题思路和步骤: 记住和差化积、积化和差公式即可: 然后就可以化简: 3.总结: 记住和差化积、积化和差公式即可…

Ps:污点修复画笔工具

污点修复画笔工具 Spot Healing Brush Tool专门用于快速清除图像中的小瑕疵、污点、尘埃或其他不想要的小元素。 它通过分析被修复区域周围的内容,无需手动取样,自动选择最佳的修复区域来覆盖和融合这些不完美之处,从而实现无痕修复的效果。 …

7.1 Qt 中输入行与按钮

目录 前言: 技能: 内容: 参考: 前言: line edit 与pushbotton的一点联动 当输入行有内容时,按钮才能使用,并能读出输入行的内容 技能: pushButton->setEnabled(false) 按钮不…

代码随想录 Leetcode56. 合并区间

题目&#xff1a; 代码(首刷自解 2024年2月18日&#xff09;&#xff1a; 这题与气球扎针&#xff0c;删除重复的大体逻辑相似。需要额外定义些变量来存储头尾 class Solution { private:const static bool cmp(vector<int>& a, vector<int>& b) {return …

解决:docker创建Redis容器成功,但无法启动Redis容器、也无报错提示

解决&#xff1a;docker创建Redis容器成功&#xff0c;但无法启动Redis容器、也无报错提示 一问题描述&#xff1a;1.docker若是直接简单使用run命令&#xff0c;但不挂载容器数据卷等参数&#xff0c;则可以启动Redis容器2.docker复杂使用run命令&#xff0c;使用指定redis.co…