2025-05-31 Python深度学习9——网络模型的加载与保存

news2025/6/4 11:25:35

文章目录

  • 1 使用现有网络
  • 2 修改网络结构
    • 2.1 添加新层
    • 2.2 替换现有层
  • 3 保存网络模型
    • 3.1 完整保存
    • 3.2 参数保存(推荐)
  • 4 加载网络模型
    • 4.1 加载完整模型文件
    • 4.2 加载参数文件
  • 5 Checkpoint
    • 5.1 保存 Checkpoint
    • 5.2 加载 Checkpoint

本文环境:

  • Pycharm 2025.1
  • Python 3.12.9
  • Pytorch 2.6.0+cu124

​ PyTorch 通过torchvision.models提供预训练模型(如 VGG16)。

​ 网址链接:https://docs.pytorch.org/vision/stable/models.html。

1 使用现有网络

​ 以 VGG16 为例,进入网址:https://docs.pytorch.org/vision/stable/models/generated/torchvision.models.vgg16.html#torchvision.models.vgg16。

image-20250531103635500

方法一:使用随机初始化权重

​ 将 weights 设置为 None,从 0 开始训练自己的网络。

vgg16_false = torchvision.models.vgg16(weights=None)  # 权重随机初始化

方法二:加载预训练权重

​ 也可以使用预训练好的网络参数,加载后可直接使用网络。
这将从官网上下载已训练好的模型文件。

vgg16_true = torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.IMAGENET1K_V1)

​ 可打印网络查看其模型结构:

print(vgg16_true)
image-20250531104433678
...
image-20250531104447912

2 修改网络结构

2.1 添加新层

​ 使用add_module在分类器(classifier)后追加全连接层:

vgg16_true.classifier.add_module('add_linear', nn.Linear(1000, 10))
image-20250531104536554

2.2 替换现有层

​ 直接修改分类器的最后一层(如适配 CIFAR10 的 10 分类任务):

vgg16_false.classifier[6] = nn.Linear(4096, 10)  # 替换第6层
image-20250531104551228

3 保存网络模型

​ 使用torch.save()方法保存网络模型。文件扩展名推荐使用.pt.pth

3.1 完整保存

​ 将模型类和参数一并保存到文件中。

torch.save(vgg16, 'vgg16_method1.pth')  # 包含模型类和参数
  • 优点:加载时无需重新定义模型结构。
  • 缺点:文件较大,且依赖原始代码环境(见 4.1 节)。

3.2 参数保存(推荐)

​ 仅保存参数字典到文件中。

torch.save(vgg16.state_dict(), 'vgg16_method2.pth')  # 仅保存参数字典
  • 优点:文件小,灵活性强,适合生产部署。

示例

import torch
import torchvision.models
from torch import nn

vgg16 = torchvision.models.vgg16(weights=None)

# 保存方式 1,模型结构 + 模型参数
torch.save(vgg16, 'vgg16_method1.pth')

# 保存方式 2,模型参数(官方推荐)
torch.save(vgg16.state_dict(), 'vgg16_method2.pth')

4 加载网络模型

​ 使用torch.load()方法加载网络模型。

4.1 加载完整模型文件

​ 加载完整模型时,需将 weights_only 参数设置为 False。

model = torch.load('vgg16_method1.pth', weights_only=False)  # 需确保模型类已定义

​ 模型打印结果如下:

print(model)
image-20250531111142517

注意

​ 若保存自定义模型,加载时必须确保环境中也有该模型的定义,否则会出现报错。

  • model_save.py

    # model_save.py
    
    import torch
    from torch import nn
    
    
    class MyModel(nn.Module):
        def __init__(self):
            super(MyModel, self).__init__()
            self.conv1 = nn.Conv2d(3, 64, 3)
    
        def forward(self, x):
            return self.conv1(x)
    
    
    model = MyModel()
    torch.save(model, 'my_model_method1.pth')
    
  • model_load.py

    import torch
    
    model = torch.load('my_model_method1.pth', weights_only=False)  # 报错,找不到 MyModel 的定义
    

    先运行 model_save.py,再运行 model_load.py,则会出现以下报错:

image-20250531110244566

4.2 加载参数文件

​ 首先,使用torch.load()方法加载网络模型。

​ 使用模型时,需先创建匹配的网络结构,再使用model.load_state_dict()加载参数数据。

vgg16 = torchvision.models.vgg16(weights=None)
model_dict = torch.load('vgg16_method2.pth')
vgg16.load_state_dict(model_dict)  # 需结构匹配

​ 模型打印结果是参数字典:

print(model_dict)
image-20250531111411199

注意

​ 模型保存时若在 GPU 上,加载时需指定 map_location 为 cup。

torch.load('model.pth', map_location=torch.device('cpu'))

​ 将参数加载到模型后,手动迁移到 GPU:

model = MyModel()
model.load_state_dict(model_dict)
model.to('cuda:0')

5 Checkpoint

​ 使用 Checkpoint 可以在训练过程中定期保存模型的状态,以便在中断后可以恢复训练,或者在测试时使用最终的模型。文件扩展名推荐使用.tar

5.1 保存 Checkpoint

​ 要保存一个模型的 Checkpoint,通常需要保存以下数据:

  • 模型的 state_dict(状态字典);
  • 优化器的状态;
  • 额外的信息,如 epoch 等。
import torch
 
# 假设 model 是你的模型,optimizer 是你的优化器
checkpoint = {
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': loss
}
 
# 保存checkpoint
torch.save(checkpoint, 'checkpoint.tar')

5.2 加载 Checkpoint

​ 加载 Checkpoint,首先需要加载文件,然后将其内容恢复到模型和优化器的状态中。

# 假设 model 和 optimizer 是你的模型和优化器实例
checkpoint = torch.load('checkpoint.tar')
 
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

# 如果需要,可以继续训练
model.train()  # 确保模型处于训练模式

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2396353.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

长安链起链调用合约时docker ps没有容器的原因

在调用这个命令的时候,发现并没有出现官方预期的合约容器,这是因为我们在起链的时候没有选择用docker的虚拟环境,实际上这不影响后续的调用,如果想要达到官方的效果那么你只需要在起链的时候输入yes即可,如图三所示

Appium+python自动化(七)- 认识Appium- 上

简介 经过前边的各项准备工作,终于才把appium搞定。 一、appium自我介绍 appium是一款开源的自动化测试工具,可以支持iOS和安卓平台上的原生的,基于移动浏览器的,混合的应用(APP)。 1、 使用appium进…

模块联邦:更快的微前端方式!

什么是模块联邦 在前端项目中,不同团队之间的业务模块可能有耦合,比如A团队的页面里有一个富文本模块(组件),而B团队 的页面恰好也需要使用这个富文本模块。 传统模式下,B团队只能去抄A团队的代码&#x…

前端基础学习html+css+js

HTML 区块 div标签,块级标签 span包装小部分文本,行内元素 表单 CSS css选择器 css属性 特性blockinlineinline-block是否换行✅ 换行❌ 不换行❌ 不换行可设置宽高✅ 支持❌ 不支持✅ 支持常见元素div容器 p段落 h标题span文本容器 a超链接img图片…

手机打电话时将对方DTMF数字转为RFC2833发给局域网SIP坐席

手机打电话时将对方DTMF数字转为RFC2833发给局域网SIP坐席 --局域网SIP坐席呼叫 上一篇:手机打电话时由对方DTMF响应切换多级IVR语音菜单(完结) 下一篇:安卓App识别手机系统弹授权框包含某段文字-并自动点击确定按钮 一、前言 …

SAP Business One:无锡哲讯科技助力中小企业数字化转型的智慧之选

数字化转型,中小企业的必经之路 在当今竞争激烈的商业环境中,数字化转型已不再是大型企业的专利,越来越多的中小企业开始寻求高效、灵活的管理系统来优化业务流程、提升运营效率。作为全球领先的企业管理软件,SAP Business One…

小型语言模型:为何“小”才是“大”?

当说到人工智能(AI)的时候,大家通常会想到那些拥有数十亿参数的超大型语言模型,它们能做出一些令人惊叹的事情。 厉害不厉害?绝对厉害! 但对于大多数企业和开发者来说,实用吗?可能…

秋招Day12 - 计算机网络 - 网络综合

从浏览器地址栏输入URL到显示网页的过程了解吗? 从在浏览器地址栏输入 URL 到显示网页的完整过程,并不是一个单一的数据包从头到尾、一次性地完成七层封装再七层解析的过程。 而是涉及到多次、针对不同目的、与不同服务器进行的、独立的网络通信交互&a…

QT-JSON

#include <QJsonDocument>#include <QJsonObject>#include <QJsonArray>#include <QFile>#include <QDebug>void createJsonFile() {// 创建一个JSON对象 键值对QJsonObject jsonObj;jsonObj["name"] "John Doe";jsonObj[…

IP 风险画像技术略解

IP 风险画像的技术定义与价值 IP 风险画像通过整合 IP 查询数据与 IP 离线库信息&#xff0c;结合机器学习算法&#xff0c;为每个 IP 地址生成多维度风险评估模型。其核心价值在于将传统的静态 IP 黑名单升级为动态风险评估体系&#xff0c;可实时识别新型网络威胁&#xff0…

秋招Day12 - 计算机网络 - 基础

说一下计算机网络体系结构 OSI七层模型&#xff0c;TCP/IP四层模型和五层体系结构 说说OSI七层模型&#xff1f; 应用层&#xff1a;最靠近用户的层&#xff0c;用于处理特定应用程序的细节&#xff0c;提供了应用程序和网络服务之间的接口。表示层&#xff1a;确保从一个系…

【网络安全】——Modbus协议详解:工业通信的“通用语言”

目录 一、初识Modbus&#xff1a;工业通信的基石 1.1 协议全称 1.2 协议简史 二、核心特性解析 2.1 架构设计 2.2 典型应用场景 三、协议族全景图 3.1 协议栈分类 3.2 版本演进对比 四、协议报文深度解析 4.1 Modbus RTU帧结构 4.2 Modbus TCP报文 五、通信机制实…

【GlobalMapper精品教程】095:如何获取无人机照片的拍摄方位角

文章目录 一、加载无人机照片二、计算方位角三、Globalmapper符号化显示方向四、arcgis符号化显示方向一、加载无人机照片 打开软件,加载无人机照片,在GLobalmapperV26中文版中,默认显示如下的航线信息。 关于航线的起止问题,可以直接从照片名称来确定。 二、计算方位角 …

小提琴图绘制-Graph prism

在 GraphPad Prism 中为小提琴图添加显著性标记(如*P<0.05)的步骤如下: 步骤1:完成统计检验 选择数据表:确保数据已按分组排列(如A列=Group1,B列=Group2)。执行统计检验: 点击工具栏 Analyze → Column analyses → Mann-Whitney test(非参数检验,适用于非正态数…

[GHCTF 2025]SQL???

打开题目在线环境&#xff1a; 先尝试注入&#xff1a; id1;show databases; 发现报错&#xff0c;后来看了wp才知道这个题目是SQLite注入。 我看的是这个师傅的wp: https://blog.csdn.net/2401_86190146/article/details/146164505?ops_request_misc%257B%2522request%255Fid…

【科研绘图系列】R语言绘制GO term 富集分析图(enrichment barplot)

禁止商业或二改转载,仅供自学使用,侵权必究,如需截取部分内容请后台联系作者! 文章目录 介绍加载R包数据下载导入数据数据预处理画图code 2code 3系统信息介绍 本文介绍了使用R语言绘制GO富集分析条形图的方法。通过加载ggplot2等R包,对GO term数据进行预处理,包括p值转换…

Laravel单元测试使用示例

Date: 2025-05-28 17:35:46 author: lijianzhan 在 Laravel 框架中&#xff0c;单元测试是一种常用的测试方法&#xff0c;它是允许你测试应用程序中的最小可测试单元&#xff0c;通常是方法或函数。Laravel 提供了内置的测试工具PHPUnit&#xff0c;实践中进行单元测试是保障代…

Kotlin委托机制使用方式和原理

目录 类委托属性委托简单的实现属性委托Kotlin标准库中提供的几个委托延迟属性LazyLazy委托参数可观察属性Observable委托vetoable委托属性储存在Map中 实践方式双击back退出Fragment/Activity传参ViewBinding和委托 类委托 类委托有点类似于Java中的代理模式 interface Base…

基于 HT for Web 轻量化 3D 数字孪生数据中心解决方案

一、技术架构&#xff1a;HT for Web 的核心能力 图扑软件自主研发的 HT for Web 是基于 HTML5 的 2D/3D 可视化引擎&#xff0c;核心技术特性包括&#xff1a; 跨平台渲染&#xff1a;采用 WebGL 技术&#xff0c;支持 PC、移动端浏览器直接访问&#xff0c;兼容主流操作系统…

精英-探索双群协同优化(Elite-Exploration Dual Swarm Cooperative Optimization, EEDSCO)

一种多群体智能优化算法&#xff0c;其核心思想是通过两个分工明确的群体——精英群和探索群——协同工作&#xff0c;平衡算法的全局探索与局部开发能力&#xff0c;从而提高收敛精度并避免早熟收敛。 一 核心概念 在传统优化算法&#xff08;如粒子群优化、遗传算法&#xf…