Mnist手写数字

news2025/6/6 10:37:56

运行实现:

import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
import matplotlib.pyplot as plt


class Net(torch.nn.Module):#net类神经网络主体

    def __init__(self):#4个全链接层
        super().__init__()
        self.fc1 = torch.nn.Linear(28*28, 64)#输入为28*28尺寸图像
        self.fc2 = torch.nn.Linear(64, 64)#中间三层都是64个节点
        self.fc3 = torch.nn.Linear(64, 64)
        self.fc4 = torch.nn.Linear(64, 10)#输出为10个数字类别
    
    def forward(self, x):#前向传播
        x = torch.nn.functional.relu(self.fc1(x))#先全连接线性计算,再套上激活函数
        x = torch.nn.functional.relu(self.fc2(x))
        x = torch.nn.functional.relu(self.fc3(x))
        x = torch.nn.functional.log_softmax(self.fc4(x), dim=1)#输出层用softmax做归一化,log_softmax是为了提高计算稳定性,套上了一个对数函数
        return x


def get_data_loader(is_train):#导入数据
    to_tensor = transforms.Compose([transforms.ToTensor()])#导入张量
    data_set = MNIST("", is_train, transform=to_tensor, download=True)#下载文件,”“里面对应的是下载目录,is_train指定导入训练集还是测试集
    return DataLoader(data_set, batch_size=15, shuffle=True)#一个批次15张图片,shuffle=true说明数据是随机打乱的,返回数据加载器


def evaluate(test_data, net):#评估正确率
    n_correct = 0
    n_total = 0
    with torch.no_grad():
        for (x, y) in test_data:#取出数据
            outputs = net.forward(x.view(-1, 28*28))#计算神经网络预测值
            for i, output in enumerate(outputs):#作比较
                if torch.argmax(output) == y[i]:#argmax取最大预测概率的序号
                    n_correct += 1#累加正确的
                n_total += 1
    return n_correct / n_total


def main():

    train_data = get_data_loader(is_train=True)#训练集
    test_data = get_data_loader(is_train=False)#测试集
    net = Net()
    
    print("initial accuracy:", evaluate(test_data, net))#打印初始网络的正确率,接近0.1
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001)#以下为pytorch固定写法
    for epoch in range(3):#epoch是轮次
        for (x, y) in train_data:
            net.zero_grad()#初始化
            output = net.forward(x.view(-1, 28*28))#正向传播
            loss = torch.nn.functional.nll_loss(output, y)#计算差值,null_loss对数损失函数,为了匹配前面log_softmax的对数运算
            loss.backward()#反向误差传播
            optimizer.step()#优化网络参数
        print("epoch", epoch, "accuracy:", evaluate(test_data, net))

    for (n, (x, _)) in enumerate(test_data):#抽取4张图像,显示预测结果
        if n > 3:
            break
        predict = torch.argmax(net.forward(x[0].view(-1, 28*28)))
        plt.figure(n)
        plt.imshow(x[0].view(28, 28),cmap='gray')
        plt.title("prediction: " + str(int(predict)))
    plt.show()


if __name__ == "__main__":
    main()

中间可能会报错误:(libiomp5md.dll问题)

OMP: Error #15: Initializing libiomp5md.dll, but found libiomp5md.dll already initialized.

这个处理就是在anaconda文件夹下面搜索libiomp5md.dll,那bin下面的 libiomp5md.dll文件全部修改命名,就像我这样,两个bin文件夹下面的都改了。

 

运行结果:

两轮精确度如下:

4个数字预测图片如下:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2401549.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

《一生一芯》数字实验三:加法器与ALU

1. 实验目标 设计一个能实现如下功能的4位带符号位的 补码 ALU&#xff1a; Table 4 ALU 功能列表  功能选择 功能 操作 000 加法 AB 001 减法 A-B 010 取反 Not A 011 与 A and B 100 或 A or B 101 异或 A xor B 110 比较大小 If A<B then out1…

三甲医院“AI平台+专家系统”双轮驱动模式的最新编程方向分析

医疗人工智能领域正在经历从“单点技术应用”到“系统性赋能”的深刻转型。在这一转型过程中,国内领先的三甲医院通过探索“AI平台+专家系统”双轮驱动模式,不仅解决了医疗AI落地“最后一公里”的难题,更推动了医疗服务质量与效率的全面提升。本文从技术架构、编程方向、落地…

第12期_网站搭建_几时网络验证1.3二改源码包2024 软件卡密系统 虚拟主机搭建笔记

我用夸克网盘分享了「第12期_网站搭建_几时网络验证1.3二改源码包2024.7z」&#xff0c;点击链接即可保存。打开「夸克APP」&#xff0c;无需下载在线播放视频&#xff0c;畅享原画5倍速&#xff0c;支持电视投屏。 链接&#xff1a;https://pan.quark.cn/s/fe8e7786bd6d

[论文阅读] (38)基于大模型的威胁情报分析与知识图谱构建论文总结(读书笔记)

《娜璋带你读论文》系列主要是督促自己阅读优秀论文及听取学术讲座&#xff0c;并分享给大家&#xff0c;希望您喜欢。由于作者的英文水平和学术能力不高&#xff0c;需要不断提升&#xff0c;所以还请大家批评指正&#xff0c;非常欢迎大家给我留言评论&#xff0c;学术路上期…

回溯算法复习(1)

1.回溯的定义&#xff08;ai&#xff09; 回溯&#xff08;Backtracking&#xff09; 是一种通过搜索所有可能的解空间来求解问题的算法思想&#xff0c;属于试探性求解方法。其核心是在搜索过程中逐步构建解&#xff0c;并在发现当前路径无法得到有效解时&#xff0c;主动回退…

学习路之PHP--webman安装及使用、webman/admin安装

学习路之PHP--webman安装及使用 一、安装webman二、运行三、安装webman/admin四、效果五、配置Nginx反向代理&#xff08;生产环境&#xff1a;可选&#xff09;六、使用 一、安装webman 准备&#xff1a; PHP > 8.1 Composer > 2.0 启用函数&#xff1a; putenv proc_o…

基于cornerstone3D的dicom影像浏览器 第二十八章 LabelTool文字标记,L标记,R标记及标记样式设置

文章目录 前言一、L标记、R标记二、修改工具样式1. 样式的四种级别2. 导入annotation3. 示例1 - 修改toolGroup中的样式4. 示例2 - 修改viewport中的样式 三、可配置样式 前言 cornerstone3D 中的文字标记工具LabelTool&#xff0c;在添加文字标记时会弹出对话框让用户输入文字…

电路图识图基础知识-自耦变压器降压启动电动机控制电路(十六)

自耦变压器降压启动电动机控制电路 自耦变压器降压启动电动机控制电路是将自耦变压器的原边绕组接于电源侧&#xff0c;副边绕组接 于电机侧。电动机定子绕组启动时的电压为自耦变压器降压后得到的电压&#xff0c;这样可以减少电动 机的启动电流和启动力矩&#xff0c;当电动…

神经网络与深度学习 网络优化与正则化

1.网络优化存在的难点 &#xff08;1&#xff09;结构差异大&#xff1a;没有通用的优化算法&#xff1b;超参数多 &#xff08;2&#xff09;非凸优化问题&#xff1a;参数初始化&#xff0c;逃离局部最优 &#xff08;3&#xff09;梯度消失&#xff08;爆炸&#xff09; …

【Git系列】如何同步原始仓库的更新到你的fork仓库?

&#x1f389;&#x1f389;&#x1f389;欢迎来到我们的博客&#xff01;无论您是第一次访问&#xff0c;还是我们的老朋友&#xff0c;我们都由衷地感谢您的到来。无论您是来寻找灵感、获取知识&#xff0c;还是单纯地享受阅读的乐趣&#xff0c;我们都希望您能在这里找到属于…

深度强化学习驱动的智能爬取策略优化:基于网页结构特征的状态表示方法

传统网络爬虫依赖静态规则&#xff08;如广度优先搜索&#xff09;或启发式策略&#xff0c;在面对动态网页&#xff08;如SPA单页应用&#xff09;、复杂层级结构&#xff08;如多层嵌套导航&#xff09;及反爬机制时&#xff0c;常表现出爬取效率低下、覆盖率不足等问题。本文…

如何轻松将视频从安卓设备传输到电脑?

现在&#xff0c;我们可以轻松地使用安卓手机拍摄高分辨率视频。然而&#xff0c;这些视频会占用大量的存储空间。如果您想将视频从安卓设备传输到电脑以释放存储空间、编辑素材或只是备份记忆&#xff0c;可以使用本文介绍的 8 种实用方法来完成视频传输。 第 1 部分&#xff…

时代星光推出战狼W60智能运载无人机,主要性能超市场同类产品一倍!

在刚刚结束的第九届世界无人机大会上&#xff0c;时代星光科技发布了其全新产品战狼W60智能运载无人机&#xff0c;并展示了基于战狼W60无人机平台的多种应用场景解决方案。据了解&#xff0c;该产品作为一款多旋翼无人机&#xff0c;主要性能参数均远超市场同类产品&#xff0…

BUUCTF[极客大挑战 2019]Secret File 1题解

[极客大挑战 2019]Secret File 1 分析&#xff1a;解题界面1&#xff1a;界面二&#xff1a;界面3&#xff1a; 总结: 分析&#xff1a; 事后来看&#xff0c;这道题主打一个走一步看一步。我们只能从题目的标题中猜到&#xff0c;这道题与文件有关。 解题 界面1&#xff1a…

Odoo电子邮件使用配置指南

在Odoo中配置邮件收发功能需要设置SMTP发件服务器和IMAP/POP3收件服务器&#xff0c;并确保DNS记录&#xff08;如SPF、DKIM&#xff09;正确&#xff0c;以避免邮件被标记为垃圾邮件。以下指南是详细配置步骤&#xff1a; 1. 配置出站邮件&#xff08;SMTP&#xff09; 1.1 使…

MacOS解决局域网“没有到达主机的路由 no route to host“

可能原因&#xff1a;MacOS 15新增了"本地网络"访问权限&#xff0c;在 APP 第一次尝试访问本地网络的时候会请求权限&#xff0c;可能顺手选择了关闭。 解决办法&#xff1a;给想要访问本地网络的 APP &#xff08;例如 terminal、Navicat、Ftp&#xff09;添加访问…

找到每一个单词+模拟的思路和算法

如大家所知&#xff0c;我们可以对给定的字符串 sentence 进行一次遍历&#xff0c;找出其中的每一个单词&#xff0c;并根据题目的要求进行操作。 在寻找单词时&#xff0c;我们可以使用语言自带的 split() 函数&#xff0c;将空格作为分割字符&#xff0c;得到所有的单词。为…

2025东南亚跨境选择:Lazada VS. Shopee深度对比

东南亚电商市场持续爆发&#xff0c;2025年预计规模突破2000亿美元。对跨境卖家而言&#xff0c;Lazada与Shopee仍是两大核心战场&#xff0c;但平台生态与竞争格局已悄然变化。深入对比&#xff0c;方能制胜未来。 一、平台基因与核心优势对比 维度 Lazada (阿里系) Shopee …

Java-39 深入浅出 Spring - AOP切面增强 核心概念 通知类型 XML+注解方式 附代码

点一下关注吧&#xff01;&#xff01;&#xff01;非常感谢&#xff01;&#xff01;持续更新&#xff01;&#xff01;&#xff01; &#x1f680; AI篇持续更新中&#xff01;&#xff08;长期更新&#xff09; 目前2025年06月05日更新到&#xff1a; AI炼丹日志-28 - Aud…

.NET 8集成阿里云短信服务完全指南【短信接口】

文章目录 前言一、准备工作1.1 阿里云账号准备1.2 .NET 8项目创建 二、集成阿里云短信SDK2.1 安装NuGet包2.2 配置阿里云短信参数2.3 创建配置类 三、实现短信发送服务3.1 创建短信服务接口3.2 实现短信服务3.3 注册服务 四、创建控制器五、测试与优化5.1 单元测试5.2 性能优化…