快速上手Pytorch Lighting框架 | 深度学习入门

news2025/5/12 16:39:03

快速上手Pytorch Lighting框架 | 深度学习入门

  • 前言
    • 参考官方文档
  • 介绍
  • 快速上手
    • 基本流程
    • 常用接口
      • LightningModule
        • \_\_init\_\_ & setup()
        • \*\_step()
        • configure_callbacks()
        • configure_optimizers()
        • load_from_checkpoint
      • Trainer
        • 常用参数
    • 可选接口
      • Loggers
        • TensorBoard Logger
      • Callbacks
        • EarlyStopping
        • ModelCheckpoint
        • ProgressBar

前言

本文将介绍一个深度学习的训练框架——Pytorch Lighting框架。首先会介绍Pytorch Lighting框架的特点,然后会聚焦于你使用该框架时一定会使用的那些接口,包括我个人学习该框架时的经验传授。

参考官方文档

  • Welcome to ⚡ PyTorch Lightning — PyTorch Lightning 2.5.1.post0 documentation
  • Lightning in 15 minutes — PyTorch Lightning 2.5.1.post0 documentation
  • How to Organize PyTorch Into Lightning — PyTorch Lightning 2.5.1.post0 documentation

介绍

Pytorch Lightning是一个基于Pytorch的深度学习与机器学习的框架,它进一步封装Pytorch的接口,简化了深度学习训练代码的搭建过程,帮助用户能够关注于模型本身,而不需要再反复书写重复的训练代码。

Pytorch Lighting框架本质是对Pytorch的进一步封装,所以如果熟悉Pytorch框架,那么很容易上手Pytorch Lighting。结合官方文档以及个人使用体验,相比Pytorch,我认为Pytorch Lightning具有以下特点:

  • 代码复用性:Pytorch Lightning提供训练流程的所有接口,可以通过继承的方式,准备训练不同阶段的组件,从而在相似任务之间使用同一份代码。
  • 代码可读性:原本的Pytorch代码被进一步封装到框架中,让代码的聚合程度更高,训练流程更清晰,提高了代码的可读性。
  • 灵活性:通过框架类方法,可以根据需求定制特定环节的计算逻辑,精细控制训练的每个细节。
  • 可移植性:Pytorch Lightning的框架添加了自动检测训练设备的功能,同一份代码可以不仅在本地的CPU上训练,也可以通过远程服务器使用多GPU训练。
  • 自动化:框架集成了一些训练会用到的工具,比如日志输出、检查点记录等等。

更多详细内容,可以查阅官方文档介绍!

快速上手

基本流程

使用Lighting框架训练一个深度学习模型,遵循以下的流程:

  1. 安装Pytorch Lighting
  2. 定义Pytorch Lighting模块
  3. 定义数据集(生成样本迭代器)
  4. 配置训练器,训练模型
  5. 使用模型:包括测试模型或使用模型预测…
  6. 可视化训练过程

常用接口

上一小节,简单介绍了使用Pytorch Lighting框架的流程。其本质和普通的机器学习训练流程是一致的,如果只是简单的使用PL框架,几乎可以不输入多余的参数,就能直接开始训练,PL会帮助你完成大量的任务。同时框架提供了训练流程中每一步的对应接口,让用户可以根据需求,修改不同的细节。本小节中将具体介绍这些重要的接口,主要对应上述流程的第2步、第4步及第5步。

对于第3步,PL训练时需要迭代器类型的输入,可以手动生成样本迭代器,也可以使用Pytorch中的Dataloader等,此处将不再展开。

LightningModule

LightningModule是框架的核心部件,该类中提供关于训练的所有核心方法,涵盖6个方面:

  • 模型初始化:init & setup()
  • 训练循环:training_step()
  • 验证循环:validation_step()
  • 测试循环:test_step()
  • 预测循环:predict_step()
  • 优化器及学习率调整
__init__ & setup()

与类的基本使用方法相同,在LightingModule类的构造函数中,需要对类做必要的初始化,比如导入核心模型结构、优化器方法、损失函数类型等等。

setup(_trainer_, _pl_module_, _stage_)

setup()本质是一个回调函数,功能也是对类进行初始化设置,一般用于不同的训练阶段(predict,test,…)。调用该接口可以在不同的阶段采用不同的初始化策略。

*_step()

在不同的阶段的循环步中,可以部署期望的任务结果。除了基本的前馈计算、反向传播等操作,可以添加日志输出、指标收集等等。比如在train阶段,只获取loss指标;在test阶段,同时获取loss指标、acc指标等。

configure_callbacks()

通过重写该方法,可以定制训练所需的回调函数。当模型被调用的时候,比如执行test()的时候,框架会自动调用这些回调函数。
如果与Trainer中的回调函数表有冲突时,框架会优先使用此处的回调函数配置。

configure_optimizers()

该方法下,可以配置训练过程中使用的优化器类型以及具体的学习率。在常规模型的训练中,只会配置一个优化器,那么返回值就是单个优化器。如果是GANs或其他需要多个优化器的模型,支持返回多个迭代器,但是需要手动进行模型优化,即需要配置optimizer_step()方法。

load_from_checkpoint
load_from_checkpoint(_checkpoint_path_, _map_location=None_, _hparams_file=None_, _**kwargs_)

一般在测试阶段会需要调用该函数,用一个已经训练好的模型来初始化LightingModule类。checkpoint_path是训练好的模型的.ckpt文件存储位置,PL框架也支持传入URL,或一个类。

TIPS: 如果构造函数传入超参数,记得在构造函数中调用调用self.save_hyperparameters()。这样框架才会自动保存这些超参数到.ckpt文件中。否则如果训练、测试阶段分开进行时,需要重新导入模型,则需要准备.yaml文件,或超参数列表,才能正确的初始化模型。

Trainer

如果完成了LightningModule的配置,直接实例化一个训练器Trainer,便可以直接开始训练,默认生成的Train可以自动的帮助你完成所有训练任务:

model = MyLightningModule()

trainer = Trainer()
trainer.fit(model, train_dataloader, val_dataloader)

模型完成训练后,单独调用test()、validate()方法,对模型进行测试、验证。如果有特殊的训练、测试、验证需求,可以在实例化Trainer的时候进行配置。

常用参数
  • accelerator & devices::
    该参数是PL框架的特点之一,只需要实例化不同的Trainer就可以实现在不同硬件设备下的训练。也可以不指定参数,框架会自动匹配对应设备完成训练。
accelerator = ["cpu"] ["gpu"] ["tpu"] ["hpu"] ["auto"]
devices = [number of devices] ["auto"]
  • callbacks:: 传入单个回调类或回调列表。当传入的是列表时,框架会自动根据顺序逐个调用回调类。如果在PL框架中重写了configure_callbacks()方法,则以框架中的回调类优先。
  • max_epochs:: 最大的训练周期。
  • enable_progress_bar:: 是否显示进度条,默认将会为True。
  • logger:: 传入一个Loggers的实例,默认会使用TensorBoard Logger。设置为False则会禁用日志功能。
  • log_every_n_steps:: 日志记录的步长
  • strategy:: 训练策略,如ddp, fsdp等。
  • limit_train_batches:: 限制训练时的batch数量,一般在调试时使用。传入一个数字,当数字小于1时按比例计算【0.25,则使用Dataloader总数的25%的batch】;当数字大于1时按个数计算【5,则使用5个batch】

可选接口

Loggers

在PL中,继承自基类Logger有多种log格式可选,比如MLflow Logger,CSV logger,TensorBoard Logger等等。可以根据自己的需要,使用不同的日志记录形式。此处着重介绍TensorBoard Logger。

TensorBoard Logger

调用该类,日志将会以tensorboard格式进行记录,训练结束后可以可视化看到训练过程。

TensorBoardLogger(_save_dir_, _name='lightning_logs'_, _version=None_, _log_graph=False_, _default_hp_metric=True_, _prefix=''_, _sub_dir=None_, _**kwargs_)

重要的参数是save_dir,name,version。因为这将决定日志的保存位置:save_dir/name/version。在不同的训练阶段可以实例化不同的logger,就可以将不同的阶段的日志放置在不同路径,方便分析研究。

构建好Logger的实例后,作为参数传入到Trainer中即可,以下是官方文档中的例子:

from lightning.pytorch import Trainer
from lightning.pytorch.loggers import TensorBoardLogger

logger = TensorBoardLogger("tb_logs", name="my_model")
trainer = Trainer(logger=logger)

Callbacks

EarlyStopping

通过该类配置训练早停的策略。

EarlyStopping(_monitor_, _min_delta=0.0_, _patience=3_, _verbose=False_, _mode='min'_, _strict=True_, _check_finite=True_, _stopping_threshold=None_, _divergence_threshold=None_, _check_on_train_epoch_end=None_, _log_rank_zero_only=False_)
  • monitor:: 监视指标。
  • patience:: 传入一个整数n。默认情况下,每个epoch后都会检查指标的数值,当指标n次检查都一样时会触发早停。
  • mode:: 可选max或min模式:max模式下,指标不再增长时会触发早停;min模式下,指标不再下降时会触发早停。
ModelCheckpoint

通过该类配置模型保存的保存策略。

ModelCheckpoint(_dirpath=None_, _filename=None_, _monitor=None_, _verbose=False_, _save_last=None_, _save_top_k=1_, _save_weights_only=False_, _mode='min'_, _auto_insert_metric_name=True_, _every_n_train_steps=None_, _train_time_interval=None_, _every_n_epochs=None_, _save_on_train_epoch_end=None_, _enable_version_counter=True_)
  • dirpath & filename:: 模型文件将存储为dirpath/filename。
  • monitor:: 评价指标,需要搭配save_top_k选项一起使用。
  • save_top_k:: 传入一个整数n,指定保存模型的数量。
    1. n为0,不会保存模型。
    2. n为-1,会保存所有检查点时的模型。
    3. n大于2,模型会保存指标最好的n个模型。
ProgressBar

通过继承该类,重写成员方法,以按需求定制进度条的形式。

  • get_metrics :: 可以从基类获得所有指标,然后返回想要显示的指标的字典
  • print:: 定制进度条的输出样式。原文提到without breaking the progress bar.,应该是要注意输出的方式,比如不能重新刷新屏幕缓冲区?

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2374074.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

C++学习之STL学习

在经过前面的简单的C入门语法的学习后,我们开始接触C最重要的组成部分之一:STL 目录 STL的介绍 什么是STL STL的历史 UTF-8编码原理(了解) UTF-8编码原理 核心编码规则 规则解析 编码步骤示例 1. 确定码点范围 2. 转换为…

3. 仓颉 CEF 库封装

文章目录 1. capi 使用说明2. Cangjie CEF2. 1实现目标 3. 实现示例 1. capi 使用说明 根据上一节 https://blog.csdn.net/qq_51355375/article/details/147880718?spm1011.2415.3001.5331 所述, cefcapi 是libcef 共享库导出一个 C API, 而以源代码形式分发的 li…

LabVIEW多通道并行数据存储系统

在工业自动化监测、航空航天测试、生物医学信号采集等领域,常常需要对多个传感器通道的数据进行同步采集,并根据后续分析需求以不同采样率保存特定通道组合。传统单线程数据存储方案难以满足实时性和资源利用效率的要求,因此设计一个高效的多…

谷歌在即将举行的I/O大会之前,意外泄露了其全新设计语言“Material 3 Expressive”的细节

每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗?订阅我们的简报,深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同,从行业内部的深度分析和实用指南中受益。不要错过这个机会,成为AI领…

十三、基于大模型的在线搜索平台——整合function calling流程

基于大模型的在线搜索平台——整合function calling流程 一、function calling调用总结 上篇文章已经实现了信息抓取能力,并封装成了函数。现在最后一步将能力转换为大模型可以调用的能力,实现搜索功能就可以了。这篇主要实现大模型的function calling能…

力扣70题解

记录 2025.5.8 题目: 思路: 1.初始化:p 和 q 初始化为 0,表示到达第 0 级和第 1 级前的方法数。r 初始化为 1,表示到达第 1 级台阶有 1 种方法。 2.循环迭代:从第 1 级到第 n 级台阶进行迭代: p 更新为前…

电商双11美妆数据分析

1、初步了解 2.2 缺失值处理 通过上面观察数据发现sale_count,comment_count 存在缺失值,先观察存在缺失值的行的基本情况 2.3 数据挖掘寻找新的特征 给出各个关键词的分类类别 由title新生成两列类别 对是否是男性专用进行分析并新增一列 对每个产品总销量新增销售额这一列

24、TypeScript:预言家之书——React 19 类型系统

一、预言家的本质 "TypeScript是魔法世界的预言家之书,用静态类型编织代码的命运轨迹!" 霍格沃茨符文研究院的巫师挥动魔杖,类型注解与泛型的星轨在空中交织成防护矩阵。 ——基于《国际魔法联合会》第12号类型协议,Ty…

第8章-1 查询性能优化-优化数据访问

上一篇:《第7章-3 维护索引和表》 在前面的章节中,我们介绍了如何设计最优的库表结构、如何建立最好的索引,这些对于提高性能来说是必不可少的。但这些还不够——还需要合理地设计查询。如果查询写得很糟糕,即使库表结构再合理、索…

PCL点云按指定方向进行聚类(指定类的宽度)

需指定方向和类的宽度。测试代码如下&#xff1a; #include <iostream> #include <fstream> #include <vector> #include <string> #include <pcl/point_types.h> #include <pcl/point_cloud.h> #include <pcl/visualization/pcl_visu…

C#对SQLServer增删改查

1.创建数据库 2.SqlServerHelper using System; using System.Collections.Generic; using System.Data.SqlClient; using System.Data; using System.Linq; using System.Text; using System.Threading.Tasks;namespace WindowsFormsApp1 {internal class SqlServerHelper{//…

模拟太阳系(C#编写的maui跨平台项目源码)

源码下载地址&#xff1a;https://download.csdn.net/download/wgxds/90789056 本资源为用C#编写的maui跨平台项目源码&#xff0c;使用Visual Studio 2022开发环境&#xff0c;基于.net8.0框架&#xff0c;生成的程序为“模拟太阳系运行”。经测试&#xff0c;生成的程序可运行…

蓝桥杯14届 数三角

问题描述 小明在二维坐标系中放置了 n 个点&#xff0c;他想在其中选出一个包含三个点的子集&#xff0c;这三个点能组成三角形。然而这样的方案太多了&#xff0c;他决定只选择那些可以组成等腰三角形的方案。请帮他计算出一共有多少种选法可以组成等腰三角形&#xff1f; 输…

HTML12:文本框和单选框

表单元素格式 属性说明type指定元素的类型。text、password、 checkbox、 radio、submit、reset、file、hidden、image 和button&#xff0c;默认为textname指定表单元素的名称value元素的初始值。type为radio时必须指定一个值size指定表单元素的初始宽度。当type为text 或pas…

机器人厨师上岗!AI在餐饮界掀起新风潮!

想要了解人工智能在其他各个领域的应用&#xff0c;可以查看下面一篇文章 《AI在各领域的应用》 餐饮业是与我们日常生活息息相关的行业&#xff0c;而人工智能&#xff08;AI&#xff09;正在迅速改变这个传统行业的面貌。从智能点餐到食材管理&#xff0c;再到个性化推荐&a…

MySQL开篇

文章目录 一、前置知识1. MySQL的安装2. 前置一些概念知识 二、MySQL数据库操作2.1 概念2.2 数据库的操作2.2.1创建数据库命令2.2.2 查看数据库2.2.3 选中数据库2.2.4 删除数据库 三、MySQL数据表操作3.1 概念3.2 数据表的操作3.2.1 创建表 一、前置知识 1. MySQL的安装 MySQ…

Linux电脑本机使用小皮面板集成环境开发调试WEB项目

开发调试WEB项目&#xff0c;有时开发环境配置繁琐&#xff0c;可以使用小皮面板集成环境。 小皮面板官网&#xff1a; https://www.xp.cn/1.可以使用小皮面板安装脚本一键安装。登陆小皮面板管理后台 2.在“软件商店”使用LNMP一键部署集成环境。 3.添加网站&#xff0c;本…

问题及解决01-面板无法随着窗口的放大而放大

在MATLAB的App Designer中&#xff0c;默认情况下&#xff0c;组件的位置是固定的&#xff0c;不会随着父容器的大小变化而改变。问题图如下图所示。 解决&#xff1a; 为了让Panel面板能够随着UIFigure父容器一起缩放&#xff0c;需要使用布局管理器&#xff0c;我利用 MATLA…

操作系统原理实验报告

操作系统原理课程的实验报告汇总 实验三&#xff1a;线程的创建与撤销 实验环境&#xff1a;计算机一台&#xff0c;内装有VC、office等软件 实验日期&#xff1a;2024.4.11 实验要求&#xff1a; 1.理解&#xff1a;Windows系统调用的基本概念&#xff0c;进程与线程的基…

《Linux命令行大全(第2版)》PDF下载

内容简介 本书对Linux命令行进行详细的介绍&#xff0c;全书内容包括4个部分&#xff0c;第一部分由Shell的介绍开启命令行基础知识的学习之旅&#xff1b;第二部分讲述配置文件的编辑&#xff0c;如何通过命令行控制计算机&#xff1b;第三部分探讨常见的任务与必备工具&…