1. pytorch手写数字预测

news2025/6/7 21:57:48

1. pytorch手写数字预测

  • 1.背景
  • 2.准备数据集
  • 2.定义模型
  • 3.dataloader和训练
  • 4.训练模型
  • 5.测试模型
  • 6.保存模型

1.背景

因为自身的研究方向是多模态目标跟踪,突然对其他的视觉方向产生了兴趣,所以心血来潮的回到最经典的视觉任务手写数字预测上来,所以这份教程并不是一份非常详尽的教程,是在一部分pytorch,深度学习基础上的教程,如果需要的是非常保姆级的教程建议看别的文章

2.准备数据集

这里我才用了直接导torchvision中的dataset包来下载Mnist数据集,也算是一个非常经典的数据集了

# 导入数据集
from torchvision.datasets import MNIST
import torch

# 设置随机种子
torch.manual_seed(3306)

# 数据预处理
from torchvision import transforms
# 定义数据转换
transform = transforms.Compose([
    transforms.ToTensor(),  # 转换为 Tensor
    transforms.Normalize((0.1307,), (0.3081,))  # 标准化
])

# 下载 MNIST 数据集
mnist_train = MNIST(root='./dataset_file/mnist_raw', train=True, download=True,transform=transform)
mnist_test = MNIST(root='./dataset_file/mnist_raw', train=False, download=True,transform=transform)
# 查看数据集大小
print(f"MNIST train dataset size: {len(mnist_train)}")
print(f"MNIST test dataset size: {len(mnist_test)}")

其中,MNIST()中的root代表的是数据集存放的位置,download代表是如果当前位置没有数据集是否需要下载。
transformer则是对数据的处理方式,我这里采用了简单地转成tensor和简单地标准化。

不过这样子下载下来的数据集是二进制格式的,无法直接查看图片,当然,如果你需要查看图片,也有办法。

# 查看图片
import matplotlib.pyplot as plt


def show_image(id):
    img, label = mnist_train[id]
    img = img.squeeze().numpy()  # 去掉通道维度
    print(img.shape)
    # print(img)
    plt.imshow(img, cmap='gray')
    plt.title(f"Label: {label}")
    plt.axis('off')
    plt.show()

show_image(1)

效果
在这里插入图片描述

又或者你想要下载的数据集是图片格式,我这里也准备了代码

代码是在别人的基础上改的,其中数据集存放路径是dataset_dir,如果需要修改自行打印然后修改位置就好了。

#!/usr/bin/env python3
# -*- encoding utf-8 -*-

'''
@File: save_mnist_to_jpg.py
@Date: 2024-08-23
@Author: KRISNAT
@Version: 0.0.0
@Email: ****
@Copyright: (C)Copyright 2024, KRISNAT
@Desc:
    1. 通过 torchvision.datasets.MNIST 下载、解压和读取 MNIST 数据集;
    2. 使用 PIL.Image.save 将 MNIST 数据集中的灰度图片以 JPEG 格式保存。
'''

import sys, os
sys.path.insert(0, os.getcwd())

from torchvision.datasets import MNIST
import PIL
from tqdm import tqdm

if __name__ == "__main__":
    home_dir = os.path.abspath('.')
    root = os.path.abspath(os.path.join(home_dir, '../dataset_file'))
    print(root)
    # exit(0)

    # 图片保存路径
    dataset_dir = os.path.join(root, 'mnist_jpg')
    if not os.path.exists(dataset_dir):
        os.makedirs(dataset_dir)

    # 从网络上下载或从本地加载MNIST数据集
    # 训练集60K、测试集10K
    # torchvision.datasets.MNIST接口下载的数据一组元组
    # 每个元组的结构是: (PIL.Image.Image image model=L size=28x28, 标签数字 int)
    training_dataset = MNIST(
        root='mnist',
        train=True,
        download=True,
    )
    test_dataset = MNIST(
        root='mnist',
        train=False,
        download=True,
    )

    # 保存训练集图片
    with tqdm(total=len(training_dataset), ncols=150) as pro_bar:
        for idx, (X, y) in enumerate(training_dataset):
            f = dataset_dir + "/" + "training_" + str(idx) + \
                "_" + str(training_dataset[idx][1] ) + ".jpg"  # 文件路径
            training_dataset[idx][0].save(f)
            pro_bar.update(n=1)

    # 保存测试集图片
    with tqdm(total=len(test_dataset), ncols=150) as pro_bar:
        for idx, (X, y) in enumerate(test_dataset):
            f = dataset_dir + "/" + "test_" + str(idx) + \
                "_" + str(test_dataset[idx][1] ) + ".jpg"  # 文件路径
            test_dataset[idx][0].save(f)
            pro_bar.update(n=1)

2.定义模型

这里我准备了两个模型,一个MLP模型和一个简单地CNN模型,其中MLP模型参数量1M,CNN模型参数量大概8M,当然这俩模型也没有很仔细的规划

import torch
import torch.nn as nn


class DigitLinear(nn.Module):
    def __init__(self):
        super(DigitLinear, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 1000)
        self.fc2 = nn.Linear(1000, 500)
        self.dropout = nn.Dropout(0.3)
        self.fc3 = nn.Linear(500, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = self.fc1(x)
        x = torch.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = torch.relu(x)
        x = self.fc3(x)
        return x


class DigitCNN(nn.Module):
    def __init__(self):
        super(DigitCNN,self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)

        self.fc1 = nn.Linear(64*28*28, 128)
        self.dropout = nn.Dropout(0.1)
        self.fc2 = nn.Linear(128, 10)
    def forward(self, x):
        # print("x.shape:", x.shape)
        B,N,H,W = x.shape
        x = self.conv1(x)
        x = torch.relu(x)
        x = self.conv2(x)
        x = torch.relu(x)
        x = x.view(B, -1)  # 展平
        x = self.fc1(x)
        x = torch.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        return x

3.dataloader和训练

这里的代码就很简单了,就是一些参数的选择,例如epoch,batchsize。其中的训练函数我写的买有很全面,只是勉强满足了训练功能,还有好多可以优化的点,比如打印fps,断点续训练啥的,不过这个任务提不起劲去干这事,大家可以自行优化。

# 数据加载器
from torch.utils.data import DataLoader
from lib.model.DigitModel import DigitLinear,DigitCNN
# 定义数据加载器
batch_size = 256
train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=False)

epoch = 50
# 训练模型
net = DigitLinear() # 参数量1M 97.50%
# net = DigitCNN() # 参数量8M 98.81%
net.cuda()


# 定义损失函数和优化器
import torch.optim as optim
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
# 训练函数

def train_model(model, train_loader, criterion, optimizer, num_epochs=10):
    model.train()  # 设置模型为训练模式
    
    for epoch in range(num_epochs):
        running_loss = 0.0
        correct = 0
        total = 0
        
        for i, (inputs, labels) in enumerate(train_loader):
            inputs= inputs.cuda()
            y = torch.tensor(torch.zeros((inputs.shape[0],10), dtype=torch.float)).cuda()
            y[torch.arange(inputs.shape[0]), labels] = 1
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, y)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels.cuda()).sum().item()
        epoch_loss = running_loss / len(train_loader)
        epoch_acc = 100. * correct / total
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.2f}%')

# 训练模型
train_model(net, train_loader, criterion, optimizer, num_epochs=epoch)

4.训练模型

有了上面的代码就可以开始训练了,我这里训练的截图是我的MLP模型,效果不是很好,CNN的效果稍微好一点,比MLP高1%,但是图忘记截了。反正够用了,因为本身MNIST的数据就不是很完美,有很多类似于噪声的数据例如:
在这里插入图片描述
这些数字我人眼都分不出是什么玩意。

训练效果如下
在这里插入图片描述

5.测试模型

训练完当然是测试了
最后我的MLP模型跑了97.50%的准确率

代码如下

# 测试模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net.eval()
correct = 0
total = 0
with torch.no_grad():
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device).float(), labels.to(device).float()
        outputs = net(inputs)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels.cuda()).sum().item()
        
        # print(f"Predicted: {predicted}, Ground Truth: {targets}")

print(f"Accuracy: {correct / total * 100:.4f} %")

在这里插入图片描述

6.保存模型

保存模型代码就更简单了

# 保存模型
torch.save(net.state_dict(), './digit_model.pth')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2397513.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

AXI 协议补充(二)

axi协议存在slave 和master 之间的数据交互,在ahb ,axi-stream 高速接口 ,叠加大位宽代码逻辑中,往往有时序问题,valid 和ready 的组合电路中的问题引发的时序问题较多。 本文根据axi 协议和现有解决反压造成的时序问题的方法做一个详细的科普。 1. 解决时序问题的方法:…

Linux 基础指令入门指南:解锁命令行的实用密码

文章目录 引言:Linux 下基本指令常用选项ls 指令pwd 命令cd 指令touch 指令mkdir 指令rmdir 指令 && rm 指令man 指令cp 指令mv 指令cat 指令more 指令less 指令head 指令tail 指令date 指令cal 指令find 指令按文件名搜索按文件大小搜索按修改时间搜索按文…

标准精读:2025 《可信数据空间 技术架构》【附全文阅读】

《可信数据空间 技术架构》规范了可信数据空间的技术架构,明确其作为国家数据基础设施的定位,以数字合约和使用控制技术为核心,涵盖功能架构(含服务平台与接入连接器的身份管理、目录管理、数字合约管理等功能)、业务流程(登记、发现、创建空间及数据流通利用)及安全要求…

山东大学软件学院项目实训-基于大模型的模拟面试系统-面试官和面试记录的分享功能(2)

本文记录在发布文章时&#xff0c;可以添加自己创建的面试官和面试记录到文章中这一功能的实现。 前端 首先是在原本的界面的底部添加了两个多选框&#xff08;后期需要美化调整&#xff09; 实现的代码&#xff1a; <el-col style"margin-top: 1rem;"><e…

Webug4.0靶场通关笔记05- 第5关SQL注入之过滤关键字

目录 一、代码审计 1、源码分析 2、SQL注入分析 &#xff08;1&#xff09;大小写绕过 &#xff08;2&#xff09;双写绕过 二、第05关 过滤型注入 1、进入靶场 2、sqlmap渗透 &#xff08;1&#xff09;bp抓包保存报文 &#xff08;2&#xff09;sqlmap渗透 &…

ONLYOFFICE文档API:更强的安全功能

在数字化办公时代&#xff0c;文档的安全性与隐私保护已成为企业和个人用户的核心关切。如何确保信息在存储、传输及协作过程中的安全&#xff0c;是开发者与IT管理者亟需解决的问题。ONLYOFFICE作为一款功能强大的开源办公套件&#xff0c;不仅提供了高效的文档编辑与协作体验…

解析楼宇自控系统:分布式结构的核心特点与优势展现

在建筑智能化发展的进程中&#xff0c;楼宇自控系统作为实现建筑高效运行与管理的关键&#xff0c;其系统结构的选择至关重要。传统的集中式楼宇自控系统在面对日益复杂的建筑环境和多样化的管理需求时&#xff0c;逐渐暴露出诸多弊端&#xff0c;如可靠性低、扩展性差、响应速…

C#数字图像处理(三)

文章目录 前言1.图像平移1.1 图像平移定义1.2 图像平移编程实例 2.图像镜像2.1 图像镜像定义2.2 图像镜像编程实例 3.图像缩放3.1 图像缩放定义3.2 灰度插值法3.3 图像缩放编程实例 4.图像旋转4.1 图像旋转定义4.2 图像旋转编程实例 前言 在某种意义上来说&#xff0c;图像的几…

SQL Transactions(事务)、隔离机制

目录 Why Transactions? Example: Bad Interaction Transactions ACID Transactions COMMIT ROLLBACK How the Transaction Log Works How Data Is Stored Example: Interacting Processes Interleaving of Statements Example: Strange Interleaving Fixing the…

【机器学习基础】机器学习入门核心:Jaccard相似度 (Jaccard Index) 和 Pearson相似度 (Pearson Correlation)

机器学习入门核心&#xff1a;Jaccard相似度 &#xff08;Jaccard Index&#xff09; 和 Pearson相似度 &#xff08;Pearson Correlation&#xff09; 一、算法逻辑Jaccard相似度 (Jaccard Index)**Pearson相似度 (Pearson Correlation)** 二、算法原理与数学推导1. Jaccard相…

QT之头像剪裁效果实现

文章目录 源码地址&#xff0c;环境&#xff1a;QT5.15&#xff0c;MinGW32位效果演示导入图片设置剪裁区域创建剪裁小窗口重写剪裁小窗口的鼠标事件mousePressEventmouseMoveEventmouseReleaseEvent 小窗口移动触发父窗口的重绘事件剪裁效果实现 源码地址&#xff0c;环境&…

【GPT入门】第40课 vllm与ollama特性对比,与模型部署

【GPT入门】第40课 vllm与ollama特性对比&#xff0c;与模型部署 1.两种部署1.1 vllm与ollama特性对比2. vllm部署2.1 服务器准备2.1 下载模型2.2 提供模型服务 1.两种部署 1.1 vllm与ollama特性对比 2. vllm部署 2.1 服务器准备 在autodl 等大模型服务器提供商&#xff0c;…

unity开发棋牌游戏

使用unity开发的棋牌游戏&#xff0c;目前包含麻将、斗地主、比鸡、牛牛四种玩法游戏。 相关技术 客户端&#xff1a;unity 热更新&#xff1a;xlua 服务器&#xff1a;c Web服务器&#xff1a;ruoyi 游戏视频 unity开发棋牌游戏 游戏截图

Nat Commun项目文章 ▏小麦CUTTag助力解析转录因子TaTCP6调控小麦氮磷高效利用机制

今年2月份发表在《Nature Communications》&#xff08;IF14.4&#xff09;的“TaTCP6 is required for efficientand balanced utilization of nitrate and phosphorus in wheat”揭示了TaTCP6在小麦氮磷利用中的关键调控作用&#xff0c;为优化肥料利用和提高作物产量提供了理…

C 语言开发中常见的开发环境

目录 1.Dev-C 2.Visual Studio Code 3.虚拟机 Linux 环境 4.嵌入式 MCU 专用开发环境 1.Dev-C 使用集成的 C/C 开发环境&#xff08;适合基础学习&#xff09;,下载链接Dev-C下载 - 官方正版 - 极客应用 2.Visual Studio Code 结合 C/C 扩展 GCC/MinGW 编译器&#xff0c…

vscode命令行debug

vscode命令行debug 一般命令行debug会在远程连服务器的时候用上&#xff0c;命令行debug的本质是在执行时暴露一个监听端口&#xff0c;通过进入这个端口&#xff0c;像本地调试一样进行。 这里提供两种方式&#xff1a; 直接在命令行中添加debugpy&#xff0c;适用于python…

Matlab作图之 subplot

1. subplot(m, n, p) 将当前图形划分为m*n的网格&#xff0c;在 p 指定的位置创建坐标轴 matlab 按照行号对子图的位置进行编号 第一个子图是第一行第一列&#xff0c;第二个子图是第二行第二列......... 如果指定 p 位置存在坐标轴&#xff0c; 此命令会将已存在的坐标轴设…

【机器学习基础】机器学习入门核心算法:层次聚类算法(AGNES算法和 DIANA算法)

机器学习入门核心算法&#xff1a;层次聚类算法&#xff08;AGNES算法和 DIANA算法&#xff09; 一、算法逻辑二、算法原理与数学推导1. 距离度量2. 簇间距离计算&#xff08;连接标准&#xff09;3. 算法伪代码&#xff08;凝聚式&#xff09; 三、模型评估1. 内部评估指标2. …

Google Play的最新安全变更可能会让一些高级用户无法使用App

喜欢Root或刷机的Android用户要注意了&#xff0c;Google最近全面启用了新版Play Integrity API&#xff0c;可能会导致部分用户面临无法使用某些App的窘境。Play Integrity API是Google提供给开发者的工具&#xff0c;用于验证App是否在“未修改”的设备上运行。 许多重要应用…

React---day5

4、React的组件化 组件的分类&#xff1a; 根据组件的定义方式&#xff0c;可以分为&#xff1a;函数组件(Functional Component )和类组件(Class Component)&#xff1b;根据组件内部是否有状态需要维护&#xff0c;可以分成&#xff1a;无状态组件(Stateless Component )和…