python学习打卡day40

news2025/6/4 8:33:42
DAY 40 训练和测试的规范写法

知识点回顾:

  1. 彩色和灰度图片测试和训练的规范写法:封装在函数中
  2. 展平操作:除第一个维度batchsize外全部展平
  3. dropout操作:训练阶段随机丢弃神经元,测试阶段eval模式关闭dropout

作业:仔细学习下测试和训练代码的逻辑,这是基础,这个代码框架后续会一直沿用,后续的重点慢慢就是转向模型定义阶段了。

1.彩色和灰度图片测试和训练的规范写法:封装在函数中

# 先继续之前的代码
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader , Dataset # DataLoader 是 PyTorch 中用于加载数据的工具
from torchvision import datasets, transforms # torchvision 是一个用于计算机视觉的库,datasets 和 transforms 是其中的模块
import matplotlib.pyplot as plt
import warnings
# 忽略警告信息
warnings.filterwarnings("ignore")
# 设置随机种子,确保结果可复现
torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")

灰度图像的规范写法:

# 1. 数据预处理
transform = transforms.Compose([
    transforms.ToTensor(),  # 转换为张量并归一化到[0,1]
    transforms.Normalize((0.1307,), (0.3081,))  # MNIST数据集的均值和标准差
])

# 2. 加载MNIST数据集
train_dataset = datasets.MNIST(
    root='./data',
    train=True,
    download=True,
    transform=transform
)

test_dataset = datasets.MNIST(
    root='./data',
    train=False,
    transform=transform
)


#3. 创建数据加载器
train_loader = DataLoader(
    train_dataset,
    batch_size=64, # 每个批次64张图片,一般是2的幂次方,这与GPU的计算效率有关
    shuffle=True # 随机打乱数据
)

test_loader = DataLoader(
    test_dataset,
    batch_size=1000 # 每个批次1000张图片
    # shuffle=False # 测试时不需要打乱数据
)
# 4. 定义模型、损失函数和优化器
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.flatten = nn.Flatten()  # 将28x28的图像展平为784维向量
        self.layer1 = nn.Linear(784, 128)  # 第一层:784个输入,128个神经元
        self.relu = nn.ReLU()  # 激活函数
        self.layer2 = nn.Linear(128, 10)  # 第二层:128个输入,10个输出(对应10个数字类别)
        
    def forward(self, x):
        x = self.flatten(x)  # 展平图像
        x = self.layer1(x)   # 第一层线性变换
        x = self.relu(x)     # 应用ReLU激活函数
        x = self.layer2(x)   # 第二层线性变换,输出logits
        return x

# 初始化模型
model = MLP()
model = model.to(device)  # 将模型移至GPU(如果可用)

# from torchsummary import summary  # 导入torchsummary库
# print("\n模型结构信息:")
# summary(model, input_size=(1, 28, 28))  # 输入尺寸为MNIST图像尺寸

criterion = nn.CrossEntropyLoss()  # 交叉熵损失函数,适用于多分类问题
optimizer = optim.Adam(model.parameters(), lr=0.001)  # Adam优化器

彩色图像的规范写法:

# 设置中文字体支持
plt.rcParams["font.family"] = ["SimHei"]
plt.rcParams['axes.unicode_minus'] = False  # 解决负号显示问题

# 1. 数据预处理
transform = transforms.Compose([
    transforms.ToTensor(),                # 转换为张量
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 标准化处理
])

# 2. 加载CIFAR-10数据集
train_dataset = datasets.CIFAR10(
    root='./data',
    train=True,
    download=True,
    transform=transform
)

test_dataset = datasets.CIFAR10(
    root='./data',
    train=False,
    transform=transform
)

# 3. 创建数据加载器
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# 4. 定义MLP模型(适应CIFAR-10的输入尺寸)
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.flatten = nn.Flatten()  # 将3x32x32的图像展平为3072维向量
        self.layer1 = nn.Linear(3072, 512)  # 第一层:3072个输入,512个神经元
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(0.2)  # 添加Dropout防止过拟合
        self.layer2 = nn.Linear(512, 256)  # 第二层:512个输入,256个神经元
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(0.2)
        self.layer3 = nn.Linear(256, 10)  # 输出层:10个类别
        
    def forward(self, x):
        # 第一步:将输入图像展平为一维向量
        x = self.flatten(x)  # 输入尺寸: [batch_size, 3, 32, 32] → [batch_size, 3072]
        
        # 第一层全连接 + 激活 + Dropout
        x = self.layer1(x)   # 线性变换: [batch_size, 3072] → [batch_size, 512]
        x = self.relu1(x)    # 应用ReLU激活函数
        x = self.dropout1(x) # 训练时随机丢弃部分神经元输出
        
        # 第二层全连接 + 激活 + Dropout
        x = self.layer2(x)   # 线性变换: [batch_size, 512] → [batch_size, 256]
        x = self.relu2(x)    # 应用ReLU激活函数
        x = self.dropout2(x) # 训练时随机丢弃部分神经元输出
        
        # 第三层(输出层)全连接
        x = self.layer3(x)   # 线性变换: [batch_size, 256] → [batch_size, 10]
        
        return x  # 返回未经过Softmax的logits

# 检查GPU是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 初始化模型
model = MLP()
model = model.to(device)  # 将模型移至GPU(如果可用)

criterion = nn.CrossEntropyLoss()  # 交叉熵损失函数
optimizer = optim.Adam(model.parameters(), lr=0.001)  # Adam优化器

训练模型:

# 5. 训练模型(记录每个 iteration 的损失)
def train(model, train_loader, test_loader, criterion, optimizer, device, epochs):
    model.train()  # 设置为训练模式
    
    # 记录每个 iteration 的损失
    all_iter_losses = []  # 存储所有 batch 的损失
    iter_indices = []     # 存储 iteration 序号
    
    for epoch in range(epochs):
        running_loss = 0.0
        correct = 0
        total = 0
        
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)  # 移至GPU
            
            optimizer.zero_grad()  # 梯度清零
            output = model(data)  # 前向传播
            loss = criterion(output, target)  # 计算损失
            loss.backward()  # 反向传播
            optimizer.step()  # 更新参数
            
            # 记录当前 iteration 的损失
            iter_loss = loss.item()
            all_iter_losses.append(iter_loss)
            iter_indices.append(epoch * len(train_loader) + batch_idx + 1)
            
            # 统计准确率和损失
            running_loss += iter_loss
            _, predicted = output.max(1)
            total += target.size(0)
            correct += predicted.eq(target).sum().item()
            
            # 每100个批次打印一次训练信息
            if (batch_idx + 1) % 100 == 0:
                print(f'Epoch: {epoch+1}/{epochs} | Batch: {batch_idx+1}/{len(train_loader)} '
                      f'| 单Batch损失: {iter_loss:.4f} | 累计平均损失: {running_loss/(batch_idx+1):.4f}')
        

模型测试

# 计算当前epoch的平均训练损失和准确率
        epoch_train_loss = running_loss / len(train_loader)
        epoch_train_acc = 100. * correct / total
        
        # 测试阶段
        model.eval()  # 设置为评估模式
        test_loss = 0
        correct_test = 0
        total_test = 0
        
        with torch.no_grad():
            for data, target in test_loader:
                data, target = data.to(device), target.to(device)
                output = model(data)
                test_loss += criterion(output, target).item()
                _, predicted = output.max(1)
                total_test += target.size(0)
                correct_test += predicted.eq(target).sum().item()
        
        epoch_test_loss = test_loss / len(test_loader)
        epoch_test_acc = 100. * correct_test / total_test
        
        print(f'Epoch {epoch+1}/{epochs} 完成 | 训练准确率: {epoch_train_acc:.2f}% | 测试准确率: {epoch_test_acc:.2f}%')
    
    # 绘制所有 iteration 的损失曲线
    plot_iter_losses(all_iter_losses, iter_indices)
    
    return epoch_test_acc  # 返回最终测试准确率

绘制损失曲线

# 6. 绘制每个 iteration 的损失曲线
def plot_iter_losses(losses, indices):
    plt.figure(figsize=(10, 4))
    plt.plot(indices, losses, 'b-', alpha=0.7, label='Iteration Loss')
    plt.xlabel('Iteration(Batch序号)')
    plt.ylabel('损失值')
    plt.title('每个 Iteration 的训练损失')
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()

执行训练和测试:

# 7. 执行训练和测试
epochs = 20  # 增加训练轮次以获得更好效果
print("开始训练模型...")
final_accuracy = train(model, train_loader, test_loader, criterion, optimizer, device, epochs)
print(f"训练完成!最终测试准确率: {final_accuracy:.2f}%")

@浙大疏锦行

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2396503.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Live Helper Chat 安装部署

Live Helper Chat(LHC)是一款开源的实时客服聊天系统,适用于网站和应用,帮助企业与访问者即时沟通。它功能丰富、灵活、可自托管,常被用于在线客户支持、销售咨询以及技术支持场景。 🧰 系统要求 安装要求 您提供的链接指向 Live Helper Chat 的官方安装指南页面,详细…

ARXML解析与可视化工具

随着汽车电子行业的快速发展,AUTOSAR标准在车辆软件架构中发挥着越来越重要的作用。然而,传统的ARXML文件处理工具往往存在高昂的许可费用、封闭的数据格式和复杂的使用门槛等问题。本文介绍一种基于TXT格式输出的ARXML解析方案,为开发团队提供了一个高效的替代解决方案。 …

PnP(Perspective-n-Point)算法 | 用于求解已知n个3D点及其对应2D投影点的相机位姿

什么是PnP算法? PnP 全称是 Perspective-n-Point,中文叫“n点透视问题”。它的目标是: 已知一些空间中已知3D点的位置(世界坐标)和它们对应的2D图像像素坐标,求解摄像机的姿态(位置和平移&…

在日常管理服务器中如何防止SQL注入与XSS攻击?

在日常管理服务器时,防止SQL注入(Structured Query Language Injection)和XSS(Cross-Site Scripting)攻击是至关重要的,这些攻击可能会导致数据泄露、系统崩溃和信息泄露。以下是一份技术文章,介…

Wkhtmltopdf使用

Wkhtmltopdf使用 1.windows本地使用2.golangwindows环境使用3.golangdocker容器中使用 1.windows本地使用 官网地址 https://wkhtmltopdf.org/,直接去里面下载自己想要的版本,这里以windows版本为例2.golangwindows环境使用 1.安装扩展go get -u githu…

ArcGIS Pro 创建渔网格网过大,只有几个格网的解决方案

之前用ArcGIS Pro创建渔网的时候,发现创建出来格网过大,只有几个格网。 后来查阅资料,发现是坐标不对,导致设置格网大小时单位为度,而不是米,因此需要进行坐标系转换,网上有很多资料讲了ArcGIS …

重学计算机网络之以太网

一:历史发展进程 DIX EtherNet V2 战胜IEEE802.3成为主流版本。总线型交换机拓扑机构代替集线器星型拓扑机构 1990年IEEE制定出星形以太网10BASE-T的标准**802.3i**。“10”代表10 Mbit/s 的数据率,BASE表示连接线上的信号是基带信号,T代表…

《深度解构现代云原生微服务架构的七大支柱》

☁️《深度解构现代云原生微服务架构的七大支柱》 一线架构师实战总结,系统性拆解现代微服务架构中最核心的 7 大支柱模块,涵盖通信协议、容器编排、服务网格、弹性伸缩、安全治理、可观测性、CI/CD 等。文内附架构图、实操路径与真实案例,适…

使用SCSS实现随机大小的方块在页面滚动

目录 一、scss中的插值语法 二、方块在界面上滚动的动画 一、scss中的插值语法 插值语法 #{}‌ 是一种动态注入变量或表达式到选择器、属性名、属性值等位置的机制 .类名:nth-child(n) 表示需同时满足为父元素的第n个元素且类名为给定条件 效果图&#xff1a; <div class…

AI 眼镜新纪元:贴片式TF卡与 SOC 芯片的黄金组合破局智能穿戴

目录 一、SD NAND&#xff1a;智能眼镜的“记忆中枢”突破空间限制的存储革命性能与可靠性的双重保障 二、SOC芯片&#xff1a;AI眼镜的“智慧大脑”从性能到能效的全面跃升多模态交互的底层支撑 三、SD NANDSOC&#xff1a;11&#xff1e;2的协同效应数据流水线的高效协同成本…

论文阅读(六)Open Set Video HOI detection from Action-centric Chain-of-Look Prompting

论文来源&#xff1a;ICCV&#xff08;2023&#xff09; 项目地址&#xff1a;https://github.com/southnx/ACoLP 1.研究背景与问题 开放集场景下的泛化性&#xff1a;传统 HOI 检测假设训练集包含所有测试类别&#xff0c;但现实中存在大量未见过的 HOI 类别&#xff08;如…

算法学习--持续更新

算法 2025年5月24日 完成&#xff1a;快速排序、快速排序基数优化、尾递归优化 快排 public class QuickSort {public void sort(int[] nums, int left, int right) {if(left>right){return;}int partiton quickSort(nums,left,right);sort(nums,left,partiton-1);sort(nu…

Postman 发送 SOAP 请求步骤 归档

0.来源 https://apifox.com/apiskills/sending-soap-requests-with-postman/?utm_sourceopr&utm_mediuma2bobzhang&utm_contentpostman 再加上自己一点实践经验 1. 创建一个新的POST请求 postman 创建一个post请求, 请求url 怎么来的可以看第三步 2. post请求设…

Python Day39 学习(复习日志Day4)

复习Day4日志内容 浙大疏锦行 补充: 关于“类”和“类的实例”的通俗易懂的例子 补充&#xff1a;如何判断是用“众数”还是“中位数”填补空缺值&#xff1f; 今日复习了日志Day4的内容&#xff0c;感觉还是得在纸上写一写印象更深刻&#xff0c;接下来几日都采取“纸质化复…

[Python] Python自动化:PyAutoGUI的基本操作

初次学习&#xff0c;如有错误还请指正 目录 PyAutoGUI介绍 PyAutoGUI安装 鼠标相关操作 鼠标移动 鼠标偏移 获取屏幕分辨率 获取鼠标位置 案例&#xff1a;实时获取鼠标位置 鼠标点击 左键单击 点击次数 多次有时间间隔的点击 右键/中键点击 移动时间 总结 鼠…

应急响应靶机-web2-知攻善防实验室

题目&#xff1a; 前景需要&#xff1a;小李在某单位驻场值守&#xff0c;深夜12点&#xff0c;甲方已经回家了&#xff0c;小李刚偷偷摸鱼后&#xff0c;发现安全设备有告警&#xff0c;于是立刻停掉了机器开始排查。 这是他的服务器系统&#xff0c;请你找出以下内容&#…

comfyui利用 SkyReels-V2直接生成长视频本地部署问题总结 1

在通过桌面版comfyUI 安装ComfyUI-WanVideoWrapper 进行SkyReels-V2 生成长视频的过程中&#xff0c;出现了&#xff0c;很多错误。 总结一下&#xff0c;让大家少走点弯路 下面是基于搜索结果的 ComfyUI 本地部署 SkyReels-V2 实现长视频生成的完整指南&#xff0c;涵盖环境配…

YOLOv8 实战指南:如何实现视频区域内的目标统计与计数

文章目录 YOLOv8改进 | 进阶实战篇&#xff1a;利用YOLOv8进行视频划定区域目标统计计数1. 引言2. YOLOv8基础回顾2.1 YOLOv8架构概述2.2 YOLOv8的安装与基本使用 3. 视频划定区域目标统计的实现3.1 核心思路3.2 完整实现代码 4. 代码深度解析4.1 关键组件分析4.2 性能优化技巧…

matlab实现VMD去噪、SVD去噪,源代码详解

为了更好的利用MATLAB自带的vmd、svd函数&#xff0c;本期作者将详细讲解一下MATLAB自带的这两个分解函数如何使用&#xff0c;以及如何画漂亮的模态分解图。 VMD函数用法详解 首先给出官方vmd函数的调用格式。 [imf,residual,info] vmd(x) 函数的输入&#xff1a; 这里的x是待…

SQLite软件架构与实现源代码浅析

概述 SQLite 是一个用 C 语言编写的库&#xff0c;它成功打造出了一款小型、快速、独立、具备高可靠性且功能完备的 SQL 数据库引擎。本文档将为您简要介绍其架构、关键组件及其协同运作模式。 SQLite 显著特点之一是无服务器架构。不同于常规数据库&#xff0c;它并非以单独进…