学习pytorch19 pytorch使用GPU训练

news2025/5/24 13:45:12

pytorch使用GPU进行训练

  • 1. 数据 模型 损失函数调用cuda()
  • 2. 使用谷歌免费GPU gogle colab 需要创建谷歌账号登录使用, 网络能访问谷歌
  • 3. 执行
  • 4. 代码

B站土堆学习视频: https://www.bilibili.com/video/BV1hE411t7RN/?p=30&spm_id_from=pageDriver&vd_source=9607a6d9d829b667f8f0ccaaaa142fcb

1. 数据 模型 损失函数调用cuda()

if torch.cuda.is_available():
    net = net.cuda()
if torch.cuda.is_available():
    loss_fn = loss_fn.cuda()
if torch.cuda.is_available():
    imgs = imgs.cuda()
    targets = targets.cuda()

2. 使用谷歌免费GPU gogle colab 需要创建谷歌账号登录使用, 网络能访问谷歌

土堆YYDS, colab省事省力好多

3. 执行

运行报错
https://stackoverflow.com/questions/59013109/runtimeerror-input-type-torch-floattensor-and-weight-type-torch-cuda-floatte
You get this error because your model is on the GPU, but your data is on the CPU. So, you need to send your input tensors to the GPU.
意思是数据在cpu但是模型在gpu导致报错
查看代码测试数据做模型预测的时候,有个变量名写错了,改正后正常运行。
在这里插入图片描述
在这里插入图片描述

4. 代码

import torch
import torchvision
from torch import nn
from torch.utils.tensorboard import SummaryWriter
import time
# from p24_model import *

# 1. 准备数据集
# 训练数据
from torch.utils.data import DataLoader

train_data = torchvision.datasets.CIFAR10(root='./dataset', train=True, transform=torchvision.transforms.ToTensor(),
                                          download=True)
# 测试数据
test_data = torchvision.datasets.CIFAR10(root='./dataset', train=False, transform=torchvision.transforms.ToTensor(),
                                         download=True)

# 查看数据大小--size
print("训练数据集大小:", len(train_data))
print("测试数据集大小:", len(test_data))
# 利用DataLoader来加载数据集
train_loader = DataLoader(dataset=train_data, batch_size=64)
test_loader = DataLoader(dataset=test_data, batch_size=64)

# 2. 导入模型结构 创建模型
class Cifar10Net(nn.Module):
    def __init__(self):
        super(Cifar10Net, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, stride=1, padding=2),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(32, 32, 5, 1, 2),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(32, 64, 5, 1, 2),
            nn.MaxPool2d(kernel_size=2),
            nn.Flatten(),
            nn.Linear(64*4*4, 64),
            nn.Linear(64, 10)
        )

    def forward(self, x):
        x = self.net(x)
        return x

net = Cifar10Net()
if torch.cuda.is_available():
    net = net.cuda()

# 3. 创建损失函数  分类问题--交叉熵
loss_fn = nn.CrossEntropyLoss()
if torch.cuda.is_available():
    loss_fn = loss_fn.cuda()

# 4. 创建优化器
# learing_rate = 0.01
# 1e-2 = 1 * 10^(-2) = 0.01
learing_rate = 1e-2
print(learing_rate)
optimizer = torch.optim.SGD(net.parameters(), lr=learing_rate)

# 设置训练网络的一些参数
epoch = 10   # 记录训练的轮数
total_train_step = 0  # 记录训练的次数
total_test_step = 0   # 记录测试的次数

# 利用tensorboard显示训练loss趋势
writer = SummaryWriter('./train_logs')

start_time = time.time()
print('start_time: ', start_time)
print(torch.cuda.is_available())
for i in range(epoch):
    # 训练步骤开始
    net.train()  # 可以加可以不加  只有当模型结构有 Dropout BatchNorml层才会起作用
    for data in train_loader:
        imgs, targets = data  # 获取数据
        if torch.cuda.is_available():
            imgs = imgs.cuda()
            targets = targets.cuda()
        output = net(imgs)    # 数据输入模型
        loss = loss_fn(output, targets)  # 损失函数计算损失 看计算的输出和真实的标签误差是多少
        # 优化器开始优化模型  1.梯度清零  2.反向传播  3.参数优化
        optimizer.zero_grad()  # 利用优化器把梯度清零 全部设置为0
        loss.backward()        # 设置计算的损失值,调用损失的反向传播,计算每个参数结点的参数
        optimizer.step()       # 调用优化器的step()方法 对其中的参数进行优化
        # 优化一次 认为训练了一次
        total_train_step += 1
        if total_train_step % 100 == 0:
            print('训练次数: {}   loss: {}'.format(total_train_step, loss))
            end_time = time.time()
            print('训练100次需要的时间:', end_time-start_time)
        # 直接打印loss是tensor数据类型,打印loss.item()是打印的int或float真实数值, 真实数值方便做数据可视化【损失可视化】
        # print('训练次数: {}   loss: {}'.format(total_train_step, loss.item()))
        writer.add_scalar('train-loss', loss.item(), global_step=total_train_step)

    # 利用现有模型做模型测试
    # 测试步骤开始
    total_test_loss = 0
    accuracy = 0
    net.eval()  # 可以加可以不加  只有当模型结构有 Dropout BatchNorml层才会起作用
    with torch.no_grad():
        for data in test_loader:
            imags, targets = data
            if torch.cuda.is_available():
                imags = imags.cuda()
                targets = targets.cuda()
            output = net(imags)
            loss = loss_fn(output, targets)
            total_test_loss += loss.item()
            # 计算测试集的正确率
            preds = (output.argmax(1)==targets).sum()
            accuracy += preds
    # writer.add_scalar('test-loss', total_test_loss, global_step=i+1)
    writer.add_scalar('test-loss', total_test_loss, global_step=total_test_step)
    writer.add_scalar('test-accracy', accuracy/len(test_data), total_test_step)
    total_test_step += 1
    print("---------test loss: {}--------------".format(total_test_loss))
    print("---------test accuracy: {}--------------".format(accuracy/len(test_data)))
    # 保存每一个epoch训练得到的模型
    torch.save(net, './net_epoch{}.pth'.format(i))

writer.close()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1295740.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Python Authlib库:构建安全可靠的身份验证系统

更多资料获取 📚 个人网站:ipengtao.com 在现代应用程序中,安全性是至关重要的,特别是在处理用户身份验证时。Authlib库为Python开发者提供了一套强大的工具,用于简化和增强身份验证和授权流程。本文将深入探讨Authli…

js/jQuery常见操作 之各种语法例子(包括jQuery中常见的与索引相关的选择器)

js/jQuery常见操作 之各种语法例子(包括jQuery中常见的与索引相关的选择器) 1. 操作table常见的1.1 动态给table添加title(指定td)1.1.1 给td动态添加title(含:获取tr的第几个td)1.1.2 动态加工…

elasticsearch聚合、自动补全、数据同步

目录 一、数据聚合1.1 聚合的种类1.2 DSL实现聚合1.2.1 Bucket聚合语法1.2.2 聚合结果排序1.2.3 限定聚合范围1.2.4 Metric聚合语法 1.3 RestAPI实现聚合 二、自动补全2.1 拼音分词器2.2 自定义分词器2.3 自动补全查询2.4 RestAPI实现自动补全 三、数据同步3.1 思路分析3.1.1 同…

Java来实现二叉树算法,将一个二叉树左右倒置(左右孩子节点互换)

文章目录 二叉树算法二叉树左右变换数据 今天来和大家谈谈常用的二叉树算法 二叉树算法 二叉树左右变换数据 举个例子: Java来实现二叉树算法,将一个二叉树左右倒置(左右孩子节点互换)如下图所示 实现的代码如下:以…

CMake ‘3.10.2‘ was not found in PATH or by cmake.dir property.

在部署Yolov5到安卓端的过程中出现:CMake ‘3.10.2’ was not found in PATH or by cmake.dir property. 原因: cmake版本太高,需要安装低版本的cmake 最开始下载的是默认最高版本的cmake,默认是3.22.1,解决方案是,下载…

AIGC创作系统ChatGPT网站源码,Midjourney绘画,GPT联网提问/即将支持TSS语音对话功能

一、AI创作系统 SparkAi创作系统是基于ChatGPT进行开发的Ai智能问答系统和Midjourney绘画系统,支持OpenAI-GPT全模型国内AI全模型。本期针对源码系统整体测试下来非常完美,可以说SparkAi是目前国内一款的ChatGPT对接OpenAI软件系统。那么如何搭建部署AI…

C++初阶 | [七] (上) string类

摘要:标准库中的string类的常用函数 C语言中,字符串是以\0结尾的一些字符的集合,为了操作方便,C标准库中提供了一些str系列的库函数, 但是这些库函数与字符串是分离开的,不太符合OOP(面向对象)的思想&#…

区块链创新应用场景不断拓展,实现去中心化

小编介绍:10年专注商业模式设计及软件开发,擅长企业生态商业模式,商业零售会员增长裂变模式策划、商业闭环模式设计及方案落地;扶持10余个电商平台做到营收过千万,数百个平台达到百万会员,欢迎咨询。 区块…

网络知识学习(笔记三)(传输层的TCP)

前面已经介绍了传输层的UDP协议的报文以及一下相关的知识点,本次主要是传输层的TCP协议,包括TCP报文的详细介绍;可靠传输、流量控制、拥塞控制等;建立连接、释放连接。 一、TCP基本知识点介绍 1.1、TCP协议的几个重要的知识点 …

联想电脑重装系统Win10步骤和详细教程

联想电脑拥有强大的性能,很多用户办公都喜欢用联想电脑。有使用联想电脑的用户反映系统出现问题了,想重新安装一个正常的系统,但是不知道重新系统的具体步骤。接下来小编详细介绍给联想电脑重新安装Win10系统系统的方法步骤。 推荐下载 系统之…

Elastcsearch:通过 Serverless 提供更多服务

作者:Ken Exner 人们使用 Elasticsearch 解决最大数据挑战的方式一直令我们感到惊讶。 从超过 40 亿次下载、70,000 次提交、1,800 名贡献者以及我们全球社区的反馈中可以清楚地看出这一点。 Elastic 在广泛的用例中发挥的作用促使我们简化复杂性,让搜索…

2023年12月8日:UI登陆界面

作业 头文件 #ifndef WIDGET_H #define WIDGET_H#include <QWidget> #include <QMovie> #include <QPushButton> #include <QDebug>QT_BEGIN_NAMESPACE namespace Ui { class Widget; } QT_END_NAMESPACEclass Widget : public QWidget {Q_OBJECTpub…

Liunx系统使用超详细(五)~命令符号

目录 一、逻辑符号 1.1&& 1.2|| 二、连接符号 2.1| 2.2> 2.3>> 2.4< 三、分隔符号 3.1 &#xff1b; 在Linux中&#xff0c;逻辑符号和连接符号常用于构建命令行中的逻辑操作和管道操作。下面对这两种符号进行总结描述。 一、逻辑符号 1.1&…

运维知识点-Nginx

Nginx Nginx解析安全实战预备知识实验目的#制作图片木马# web服务器-Nginx服务命令及配置centOS7安装安装所需插件安装gccpcre、pcre-devel安装zlib安装安装openssl Nginx解析安全实战 预备知识 NginxPHP/FastCGI构建的WEB服务器工作原理 Nginx|FastCGI简介 Nginx (“engin…

Spring--10--Spring Bean的生命周期

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 1.Spring Bean1.1 什么是 Bean简而言之&#xff0c;bean 是由 Spring IoC 容器实例化、组装和管理的对象。 1.2 Spring框架管理Bean对象的优势 2.Bean的生命周期实例…

Linux shell编程学习笔记34:eval 命令

0 前言 在JavaScript语言中&#xff0c;有一个很特别的函数eval&#xff0c;eval函数可以将字符串当做 JavaScript 代码执行&#xff0c;返回表达式或值。 在Linux Shell 中也提供了内建命令eval&#xff0c;它是否具有JavaScript语言中eval函数的功能呢&#xff1f; 1 eval命…

MuJoCo机器人动力学仿真平台安装与教程

MuJoCo是一个机器人动力学仿真平台&#xff0c;它包括一系列的物理引擎、可视化工具和机器人模拟器等工具&#xff0c;用于研究和模拟机器人的运动和动力学特性。以下是MuJoCo的安装教程&#xff1a; 下载和安装MuJoCo Pro。可以从MuJoCo的官方网站上下载最新版本的安装包。根…

QT Creator 保存(Ctrl+S)时,会将Tab制表符转换为空格

今天在写makefile文件时&#xff0c;发现QT Creator 保存(CtrlS)时&#xff0c;会将Tab制表符转换为空格&#xff0c;之前没有发现&#xff0c;略坑&#xff0c;官网上也有说明&#xff0c;点这里 简单来说&#xff0c;解决办法如下 依次点击&#xff1a;Tools ->Options-&g…

数据结构与算法(六)分支限界法(Java)

目录 一、简介1.1 定义1.2 知识回顾1.3 两种解空间树1.4 三种分支限界法1.5 回溯法与分支线定法对比1.6 使用步骤 二、经典示例&#xff1a;0-1背包问题2.1 题目2.2 分析1&#xff09;暴力枚举2&#xff09;分支限界法 2.3 代码实现1&#xff09;实现广度优先策略遍历2&#xf…

SpringBoot系列之集成Jedis教程

SpringBoot系列之集成Jedis教程&#xff0c;Jedis是老牌的redis客户端框架&#xff0c;提供了比较齐全的redis使用命令&#xff0c;是一款开源的Java 客户端框架&#xff0c;本文使用Jedis3.1.0加上Springboot2.0&#xff0c;配合spring-boot-starter-data-redis使用&#xff0…