PyTorch——优化器(9)

news2025/6/6 12:50:52

优化器根据梯度调整参数,以达到降低误差

import torch.optim
import torchvision
from torch import nn
from torch.nn import Sequential, Conv2d, MaxPool2d, Flatten, Linear
from torch.utils.data import DataLoader

# 加载CIFAR10测试数据集,设置transform将图像转换为Tensor
dataset = torchvision.datasets.CIFAR10("./data", train=False, transform=torchvision.transforms.ToTensor(),
                                       download=True)
# 创建数据加载器,设置批量大小为64
dataloader = DataLoader(dataset, batch_size=64)

# 定义卷积神经网络模型
class TY(nn.Module):
    def __init__(self):
        super(TY, self).__init__()
        # 构建网络结构:3个卷积层+池化层组合,2个全连接层
        self.model1 = Sequential(
            Conv2d(3, 32, 5, padding=2),    # 输入3通道,输出32通道,卷积核5x5
            MaxPool2d(2),                   # 最大池化,步长2
            Conv2d(32, 32, 5, padding=2),   # 第二层卷积
            MaxPool2d(2),                   # 第二次池化
            Conv2d(32, 64, 5, padding=2),   # 第三层卷积
            MaxPool2d(2),                   # 第三次池化
            Flatten(),                      # 将多维张量展平为向量
            Linear(1024, 64),               # 全连接层,输入1024维,输出64维
            Linear(64, 10),                 # 输出层,10个类别对应10个输出
        )

    def forward(self, x):
        # 定义前向传播路径
        x = self.model1(x)
        return x

# 定义损失函数(交叉熵损失适用于多分类问题)
loss = nn.CrossEntropyLoss()
# 实例化模型
ty = TY()
# 定义优化器(随机梯度下降),设置学习率为0.01
optim = torch.optim.SGD(ty.parameters(), lr=0.01)

# 训练20个完整轮次
for epoch in range(20):
    running_loss = 0.0  # 初始化本轮累计损失
    
    # 遍历数据加载器中的每个批次
    for data in dataloader:
        imgs, targets = data  # 获取图像和标签
        outputs = ty(imgs)    # 前向传播
        result_loss = loss(outputs, targets)  # 计算损失
        
        optim.zero_grad()     # 梯度清零,防止累积
        result_loss.backward()  # 反向传播计算梯度
        optim.step()          # 更新模型参数
        
        running_loss += result_loss  # 累加损失值
    
    # 打印本轮训练的累计损失
    print(f"Epoch {epoch+1}, Loss: {running_loss}")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2399314.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

07 APP 自动化- appium+pytest+allure框架封装

文章目录 一、PO二、代码简单实现项目框架预览:base_page.pydir_config.pyget_data.pylogger.pystart_session.pyconfig.yamlkey_code.yamllaunch_page_loc.pylogin_page_loc.pylaunch_page.pylogin_page.pytest_login.pypytest.inirun.py 一、PO PO 分为四层 &…

英国2025年战略防御评估报告:网络与电磁域成现代战争核心

英国 2025 年战略防御评估 (SDR) 详细制定了一项计划,通过加强使用网络、人工智能和数字战争来整合其军事防御和进攻能力。 与美国一样,英国也被认为(尽管未被公开证实)会开展进攻性网络行动,甚至针对盟友。斯诺登泄露…

基于QPSK调制解调+Polar编译码(SCL译码)的matlab性能仿真,并对比BPSK

目录 1.引言 2.算法仿真效果演示 3.数据集格式或算法参数简介 4.MATLAB核心程序 5.算法涉及理论知识概要 6.参考文献 7.完整算法代码文件获得 1.引言 Polar码由土耳其教授Erdal Arikan于2008年提出,是第一种被严格证明可以达到香农极限的构造性编码方法。其核…

Glide NoResultEncoderAvailableException异常解决

首先将解决方法提出来:缓存策略DiskCacheStrategy.DATA。 使用Glide加载图片,版本是4.15.0,有天发现无法显示gif图片,原始代码如下: Glide.with(context).load(本地资源路径).diskCacheStrategy(DiskCacheStrategy.A…

无人机巡检智能边缘计算终端技术方案‌‌——基于EFISH-SCB-RK3588工控机/SAIL-RK3588核心板的国产化替代方案‌

一、方案核心价值‌ ‌实时AI处理‌:6TOPS NPU实现无人机影像的实时缺陷检测(延迟<50ms)‌全国产化‌:芯片、操作系统、算法工具链100%自主可控‌极端环境适配‌:-40℃~85℃稳定运行,IP65防护等…

相机--相机成像原理和基础概念

教程 成像原理 基础概念 焦距(物理焦距) 镜头的光学中心到感光元件之间的距离,用f表示,单位:mm;。 像素焦距 相机内参矩阵中的 fx​ 和 fy​ 是将物理焦距转换到像素坐标系的产物,可能不同。…

2025-0604学习记录17——文献阅读与分享(2)

最近不是失踪了!也不是弃坑了...这不是马上要毕业了嘛!所以最近在忙毕业论文答辩、毕业去向填报、户档去向填报等等,事情太多了,没顾得上博客。现在这些事基本上都解决完了,也有时间静下心来写写文字了~ 想要写的内容…

图解浏览器多进程渲染:从DNS到GPU合成的完整旅程

目录 浅谈浏览器进程 浏览器进程架构的演化 进程和线程关系图示 进程(Process) 线程(Thread) 协程(Coroutine) 进程&线程&协程核心对比 单进程和多进程浏览器 单进程浏览器​编辑 单进程…

【计算机网络】第3章:传输层—TCP 拥塞控制

目录 一、PPT 二、总结 TCP 拥塞控制详解 ⭐ 核心机制与算法 1. 慢启动(Slow Start) 2. 拥塞避免(Congestion Avoidance) 3. 快速重传(Fast Retransmit) 4. 快速恢复(Fast Recovery&…

idea不识别lombok---实体类报没有getter方法

介绍 本篇文章,主要讲idea引入lombok后,在实体类中加注解Data,在项目启动的时候,编译不通过,报错xxx.java没有getXxxx()方法。 原因有以下几种 1. idea没有开启lombok插件 2. 使用idea-2023…

SAP学习笔记 - 开发15 - 前端Fiori开发 Boostrap,Controls,MVC(Model,View,Controller),Modules

上一章讲了Fiori开发的准备,以及宇宙至简之HelloWorld。 SAP学习笔记 - 开发14 - 前端Fiori开发 HelloWorld-CSDN博客 本章继续学习 Fiori 开发的知识: Bootstrap,Controls,MVC(Model,View,Controller&a…

基于SDN环境下的DDoS异常攻击的检测与缓解

参考以下两篇博客,最后成功: 基于SDN的DDoS攻击检测和防御方法_基于sdn的ddos攻击检测与防御-CSDN博客 利用mininet模拟SDN架构并进行DDoS攻击与防御模拟(Ryumininetsflowpostman)_mininet模拟dos攻击-CSDN博客 需求 H2 模拟f…

如何轻松地将文件从 PC 传输到 iPhone?

传统上,您可以使用 iTunes 将文件从 PC 传输到 iPhone,但现在,使用 iTunes 已不再是唯一的选择。现在有多种不同且有效的方法可以帮助您传输文件。在今天的指南中,您可以找到 8 种使用或不使用 iTunes 传输文件的方法,…

Bresenham算法

一 Bresenham 绘直线 使用 Bresenham 算法,可以在显示器上绘制一直线段。该算法主要思想如下: 1 给出直线段上两个端点 ,根据端点求出直线在X,Y方向上变化速率 ; 2 当 时,X 方向上变化速率快于 Y 方向上变化速率&am…

【从GEO数据库批量下载数据】

从GEO数据库批量下载数据 1:进入GEO DataSets拿到所需要下载的数据的srr.list,上传到linux, 就可以使用prefetch这个函数来下载 2:操作步骤如下: conda 安装sra-tools conda create -n sra-env -c bioconda -c co…

day 44

使用DenseNet预训练模型对cifar10数据集进行训练 import torch import torch.nn as nn import torch.optim as optim from torchvision import datasets, transforms, models from torch.utils.data import DataLoader import matplotlib.pyplot as plt import os# 设置中文字体…

NER实践总结,记录一下自己实践遇到的各种问题。

更。 没卡,跑个模型休息好几天,又闲又急。 一开始直接套用了别人的代码进行实体识别,结果很差,原因是他的词表没有我需要的东西,我是用的医学文本。代码直接在github找了改的,用的是BERT的Chinese版本。 然…

微信小程序实现运动能耗计算

微信小程序实现运动能耗计算 近我做了一个挺有意思的微信小程序,能够实现运动能耗的计算。只需要输入性别、年龄、体重、运动时长和运动类型这些信息,就能算出对应的消耗热量。 具体来说,在小程序里,性别不同,身体基…

iTunes 无法备份 iPhone:10 种解决方法

Apple 设备是移动设备市场上最先进的产品之一,但有些人遇到过 iTunes 因出现错误而无法备份 iPhone 的情况。iTunes 拒绝备份 iPhone 时,可能会令人非常沮丧。不过,幸运的是,我们有 10 种有效的方法可以解决这个问题。您可以按照以…

LangChain4J 使用实践

这里写目录标题 大模型应用场景&#xff1a;创建一个测试示例AIService聊天记忆实现简单实现聊天记录记忆MessageWindowChatMemory实现聊天记忆 隔离聊天记忆聊天记忆持久化 添加AI提示词 大模型应用场景&#xff1a; 创建一个测试示例 导入依赖 <dependency><groupI…