2025-05-31 Python深度学习10——模型训练流程

news2025/6/6 5:25:43

文章目录

  • 1 数据准备
    • 1.1 下载与预处理
    • 1.2 数据加载
  • 2 模型构建
    • 2.1 自定义 CNN 模型
    • 2.2 GPU加速
  • 3 训练配置
    • 3.1 损失函数
    • 3.2 优化器
    • 3.3 训练参数
  • 4 训练循环
    • 4.1 训练模式 (`model.train()`)
    • 4.2 评估模式 (`model.eval()`)
  • 5 模型验证

本文环境:

  • Pycharm 2025.1
  • Python 3.12.9
  • Pytorch 2.6.0+cu124

​ 本文以 CIFAR-10 为例,介绍模型的大致训练流程。相关的 Python 包如下:

import torch
import torchvision
from torch import nn, Tensor
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import time

1 数据准备

1.1 下载与预处理

​ 使用torchvision.datasets.CIFAR10下载 CIFAR-10 数据集(32x32 彩色图像,10 类),分为训练集(train=True,5 万张)和测试集(train=False,1 万张)。

# 准备数据集
train_data = torchvision.datasets.CIFAR10(
    root='./dataset',
    train=True,
    transform=torchvision.transforms.ToTensor(),
    download=True
)

test_data = torchvision.datasets.CIFAR10(
    root='./dataset',
    train=False,
    transform=torchvision.transforms.ToTensor(),
    download=True
)
  • transform=torchvision.transforms.ToTensor():将图像转为 PyTorch 张量(Tensor),并自动归一化到 [0, 1] 范围。
  • download=True:若本地无数据,自动下载。

1.2 数据加载

​ 通过DataLoader分批次加载数据:

# 加载数据集
train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)
  • batch_size=64:每批次处理 64 张图片,平衡内存占用和训练效率。
  • 训练集默认不打乱(未设置shuffle),测试集可添加shuffle=True以增强评估可靠性。

2 模型构建

2.1 自定义 CNN 模型

MyModel是一个3层卷积神经网络(CNN):

  1. 卷积层:nn.Conv2d(3, 32, 5, 1, 2)

    输入通道 3(RGB),输出通道 32,5×5 卷积核,步长 1,填充 2(保持尺寸不变)。

  2. 池化层:nn.MaxPool2d(2)

    2×2最大池化,尺寸减半。

  3. 全连接层

    • nn.Linear(64 * 4 * 4, 64)

      将展平后的特征(64 通道×4×4尺寸)映射到 64 维。

    • nn.Linear(64, 10)

      最终输出 10 类。

image-20250527161255093
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 32, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64 * 4 * 4, 64),
            nn.Linear(64, 10)
        )

    def forward(self, x):
        return self.model(x)

2.2 GPU加速

​ 通过.to(device)将模型和数据移至 GPU(若可用),显著加速计算。

# 定义训练的设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = MyModel().to(device)  # 使用GPU

3 训练配置

3.1 损失函数

​ 使用nn.CrossEntropyLoss(),适用于多分类任务,计算预测概率与真实标签的交叉熵。

# 损失函数
loss_fn = nn.CrossEntropyLoss().to(device)  # 使用GPU

3.2 优化器

​ 使用torch.optim.SGD,随机梯度下降,学习率lr=1e-2,控制参数更新步长。

# 损失函数
loss_fn = nn.CrossEntropyLoss().to(device)  # 使用GPU

3.3 训练参数

  1. total_train_step:记录训练次数,用于日志和调试。
  2. total_test_step:记录测试次数,用于日志和调试。
  3. epoch=20:遍历完整数据集 20 次。
# 设置训练网络的一些参数
total_train_step = 0  # 记录训练次数
total_test_step = 0  # 记录测试次数
epoch = 20  # 训练的轮数

4 训练循环

数据加载  →  模型初始化  →  训练循环  →  测试评估  →  保存模型
    ↑          ↑                  ↓           ↓
    └───TensorBoard日志 ←────── 参数更新 ←── 梯度计算

4.1 训练模式 (model.train())

  1. 前向传播:输入图像imgs,模型输出预测outputs
  2. 计算损失loss = loss_fn(outputs, targets),衡量预测误差。
  3. 反向传播
    • optimizer.zero_grad():清空梯度,避免累积。
    • loss.backward():计算梯度(链式法则)。
    • optimizer.step():更新模型参数。
  4. 日志记录:每 100 次训练记录损失和时间到 TensorBoard 中。
for i in range(epoch):
    print(f"------------第 {i + 1} 轮训练开始------------")

    # 训练步骤开始
    model.train()
    for data in train_dataloader:
        imgs, targets = data
        imgs = imgs.to(device)  # 使用GPU
        targets = targets.to(device)  # 使用GPU
        outputs = model(imgs)
        loss = loss_fn(outputs, targets)
        # 反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_train_step += 1
        if total_train_step % 100 == 0:
            end_time = time.time()
            print(f"第 {total_train_step} 次训练,Loss:{loss.item()},Time:{end_time - start_time}")
            writer.add_scalar("train_loss", loss.item(), total_train_step)
            start_time = time.time()

4.2 评估模式 (model.eval())

  1. 关闭梯度计算with torch.no_grad(),节省内存并加速。
  2. 测试指标
    • 总损失:累加所有批次的损失total_test_loss
    • 准确率:统计预测正确的样本数(outputs.argmax(dim=1) == targets)。
  3. 日志记录:每轮测试后保存损失和准确率到 TensorBoard。
  4. 保存模型:通过torch.save()方法将模型的 state_dict 保存到本地文件中。
    # 测试步骤开始
    model.eval()
    total_test_loss = 0
    total_accuracy = 0
    accuracy_rate = 0
    with torch.no_grad():
        for data in test_dataloader:
            imgs, targets = data
            imgs = imgs.to(device)  # 使用GPU
            targets = targets.to(device)  # 使用GPU
            outputs: Tensor = model(imgs)
            loss = loss_fn(outputs, targets)
            total_test_loss += loss
            accuracy = outputs.argmax(dim=1) == targets
            total_accuracy += accuracy.sum()
    total_test_step += 1
    accuracy_rate = total_accuracy / test_data_size
    print(f"第 {i + 1} 轮测试,Loss:{total_test_loss},Accuracy:{total_accuracy} ({accuracy_rate})")
    writer.add_scalar("test_loss", total_test_loss, total_test_step)
    writer.add_scalar("test_accuracy", total_accuracy, total_test_step)
    writer.add_scalar("accuracy_rate", accuracy_rate, total_test_step)

torch.save(model.state_dict(), f"model/my_model.pth")  # 保存模型

writer.close()
image-20250531123058128

说明

  1. 训练与评估模式切换

    • model.train():启用 Dropout 和 BatchNorm 的训练行为(如随机丢弃神经元)。

    • model.eval():固定 Dropout 和 BatchNorm 的统计量,确保评估一致性。

  2. GPU 数据迁移

    需将输入数据 imgs 和标签 targets 均移至 GPU,否则会报错。

  3. 梯度清零

    避免梯度累加导致参数更新错误。

完整代码

# train_gpu_2.py

import torch
import torchvision
from torch import nn, Tensor
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import time

# 定义训练的设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 准备数据集
train_data = torchvision.datasets.CIFAR10(
    root='./dataset',
    train=True,
    transform=torchvision.transforms.ToTensor(),
    download=True
)

test_data = torchvision.datasets.CIFAR10(
    root='./dataset',
    train=False,
    transform=torchvision.transforms.ToTensor(),
    download=True
)

# 数据集大小
train_data_size = len(train_data)
test_data_size = len(test_data)

print(f"训练集数量:{train_data_size}")
print(f"测试集数量:{test_data_size}")

# 加载数据集
train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)


# 创建网络模型
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 32, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64 * 4 * 4, 64),
            nn.Linear(64, 10)
        )

    def forward(self, x):
        return self.model(x)


model = MyModel().to(device)  # 使用GPU

# 损失函数
loss_fn = nn.CrossEntropyLoss().to(device)  # 使用GPU

# 优化器
lr = 1e-2
optimizer = torch.optim.SGD(model.parameters(), lr=lr)

# 设置训练网络的一些参数
total_train_step = 0  # 记录训练次数
total_test_step = 0  # 记录测试次数
epoch = 20  # 训练的轮数

# 添加 tensorboard
writer = SummaryWriter("../logs_train")

for i in range(epoch):
    print(f"------------第 {i + 1} 轮训练开始------------")

    start_time = time.time()
    # 训练步骤开始
    model.train()
    for data in train_dataloader:
        imgs, targets = data
        imgs = imgs.to(device)  # 使用GPU
        targets = targets.to(device)  # 使用GPU
        outputs = model(imgs)
        loss = loss_fn(outputs, targets)
        # 反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_train_step += 1
        if total_train_step % 100 == 0:
            end_time = time.time()
            print(f"第 {total_train_step} 次训练,Loss:{loss.item()},Time:{end_time - start_time}")
            writer.add_scalar("train_loss", loss.item(), total_train_step)
            start_time = time.time()

    # 测试步骤开始
    model.eval()
    total_test_loss = 0
    total_accuracy = 0
    accuracy_rate = 0
    with torch.no_grad():
        for data in test_dataloader:
            imgs, targets = data
            imgs = imgs.to(device)  # 使用GPU
            targets = targets.to(device)  # 使用GPU
            outputs: Tensor = model(imgs)
            loss = loss_fn(outputs, targets)
            total_test_loss += loss
            accuracy = outputs.argmax(dim=1) == targets
            total_accuracy += accuracy.sum()
    total_test_step += 1
    accuracy_rate = total_accuracy / test_data_size
    print(f"第 {i + 1} 轮测试,Loss:{total_test_loss},Accuracy:{total_accuracy} ({accuracy_rate})")
    writer.add_scalar("test_loss", total_test_loss, total_test_step)
    writer.add_scalar("test_accuracy", total_accuracy, total_test_step)
    writer.add_scalar("accuracy_rate", accuracy_rate, total_test_step)

torch.save(model.state_dict(), f"model/my_model.pth")  # 保存模型

writer.close()

5 模型验证

​ 准备待验证的图片,放在 imgae 目录下。

image-20250531123319819

​ 编写 test.py 文件,用于验证模型。

# test.py

import torch
import torchvision
from PIL import Image
from torch import nn

# 定义图片路径
image_path = "image/dog.png"

# 打开图片并转换为RGB格式
image = Image.open(image_path).convert('RGB')

# 定义图片转换操作
transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((32, 32)),  # 将图片大小调整为32x32
    torchvision.transforms.ToTensor()  # 将图片转换为张量
])

# 对图片进行转换操作
image = transform(image).reshape(1, 3, 32, 32)


# 定义模型类
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        # 定义模型结构
        self.model = nn.Sequential(
            nn.Conv2d(3, 32, 5, 1, 2),  # 第一个卷积层,输入通道数为3,输出通道数为32,卷积核大小为5,步长为1,填充为2
            nn.MaxPool2d(2),  # 最大池化层,池化核大小为2
            nn.Conv2d(32, 32, 5, 1, 2),  # 第二个卷积层,输入通道数为32,输出通道数为32,卷积核大小为5,步长为1,填充为2
            nn.MaxPool2d(2),  # 最大池化层,池化核大小为2
            nn.Conv2d(32, 64, 5, 1, 2),  # 第三个卷积层,输入通道数为32,输出通道数为64,卷积核大小为5,步长为1,填充为2
            nn.MaxPool2d(2),  # 最大池化层,池化核大小为2
            nn.Flatten(),  # 展平操作
            nn.Linear(64 * 4 * 4, 64),  # 全连接层,输入维度为64*4*4,输出维度为64
            nn.Linear(64, 10)  # 全连接层,输入维度为64,输出维度为10
        )

    def forward(self, x):
        return self.model(x)


# 加载模型参数
model_dict = torch.load('model/my_model.pth')
model = MyModel()
model.load_state_dict(model_dict)
model.to('cuda')

# 设置模型为评估模式
model.eval()
# 关闭梯度计算
with torch.no_grad():
    # 将图片转换为GPU格式
    image = image.to('cuda')
    # 进行模型推理
    output = model(image)

# 打印输出结果
print(output)
# 打印输出结果中最大值的索引
print(output.argmax(1))

​ 验证结果如下,表明 dog.png 图片的预测结果索引为 5,即第 6 类预测。

image-20250531123550636

​ 依据分类规则,预测结果为 dog,是正确的。

image-20250531123902927

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2399108.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

K8S StatefulSet 快速开始

其实这篇文章的梗概已经写了很久了,中间我小孩出生了,从此人间多了一份牵挂。抽出一些时间去办理新生儿相关手续。初为人父确实艰辛,就像学技术一样,都需要有极大的耐心,付出很多的时间。 一、引子 1.1、独立的存储 …

nav2笔记-250603

合作背景: AMD与Open Navigation在过去几个月里进行了合作,旨在向ROS 2社区展示AMD强大的Ryzen AI、Embedded和Kria能力。 演示内容: 帖子提到,他们已经开始展示如何使用Ryzen AI为自主机器人产品提供动力,在各种现实世…

指纹识别+精准化POC攻击

开发目的 解决漏洞扫描器的痛点 第一就是扫描量太大,对一个站点扫描了大量的无用 POC,浪费时间 指纹识别后还需要根据对应的指纹去进行 payload 扫描,非常的麻烦 开发思路 我们的思路分为大体分为指纹POC扫描 所以思路大概从这几个方面…

mac环境下的python、pycharm和pip安装使用

Python安装 Mac环境下的python安装 下载地址:https://www.jetbrains.com.cn/pycharm/ 一直点击下一步即可完成 在应用程序中会多了两个图标 IDLE 和 Python launcher IDLE支持在窗口中直接敲python命令并立即执行,双击即可打开 Python launcher双击打…

BUUCTF[极客大挑战 2019]Havefun 1题解

BUUCTF[极客大挑战 2019]Havefun 1题解 题目分析解题理解代码逻辑:构造Payload: 总结 题目分析 生成靶机,进入网址: 首页几乎没有任何信息,公式化F12打开源码,发现一段被注释的源码: 下面我们…

Tomcat优化篇

目录 一、Tomcat自身配置 1.Tomcat管理页面 2. 禁用AJP服务 3.Executor优化 4.三种运行模式 5.web.xml 6.Host标签 7.Context标签 8.启动速度优化 9.其他方面 二、JMeter测试 笔者推荐 一、Tomcat自身配置 1.Tomcat管理页面 我们可以打开Tomcat的管理页面&#xff…

Temporal Fusion Transformer(TFT)扩散模型时间序列预测模型

1. TFT 简介 Temporal Fusion Transformer(TFT)模型是一种专为时间序列预测设计的高级深度学习模型。它结合了神经网络的多种机制处理时间序列数据中的复杂关系。TFT 由 Lim et al. 于 2019年提出,旨在处理时间序列中的不确定性和多尺度的依…

【LangServe部署流程】5 分钟部署你的 AI 服务

目录 一、LangServe简介 二、环境准备 1. 安装必要依赖 2. 编写一个 LangChain 可运行链(Runnable) 3. 启动 LangServe 服务 4. 启动服务 5. 使用 API 进行调用 三、可选:访问交互式 Swagger 文档 四、基于 LangServe 的 RAG 应用部…

攻防世界-unseping

进入环境 在获得的场景中发现PHP代码并进行分析 编写PHP编码 得到 Tzo0OiJlYXNlIjoyOntzOjEyOiIAZWFzZQBtZXRob2QiO3M6NDoicGluZyI7czoxMDoiAGVhc2UAYXJncyI7YToxOntpOjA7czozOiJwd2QiO319 将其传入 想执行ls,但是发现被过滤掉了 使用环境变量进行绕过 $a new…

[yolov11改进系列]基于yolov11使用FasterNet替换backbone用于轻量化网络的python源码+训练源码

【FasterNet介绍】 为了设计快速神经网络,许多工作都集中在减少浮点运算的数量(FLOPs)上。 然而,我们观察到FLOPs的减少并不一定会导致延迟的类似程度的减少。 这主要源于低效率的每秒浮点运算(FLOPS)。 为了实现更快的网络&#…

一周学会Pandas2之Python数据处理与分析-Pandas2数据绘图与可视化

锋哥原创的Pandas2 Python数据处理与分析 视频教程: 2025版 Pandas2 Python数据处理与分析 视频教程(无废话版) 玩命更新中~_哔哩哔哩_bilibili Pandas 集成了 Matplotlib,提供了简单高效的绘图接口,使数据可视化变得直观便捷。本指南将详…

企业级安全实践:SSL/TLS 加密与权限管理(一)

引言 ** 在数字化转型的浪潮中,企业对网络的依赖程度与日俱增,从日常办公到核心业务的开展,都离不开网络的支持。与此同时,网络安全问题也日益严峻,成为企业发展过程中不可忽视的重要挑战。 一旦企业遭遇网络安全事…

2025——》VSCode Windows 最新安装指南/VSCode安装完成后如何验证是否成功?2025最新VSCode安装配置全攻略

1.VSCode Windows 最新安装指南: 以下是 2025 年 Windows 系统下安装 Visual Studio Code(VSCode)的最新指南,结合官方文档与实际操作经验整理而成: 一、下载官方安装包: 1.访问官网: 打开浏览器,进入 VSCode 官方下载页面https://code.visualstudio.com/Download 2…

【MATLAB代码】制导——三点法,二维平面下的例程|运动目标制导,附完整源代码

三点法制导是一种导弹制导策略,主要用于确保导弹能够准确追踪并击中移动目标。该方法通过计算导弹、目标和制导站之间的相对位置关系,实现对目标的有效制导。 本文给出MATLAB下的三点法例程,模拟平面上捕获运动目标的情况订阅专栏后可直接查看源代码,粘贴到MATLAB空脚本中即…

如何爬取google应用商店的应用分类呢?

以下是爬取Google Play商店应用包名(package name)和对应分类的完整解决方案,采用ScrapyPlaywright组合应对动态渲染页面,并处理反爬机制: 完整爬虫实现 1. 安装必要库 # 卸载现有安装pip uninstall playwright scrapy-playwright -y# 重新…

SQL Relational Algebra(数据库关系代数)

目录 What is an “Algebra” What is Relational Algebra? Core Relational Algebra Selection Projection Extended Projection Product(笛卡尔积) Theta-Join Natural Join Renaming Building Complex Expressions Sequences of Assignm…

智能工业时代:工业场景下的 AI 大模型体系架构与应用探索

自工业革命以来,工业生产先后经历了机械化、电气化、自动化、信息化的演进,正从数字化向智能化迈进,人工智能技术是新一轮科技革命和产业变革的重要驱动力量,AI 大模型以其强大的学习计算能力掀开了人工智能通用化的序幕&#xff…

易语言使用OCR

易语言使用OCR 用易语言写个脚本,需要用到OCR,因此我自己封装了一个OCR到DLL。 http://lkinfer.1it.top/ 视频演示:https://www.bilibili.com/video/BV1Zg7az2Eq3/ 支持易语言、c、c#使用,平台限制:window 10 介绍…

C++和C#界面开发方式的全面对比

文章目录 C界面开发方式1. **MFC(Microsoft Foundation Classes)**2. **Qt**3. **WTL(Windows Template Library)**4. **wxWidgets**5. **DirectUI** C#界面开发方式1. **WPF(Windows Presentation Foundation&#xf…

算法-集合的使用

1、set常用操作 set<int> q; //以int型为例 默认按键值升序 set<int,greater<int>> p; //降序排列 int x; q.insert(x); //将x插入q中 q.erase(x); //删除q中的x元素,返回0或1,0表示set中不存在x q.clear(); //清空q q.empty(); //判断q是否为空&a…