Pytorch-MLP-Mnist

news2025/5/25 4:29:01

文章目录

  • model.py
  • main.py
  • 参数设置
  • 注意事项
    • 初始化权重
    • 如果发现loss和acc不变
    • 关于数据下载
    • 关于输出格式
  • 运行图

model.py

import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init

class MLP_cls(nn.Module):
    def __init__(self,in_dim=28*28):
        super(MLP_cls,self).__init__()
        self.lin1 = nn.Linear(in_dim,128)
        self.lin2 = nn.Linear(128,64)
        self.lin3 = nn.Linear(64,10)
        self.relu = nn.ReLU()
        init.xavier_uniform_(self.lin1.weight)
        init.xavier_uniform_(self.lin2.weight)
        init.xavier_uniform_(self.lin3.weight)

    def forward(self,x):
        x = x.view(-1,28*28)
        x = self.lin1(x)
        x = self.relu(x)
        x = self.lin2(x)
        x = self.relu(x)
        x = self.lin3(x)
        x = self.relu(x)
        return x

main.py

import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader
import torch.optim as optim
from model import MLP_cls


seed = 42
torch.manual_seed(seed)
batch_size_train = 64
batch_size_test  = 64
epochs = 10
learning_rate = 0.01
momentum = 0.5
mlp_net = MLP_cls()

train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('./data/', train=True, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.5,), (0.5,))
                               ])),
    batch_size=batch_size_train, shuffle=True)
test_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('./data/', train=False, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.5,), (0.5,))
                               ])),
    batch_size=batch_size_test, shuffle=True)

optimizer = optim.SGD(mlp_net.parameters(), lr=learning_rate,momentum=momentum)
criterion = nn.CrossEntropyLoss()

print("****************Begin Training****************")
mlp_net.train()
for epoch in range(epochs):
    run_loss = 0
    correct_num = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        out = mlp_net(data)
        _,pred = torch.max(out,dim=1)
        optimizer.zero_grad()
        loss = criterion(out,target)
        loss.backward()
        run_loss += loss
        optimizer.step()
        correct_num  += torch.sum(pred==target)
    print('epoch',epoch,'loss {:.2f}'.format(run_loss.item()/len(train_loader)),'accuracy {:.2f}'.format(correct_num.item()/(len(train_loader)*batch_size_train)))



print("****************Begin Testing****************")
mlp_net.eval()
test_loss = 0
test_correct_num = 0
for batch_idx, (data, target) in enumerate(test_loader):
    out = mlp_net(data)
    _,pred = torch.max(out,dim=1)
    test_loss += criterion(out,target)
    test_correct_num  += torch.sum(pred==target)
print('loss {:.2f}'.format(test_loss.item()/len(test_loader)),'accuracy {:.2f}'.format(test_correct_num.item()/(len(test_loader)*batch_size_test)))

参数设置

'./data/' #数据保存路径
seed = 42 #随机种子
batch_size_train = 64
batch_size_test  = 64
epochs = 10

optim --> SGD
learning_rate = 0.01
momentum = 0.5

注意事项

初始化权重

这里使用这种方式

        init.xavier_uniform_(self.lin1.weight)
        init.xavier_uniform_(self.lin2.weight)
        init.xavier_uniform_(self.lin3.weight)

如果发现loss和acc不变

检查一下是不是忘记写optimizer.step()了

关于数据下载

数据在download=True时,会下载在./data文件夹下

关于输出格式

这里用‘xxx {:.2f}'.format(xxx),保留两位小数。注意中间的空格,区分:.2f和%2f

运行图

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1027778.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Redis之hash类型

文章目录 Redis之hash类型1. 设置一个字段/获取一个字段2. 获取所有字段值3. 判断字段是否存在4. 设置多个字段/获取多个字段5. 只获取字段名/字段值6. 获取某个key内全部数量7. 增加数字8. 删除key内字段9. 字段不存在时赋值10. 应用场景 Redis之hash类型 redis的hash类型&…

Google Play上线规范及流程

将应用发布到 Google Play 商店需要遵循一系列规范和流程,以确保应用的质量、安全性和用户体验。以下是发布应用到 Google Play 的一般规范和流程,希望对大家有所帮助。北京木奇移动技术有限公司,专业的软件外包开发公司,欢迎交流…

【linux基础(七)】Linux中的开发工具(下)--make/makefile和git

💓博主CSDN主页:杭电码农-NEO💓   ⏩专栏分类:Linux从入门到开通⏪   🚚代码仓库:NEO的学习日记🚚   🌹关注我🫵带你学更多操作系统知识   🔝🔝 Linux中的开发工具 1. 前言2.…

代码随想录算法训练营第一天(C)| 704. 二分查找 27. 移除元素

文章目录 前言一、704. 二分查找二、27. 移除元素三、34. 在排序数组中查找元素的第一个和最后一个位置总结 前言 这次是C; 代码随想录算法训练营第一天| 704. 二分查找、27. 移除元素_愚者__的博客-CSDN博客 (java) 一、704. 二分查找 的优…

【HarmonyOS】【FAQ】HarmonyOS应用开发相关问题解答(四)

贴接上回。。。 【往期FAQ参考】 【HarmonyOS】【FAQ】HarmonyOS应用开发相关问题解答(一) 【HarmonyOS】【FAQ】HarmonyOS应用开发相关问题解答(二) 【HarmonyOS】【FAQ】HarmonyOS应用开发相关问题解答(三&#x…

ELFK之zookeeper+kafka

目录 kafkazookeeper的系统架构 Zookeeper 一、zookeeper概述 二、zookeeper特点 三、zookeeper选举机制 四、应用场景 五、zookeeper实验实例 Kafka 一、概述 为什么需要消息队列(MQ) 使用消息队列的好处 消息队列的两种模式 Kafka 定义 二、Kafka 的特性 三、Ka…

【Linux系统编程】通过系统调用获取进程标识符 及 创建子进程(fork)

文章目录 1. 通过系统调用获取进程标示符(PID)1.1 进程id(PID)1.2 父进程id(PPID) 2. bash也是一个进程3. 通过系统调用创建进程-fork初识3.1 批量化注释3.2 取消注释3.3 fork创建子进程3.4 fork的返回值3.…

【AD】【PCB封装规范计划】 -CON排针类

像这种CON,排针的。画PCB封装的时候,要把数字用丝印标出来!!!

浏览器调用本地exe

本地新建 .reg 文件添加注册表信息 修改路径和自定义协议名称 双击运行reg文件添加注册表信息 各参数说明,路径需要多加一个\转义 reg文件样例 Windows Registry Editor Version 5.00 [HKEY_CLASSES_ROOT\localexe] "URL Protocol""C:\\Use…

java集合之迭代器遍历元素

集合遍历 遍历、迭代、逐个获取容器中的元素 Iterable接口 实现了Iterable接口的类是可以遍历的,因为Iterable接口是Collection接口的父接口,而所有单列集合类都实现了Collection接口,从而也都实现了Iterable接口,所以所有单列集…

电压放大器在电子测试中的应用有哪些方面

电压放大器是一种常见的电子设备,广泛应用于各种测试和测量应用中。以下是电压放大器在电子测试中的几个主要方面应用的简要介绍。 信号采集与处理:电压放大器通常用于信号采集和处理,在测试过程中将低电平信号放大到适合进一步处理或分析的水…

【python基础】编写/运行hello world项目

1.编写hello world项目 编程界每种语言的第一个程序往往都是输出hello world。因此我们来看看,如何用Python输出hello world。 1.如果你是初学者,main.py中的代码暂时是无法看懂的,所以可以把main中的源代码直接删除。如下所示 这里我们要…

Blender Morph Targets

推荐:用 NSDT编辑器 快速搭建可编程3D场景 在 Blender 中,Morph Target被称为Shape Key,即形状键,是将网格从一种形状变形为另一种形状的工具。 每个对象都被分配了一个基本形状,然后可以有许多可以变形的形状键。 形…

C# linq初探 使用linq查询数组中元素

使用linq进行数组查询 输出数组中全部的偶数并升序输出结果 写法1: int[] numbers { 5, 10, 8, 3, 6, 12 }; //查询的数组var numqurey from num in numberswhere num % 2 0 //按照条件过滤orderby numselect num;foreach (var num in numqurey){Console.Writ…

面试官:你是怎么理解ES6中 Decorator 的?使用场景?

🎬 岸边的风:个人主页 🔥 个人专栏 :《 VUE 》 《 javaScript 》 ⛺️ 生活的理想,就是为了理想的生活 ! 目录 一、介绍 二、用法 类的装饰 类属性的装饰 注意 三、使用场景 antobind readonly deprecate 一、介绍 Dec…

OSI模型与数据的封装

1、OSI模型 上层|| 七层模型 四层模型|| 应用层| 表示层 应用层 http/ftp/ssh/ftps| 会话层 -----------------------------------------------------------------------| 传输层 传输层 tcp/udp ------------------------------…

Java基于SpringBoot的财务管理系统,附源码,教程

博主介绍:✌程序员徐师兄、7年大厂程序员经历。全网粉丝30W,Csdn博客专家、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ 文章目录 一 简介第二.主要技术第三、部分效果图第四章 系统设计4.1功能结构4.2 数据库设计4.2.1 数据库E/R…

软件设计原则扩展

一、引言 经典的软件设计7大原则 开闭原则(Open Close Principle, OCP) 依赖倒置原则(Dependence Inversion Principle, DIP) 单一职责原则(Simple Responsibility Principle, SRP) 接口隔离原则&#xf…

力扣刷题-82. 删除排序链表中的重复元素

题目来源:力扣82 题目描述: 代码及思路: class Solution {public ListNode deleteDuplicates(ListNode head) {/**判断下一个节点的值与下下个的值是否相同,相同:循环到没有相同 下一个节点指到值不同的节点上不同&a…

C语言进阶第三课-----------指针的进阶----------后续版

作者前言 🎂 ✨✨✨✨✨✨🍧🍧🍧🍧🍧🍧🍧🎂 ​🎂 作者介绍: 🎂🎂 🎂 🎉🎉&#x1f389…