【PyTorch】Training Model

news2025/7/19 9:01:48

文章目录

  • 七、Training Model
    • 1、模型训练
    • 2、GPU训练
      • 2.1 .cuda()
      • 2.2 .to(device)
      • 2.3 Google Colab
    • 3、模型验证

七、Training Model

1、模型训练

CIFAR10数据集为例:

import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import time

from Model import *

# 准备数据集
train_data = torchvision.datasets.CIFAR10("../data", train=True, transform=torchvision.transforms.ToTensor(),
                                          download=True)
test_data = torchvision.datasets.CIFAR10("../data", train=False, transform=torchvision.transforms.ToTensor(),
                                         download=True)

# length长度
train_data_len = len(train_data)
test_data_len = len(test_data)
print("训练集: {}".format(train_data_len))
print("测试集: {}".format(test_data_len))

# 利用DataLoader加载数据集
train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)

# 创建网络模型
liang = Liang()

# 损失函数
loss_fn = nn.CrossEntropyLoss()

# 优化器
# learning_rate = 0.01
learning_rate = 1e-2  # 1*(10)^(-2)=1/100
optimizer = torch.optim.SGD(liang.parameters(), lr=learning_rate)

# 设置训练网络的一些参数
total_train_step = 0  # 训练次数
total_test_step = 0  # 测试次数
epoch = 10  # 训练轮数

# 添加TensorBoard
writer = SummaryWriter("../logs")
start_time = time.time()

for i in range(epoch):
    print("-----------第 {} 轮训练开始----------".format(i + 1))

    # 训练步骤开始
    liang.train()
    for data in train_dataloader:
        imgs, targets = data
        outputs = liang(imgs)
        loss = loss_fn(outputs, targets)

        # 优化调优
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_train_step += 1
        if total_train_step % 100 == 0:
            end_time = time.time()
            print(end_time - start_time)
            print("训练次数: {}, Loss: {}".format(total_train_step, loss.item()))
            writer.add_scalar("train_loss", loss.item(), total_train_step)

    # 测试步骤开始
    liang.eval()
    total_test_loss = 0
    total_accuracy = 0
    with torch.no_grad():
        for data in test_dataloader:
            imgs, targets = data
            outputs = liang(imgs)
            loss = loss_fn(outputs, targets)
            total_test_loss += loss.item()

            accuracy = (outputs.argmax(1) == targets).sum()
            total_accuracy += accuracy

    print("整体测试集上的Loss: {}".format(total_test_loss))
    print("整体测试集上的正确率: {}".format(total_accuracy / test_data_len))
    writer.add_scalar("test_loss", total_test_loss, total_test_step)
    writer.add_scalar("test_accuracy", total_accuracy / test_data_len, total_test_step)
    total_test_step += 1

    torch.save(liang, "../model/liang_{}.pth".format(i))
    print("模型已保存")

writer.close()
Files already downloaded and verified
Files already downloaded and verified
训练集: 50000
测试集: 10000
-----------1 轮训练开始----------
6.537519931793213
训练次数: 100, Loss: 2.288882255554199
13.001430749893188
训练次数: 200, Loss: 2.271170139312744
19.13790225982666
训练次数: 300, Loss: 2.247511148452759
25.20561981201172
训练次数: 400, Loss: 2.168041706085205
31.378580570220947
训练次数: 500, Loss: 2.049440383911133
37.541871309280396
训练次数: 600, Loss: 2.054497241973877
43.90901756286621
训练次数: 700, Loss: 1.9997793436050415
整体测试集上的Loss: 309.624484539032
整体测试集上的正确率: 0.2912999987602234
模型已保存
...

2、GPU训练

神经网络、损失函数、数据 转为 cuda(GPU型) 执行,我们可以发现,速度明显比CPU执行的快很多!

2.1 .cuda()

# 神经网络
liang = Liang()
if torch.cuda.is_available():
    liang = liang.cuda()
# 损失函数
loss_fn = nn.CrossEntropyLoss()
if torch.cuda.is_available():
    loss_fn = loss_fn.cuda()
# 数据
imgs, targets = data
    if torch.cuda.is_available():
        imgs = imgs.cuda()
        targets = targets.cuda()
Files already downloaded and verified
Files already downloaded and verified
训练集: 50000
测试集: 10000
-----------1 轮训练开始----------
10.994545936584473
训练次数: 100, Loss: 2.2849647998809814
12.99094533920288
训练次数: 200, Loss: 2.2762258052825928
14.33635950088501
训练次数: 300, Loss: 2.230626106262207
16.00475764274597
训练次数: 400, Loss: 2.1230242252349854
17.964726209640503
训练次数: 500, Loss: 2.022688150405884
19.61249876022339
训练次数: 600, Loss: 2.01230788230896
20.96266460418701
训练次数: 700, Loss: 1.9741096496582031
整体测试集上的Loss: 305.68632411956787
整体测试集上的正确率: 0.29739999771118164
模型已保存
...

2.2 .to(device)

# 定义训练的设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
# 神经网络
liang = Liang()
liang = liang.to(device)
# 损失函数
loss_fn = nn.CrossEntropyLoss()
loss_fn = loss_fn.to(device)
# 数据
imgs, targets = data
imgs = imgs.to(device)
targets = targets.to(device)
Files already downloaded and verified
Files already downloaded and verified
cuda
训练集: 50000
测试集: 10000
-----------1 轮训练开始----------
9.563657283782959
训练次数: 100, Loss: 2.2956345081329346
10.768706560134888
训练次数: 200, Loss: 2.2770333290100098
11.968295335769653
训练次数: 300, Loss: 2.26665997505188
13.181000471115112
训练次数: 400, Loss: 2.2037200927734375
14.387518167495728
训练次数: 500, Loss: 2.0665152072906494
15.585152387619019
训练次数: 600, Loss: 2.0054214000701904
16.81506586074829
训练次数: 700, Loss: 2.0446667671203613
整体测试集上的Loss: 320.6275497674942
整体测试集上的正确率: 0.2667999863624573
模型已保存
...

2.3 Google Colab

我们可以借助Google提供的Colab来进行GPU训练:https://colab.research.google.com/ (需要VPN)

在Colab中如果想要使用GPU进行训练,需要在笔记本设置中选择GPU。

明显快好多!!!

3、模型验证

import torch
import torchvision
from PIL import Image
from torch import nn

image_pth = "../images/dog.jpg"
image = Image.open(image_pth)
print(image)

transform = torchvision.transforms.Compose([torchvision.transforms.Resize((32, 32)),
                                            torchvision.transforms.ToTensor()])
image = transform(image)
print(image.shape)


class Liang(nn.Module):
    def __init__(self):
        super(Liang, self).__init__()
        self.module = nn.Sequential(
            nn.Conv2d(3, 32, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64 * 4 * 4, 64),
            nn.Linear(64, 10)
        )

    def forward(self, x):
        x = self.module(x)
        return x


model = torch.load("../model/liang_4.pth")
print(model)

image = torch.reshape(image, (1, 3, 32, 32))
print(image.shape)

model.eval()
image = image.cuda()

with torch.no_grad():
    output = model(image)
print(output)

print(output.argmax(1))
<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=300x200 at 0x19E66C68100>
torch.Size([3, 32, 32])
Liang(
  (module): Sequential(
    (0): Conv2d(3, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (2): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Flatten(start_dim=1, end_dim=-1)
    (7): Linear(in_features=1024, out_features=64, bias=True)
    (8): Linear(in_features=64, out_features=10, bias=True)
  )
)
torch.Size([1, 3, 32, 32])
tensor([[-0.8167, -2.1763,  1.3891,  0.7956,  1.2035,  1.8374, -0.7936,  1.7908,
         -2.0639, -1.4441]], device='cuda:0')
tensor([5], device='cuda:0')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/38010.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【算法】2022第五届“传智杯”全国大学生计算机大赛(练习赛)

【参考&#xff1a;第五届“传智杯”全国大学生计算机大赛&#xff08;练习赛&#xff09; - 洛谷 | 计算机科学教育新生态】 练习赛满分程序&#xff08;多语言&#xff09;&#xff1a;https://www.luogu.com.cn/paste/fi60s4yu CPU一秒大概运行 10810^8108 次&#xff0c;…

年产10万吨环氧树脂车间工艺设计

目 录 摘 要 1 ABSTRACT 2 1 绪论 3 1.1环氧树脂的基本性质 3 1.2 环氧树脂的特点和用途 3 1.3环氧树脂发展的历史、现状及趋势 3 1.3.1环氧树脂的发展历史 4 1.3.2环氧树脂的生产现状 4 1.3.3 环氧树脂的发展趋势 5 1.4本设计的目的、意义及内容 5 1.4.1本设计的目的 5 1.4.2…

Matlab顶级期刊配色工具Rggsci

颜色搭配是一件非常让人头疼的事情。 一方面&#xff0c;如果忽视了配色&#xff0c;就好像是做菜没放盐&#xff0c;总会感觉少些味道。 另一方面&#xff0c;如果太注重配色&#xff0c;又感觉不是很有必要&#xff0c;毕竟数据结果好看才是第一位的。 想要平衡两者&#…

18.4 嵌入式指针概念及范例、内存池改进版

一&#xff1a;嵌入式指针&#xff08;embedded pointer&#xff09; 1、嵌入式指针概念 一般应用在内存池相关的代码中&#xff0c;成功使用嵌入式指针有个前提条件&#xff1a;&#xff08;类A对象的sizeof必须不小于4字节&#xff09; 嵌入式指针工作原理&#xff1a;借用…

文华财经期货K线多周期画线技术,多重短线技术共振通道线指标公式——多周期主图自动画线

期货指标公式是通过数学逻辑角度计算而来&#xff0c;仅是期货分析环节中的一个辅助工具。期货市场具有不确定性和不可预测性的&#xff0c;请正常对待和使用指标公式! 期货指标公式信号本身就有滞后性&#xff0c;周期越大&#xff0c;滞后性越久。指标公式不是100%稳赚的工具…

cocos2dx创建工程并在androidstudio平台编译

本文主要是通过androidstudio进行编译运行cocos2dx工程。 前置条件&#xff1a; 1&#xff1a;androidstudio已经下载并安装。 2&#xff1a;cocos2dx已经下载并打开。 这里androidstudio使用2021.3.1版本&#xff0c;cocos2dx使用4.0版本。 第一步&#xff0c;首先安装py…

Hive之数据类型和视图

Hive系列 第八章 数据类型和视图 8.1 数据类型 8.1.1 原子数据类型 &#xff08;其实上图中有一点错误&#xff0c;大家可以找找看&#xff09; 说明&#xff1a; 1、Hive 支持日期类型(老版本不支持)&#xff0c;在 Hive 里日期一般都是用字符串来表示的&#xff0c;而常用…

STC 51单片机40——汇编语言 串口 接收与发送

实际运行&#xff0c;正常 ; 仿真时&#xff0c;单步运行&#xff0c;记得设置虚拟串口数据【仿真有问题&#xff0c;虚拟串口助手工作不正常&#xff01;】 ORG 0000H MOV TMOD ,#20H ;定时器1&#xff0c;工作方式2&#xff0c;8位重装载 MOV TH1,#0FDH ; 波特率…

智慧酒店解决方案-最新全套文件

智慧酒店解决方案-最新全套文件一、建设背景为什么要建设智慧酒店一、智慧酒店功能亮点 &#xff1a;二、智慧酒店八大特色&#xff1a;二、建设思路三、建设方案四、获取 - 智慧酒店全套最新解决方案合集一、建设背景 为什么要建设智慧酒店 一、智慧酒店功能亮点 &#xff1…

mysql-8.0.31-macos12-x86_64记录

常用的命令 停止MySQL服务 : sudo /usr/local/mysql/support-files/mysql.server stop 启动MySQL服务 : sudo /usr/local/mysql/support-files/mysql.server start 重启MySQL服务 : sudo /usr/local/mysql/support-files/mysql.server restart 修改mysql密码 关闭mysql服务…

Qt5开发从入门到精通——第十二篇二节(Qt5 事件处理及实例——多线程控制、互斥量、信号量、线程等待与唤醒)

提示&#xff1a;欢迎小伙伴的点评✨✨&#xff0c;相互学习c/c应用开发。&#x1f373;&#x1f373;&#x1f373; 博主&#x1f9d1;&#x1f9d1; 本着开源的精神交流Qt开发的经验、将持续更新续章&#xff0c;为社区贡献博主自身的开源精神&#x1f469;‍&#x1f680; 文…

【C语言数据结构】带头节点与不带头节点的单链表头插法对比

前言 近期在学习STM32代码框架的过程中&#xff0c;老师使用链表来注册设备&#xff0c;发现使用了不带头节点的单链表&#xff0c;注册时使用头插法。之前在本专题整理学习过带头节点的单链表&#xff0c;因此本文整理对比一下两种方式的头插法区别&#xff0c;具体实现在次&…

html表白代码

目录一.引言二.表白效果展示1.惊喜表白2.烟花表白3.玫瑰花表白4.心形表白5.心加文字6.炫酷的特效一.引言 我们可以用一下好看的网页来表白&#xff0c;下面就有我觉得很有趣的表白代码。评论直接找我要源码也行。 下载整套表白文件 二.表白效果展示 1.惊喜表白 2.烟花表白 源码…

【TS】泛型以及多个泛型参数

泛型 给函数或者属性定义类型的时候&#xff0c;类型是固定的&#xff0c;当业务发生变动时可能不好维护&#xff0c;例如&#xff1a;函数类型固定为string,后续需求更改不好维护&#xff0c;比如需要传入number类型&#xff0c;那么这个函数就不适用了 function add( val :…

数学题类英语作文

最近我看到过这样一道英语作文题&#xff0c;这类英语作文题很少见&#xff0c;但也有必要讲一讲怎么写。 简化题意&#xff1a;帮Peter完成一下一道题&#xff1a; f(x)ax2−(a6)x3ln⁡xf(x)ax^2-(a6)x3\ln xf(x)ax2−(a6)x3lnx &#xff08;1&#xff09;讨论当a1a1a1时&am…

CMake中file的使用

CMake中的file命令用于文件操作&#xff0c;其文件格式如下&#xff1a;此命令专用于需要访问文件系统的文件和路径操作 Readingfile(READ <filename> <variable>[OFFSET <offset>] [LIMIT <max-in>] [HEX])file(STRINGS <filename> <variab…

Java8-新特性及Lambda表达式

1、Java8新特性内容概述 1.1、简介 Java 8(又称为jdk1.8)是Java语言开发的一个主要版本 Java 8是oracle公司于2014年3月发布&#xff0c;可以看成是自Java 5以来最具革命性的版本。Java 8为Java语言、编译器、类库、开发工具与JVM带来了大量新特性 1.2、新特性思维导图总结 1.…

JS中数组随机排序实现(原地算法sort/shuffle算法)

&#x1f431;个人主页&#xff1a;不叫猫先生 &#x1f64b;‍♂️作者简介&#xff1a;专注于前端领域各种技术&#xff0c;热衷分享&#xff0c;期待你的关注。 &#x1f4ab;系列专栏&#xff1a;vue3从入门到精通 &#x1f4dd;个人签名&#xff1a;不破不立 目录一、原地…

代码随想录刷题|LeetCode 70. 爬楼梯(进阶) 322. 零钱兑换 279.完全平方数 139.单词拆分

目录 70. 爬楼梯 &#xff08;进阶&#xff09; 思路 爬楼梯 1或2步爬楼梯 多步爬楼梯 322. 零钱兑换 思考 1、确定dp数组及其含义 2、确定递推公式 3、初始化dp数组 4、确定遍历顺序 零钱兑换 先遍历物品&#xff0c;再遍历背包 先遍历背包&#xff0c;再遍历物品 279.完全平方…

【操作系统实验】线程的创建+信号量通信

sem_init: 功能&#xff1a;初始化信号量 返回值&#xff1a;创建成功返回0&#xff0c;失败返回-1 参数sem&#xff1a;指向信号量结构的一个指针 参数pshared&#xff1a;不为&#xff10;时此信号量在进程间共享&#xff0c;为0时当前进程的所有线程共享 参数value&#xf…