神经网络分类任务(手写数字识别)

news2025/7/19 4:24:15

1.Mnist分类任务

  • 网络基本构建与训练方法,常用函数解析

  • torch.nn.functional模块

  • nn.Module模块

学习方法:边用边查,多打印,duo'gua

使用jupyter的优点,可以打印出每一个步骤。

2.读取数据集

自动下载

%matplotlib inline
#查看本机torch的版本
import torch
print(torch.__version__)#打印torch的版本

加载并读取数据集

from pathlib import Path
import requests

DATA_PATH = Path("data")
PATH = DATA_PATH / "mnist"

PATH.mkdir(parents=True, exist_ok=True)

URL = "http://deeplearning.net/data/mnist/"
FILENAME = "mnist.pkl.gz"

if not (PATH / FILENAME).exists():
        content = requests.get(URL + FILENAME).content
        (PATH / FILENAME).open("wb").write(content)


import pickle
import gzip

with gzip.open((PATH / FILENAME).as_posix(), "rb") as f:
        ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding="latin-1")

观察数据的结构

784是mnist数据集每个样本的像素点个数

print(x_train[0].shape)
print(x_train.shape)
y_train

显示一个记录的灰度图

from matplotlib import pyplot
import numpy as np

pyplot.imshow(x_train[0].reshape((28, 28)), cmap="gray")
print(x_train.shape)

 numpy和torch的区别:
torch->gpu->tensor
numpy->cpu->ndarray

数据转化:

import torch

x_train, y_train, x_valid, y_valid = map(
    torch.tensor, (x_train, y_train, x_valid, y_valid)
)
n, c = x_train.shape
x_train, x_train.shape, y_train.min(), y_train.max()
print(x_train, y_train)
print(x_train.shape)
print(y_train.min(), y_train.max())

3  torch.nn.functional 很多层和函数在这里都会见到

torch.nn.functional中有很多功能,后续会常用的。那什么时候使用nn.Module,什么时候使用nn.functional呢?一般情况下,如果模型有可学习的参数,最好用nn.Module,其他情nn.functional相对更简单一些

import torch.nn.functional as F

loss_func = F.cross_entropy

def model(xb):
    return xb.mm(weights) + bias


bs = 64
xb = x_train[0:bs]  # a mini-batch from x
yb = y_train[0:bs]
weights = torch.randn([784, 10], dtype = torch.float,  requires_grad = True) 
#线性代数的相关知识,weights与输入想乘之后,需要输出的格式为10分类,所以的weights的矩阵为(784,10)
bs = 64
bias = torch.zeros(10, requires_grad=True)#偏执的设置,常数值作为初始化,因为这个东西对模型的影###响不是很大。

print(loss_func(model(xb), yb)):#计算真实值和预测值之间的误差

4  创建一个model来更简化代码

  • 必须继承nn.Module且在其构造函数中需调用nn.Module的构造函数
  • 无需写反向传播函数,nn.Module能够利用autograd自动实现反向传播
  • Module中的可学习参数可以通过named_parameters()或者parameters()返回迭代器
from torch import nn

class Mnist_NN(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden1 = nn.Linear(784, 128)
        self.hidden2 = nn.Linear(128, 256)
        self.out  = nn.Linear(256, 10)
   #定义前向传播,torch有一个优点:前向传播自己定义,反向传播自动实现。
    def forward(self, x):
        x = F.relu(self.hidden1(x))
        x = F.relu(self.hidden2(x))
        x = self.out(x)
        return x
        
net = Mnist_NN()
print(net)

可以打印我们定义好名字里的权重和偏置项

for name, parameter in net.named_parameters():
    print(name, parameter,parameter.size())

5  使用TensorDataset和DataLoader来简化

from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader

train_ds = TensorDataset(x_train, y_train)
train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True)#shuffle=True:洗牌的操作。
valid_ds = TensorDataset(x_valid, y_valid)
valid_dl = DataLoader(valid_ds, batch_size=bs * 2)

def get_data(train_ds, valid_ds, bs):
    return (
        DataLoader(train_ds, batch_size=bs, shuffle=True),
        DataLoader(valid_ds, batch_size=bs * 2),
    )
  • 一般在训练模型时加上model.train(),这样会正常使用Batch Normalization和 Dropout
  • 测试的时候一般选择model.eval(),这样就不会使用Batch Normalization和 Dropout

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/396484.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

移动设备配置文件管理

什么是移动设备上的设备配置文件 随着移动设备在工作中使用量的迅速增加,有必要将这些设备置于企业管理之下,以确保企业数据安全且设备符合行业标准。移动设备上的配置文件允许 IT 管理员通过对员工使用的智能手机、平板电脑和笔记本电脑实施公司策略和…

三维人脸实践:基于Face3D的渲染、生成与重构 <一>

face3d: Python tools for processing 3D face git code: https://github.com/yfeng95/face3d paper list: PaperWithCode 该方法广泛用于基于三维人脸关键点的人脸生成、属性检测(如位姿、深度、PNCC等),能够快速实现人脸建模与渲染。推荐…

学生使用的台灯该怎么选择?2023适合学生房间的灯推荐

随着社会的进步发展,我们的生活水平越来越高,很多家庭的孩子都开始使用台灯这种家居产品,对于学习任务繁重的他们来说,台灯确实可以起到保护眼睛、提高学习专注度的作用。那么不知道朋友们是否了解过,台灯该怎么选择呢…

本地开发vue项目联调遇到访问接口跨域问题

本地开发vue项目联调遇到访问接口跨域问题 修改本地的localhost 一:按winr打开运行窗口,输入drivers ,然后回车 二:打开etc文件夹,然后用记事本的方式打开里面的hosts文件, 三:这时我们就可…

oneblog_justauth_三方登录配置【QQ】

文章目录oneblog添加第三方平台QQ互联平台创建三方应用完善信息登录oneblog添加第三方平台 1.oneblog管理端,点击左侧菜单 网站管理——>社会化登录配置管理 ,添加一个社会化登录 2.编辑信息如下,选择QQ平台后复制redirectUri,然后去QQ互联平台获取…

UART 串口通信

第18.1讲 UART串口通信原理讲解_哔哩哔哩_bilibili 并行通信 一个周期同时发送8bit的数据,占用引脚资源多 串行通信 串行通信的通信方式: 同步通信 同一时钟下进行数据传输 异步通信 发送设备和接收设备的时钟不同 但是需要约束波特率(…

大数据|HDFS分布式文件系统

前文回顾:Hadoop系统 目录 📚HDFS概述 📚HDFS在设计时的假设和目标 📚HDFS的基本特征 📚HDFS的体系结构 🐇目录节点 🐇数据节点 📚HDFS的副本机制 📚HDFS的数据存…

KD500全自动电容电感测试仪

一、产品特点 1.本仪器采用了先进的测量原理与四端测量技术,可以精确测量、测试重复性能好; 2.能在不拆线的状态下,测量成组并联着的单个电容器的电容量和成组并联着电容器组的总电容量; 3.大屏幕液晶显示屏(320X24…

关于Activiti7审批工作流绘画流程图(2)

文章目录一、25张表详解二、安装插件一.定制流程提示:以下是本篇文章正文内容,下面案例可供参考 一、25张表详解 虽然表很多,但是仔细观察,我们会发现Activiti 使用到的表都是 ACT_ 开头的。表名的第二部分用两个字母表明表的用…

vuex基础之初始化功能、state、mutations、getters、模块化module的使用

vuex基础之初始化功能、state、mutations、getters、模块化module的使用一、Vuex的介绍二、初始化功能三、state3.1 定义state3.2 获取state3.2.1 原始形式获取3.2.2 辅助函数获取(mapState)四、mutations4.1 定义mutations4.2 调用mutations4.2.1 原始形式调用($store)4.2.2 辅…

数据分析方法及名词解释总结_(面试2)

1、用户画像 1.1、什么是用户画像?如何构建用户画像? - 知乎提到用户画像, 很多人都可能存在的错误认知,即把用户画像简单理解成用户各种特征,比如说姓名、性别、…https://www.zhihu.com/question/372802348/answer/2…

spi 子系统

spi在应用层的体现 spi 分为主机模式和从机模式,一般soc 自带的spi 控制器,我们都将它用作主机模式与外挂的从设备通信。从设备例如 oled芯片、flash芯片、陀螺仪芯片等等。 那么spi 驱动和设备,自然也就分为主机驱动、设备和从机驱动、设备…

云原生时代顶流消息中间件Apache Pulsar部署实操之Pulsar IO与Pulsar SQL

文章目录Pulsar IO (Connector连接器)基础定义安装Pulsar和内置连接器连接Pulsar到Cassandra安装cassandra集群配置Cassandra接收器创建Cassandra Sink验证Cassandra Sink结果删除Cassandra Sink连接Pulsar到PostgreSQL安装PostgreSQL集群配置JDBC接收器创建JDBC Sink验证JDBC …

redis cluster配置之read-mode

背景生产部署了redis集群,三台机器(三主三从,主从不在同一台机器上),redission连接使用。当有一个master节点挂掉时,redis整个集群不可用。解决过程运维登上机器上,执行cluster info发现集群OK状…

JAVA开发(数据类型String和HasMap的实现原理)

在JAVA开发中,使用最多的数据类型恐怕是String 和 HasMap两种数据类型。在开发的过程中我们每天都使用的不亦乐乎。但是相信很多人都没有考虑过String数据类型的实现原理或者说是在数据结构中的存储原理,还有一个就是是HashMap,也很少有人去了…

SNAP中根据入射角和相干图使用波段计算计算垂直形变--以门源地震为例

SNAP计算垂直形变0 写在前面1 具体步骤1.1 准备数据1.2 在SNAP中打开波段运算Band Maths1.3 之前计算的水平位移displacement如下图数据的其他处理请参考博文在SNAP中用sentinel-1数据做InSAR测量,以门源地震为例 0 写在前面 如果假设没有平行于传感器视线的水平运…

案例27-单表从9个更新语句调整为2个

目录 一:背景介绍 二:思路&方案 三:过程 1.项目结构 2.准备一个普通的maven项目,部署好mysql数据库 3.在项目中引入pom依赖 5.编写MyBitis配置文件 6.编写Mysql配置类 7.编写通用Update语句 8.项目启动类 四:总…

用GRU实现情感分析:不需要长记忆,也能看懂你的心情

❤️觉得内容不错的话,欢迎点赞收藏加关注😊😊😊,后续会继续输入更多优质内容❤️👉有问题欢迎大家加关注私戳或者评论(包括但不限于NLP算法相关,linux学习相关,读研读博…

JavaScript_Object.keys() Object.values()

目录 一、Object.keys() 二、Object.values() 一、Object.keys() Object.keys( ) 的 用法 : 作用 &#xff1a;遍历对象 { } 返回结果&#xff1a;返回 对象中 每一项 的 key 值 返回值 : 是一个 *** [ 数 组 ] *** 例子 ( 1 ) : <script>// 1. 定义一个对象var obj …

【硬件】P沟道和N沟道MOS管开关电路设计

场效应管做的开关电路一般分为两种&#xff0c;一种是N沟道&#xff0c;另一种是P沟道&#xff0c;如果电路设计中要应用到高端驱动的话&#xff0c;可以采用PMOS来导通。P沟道MOS管开关电路PMOS的特性&#xff0c;Vgs小于一定的值就会导通&#xff0c;当Vgs<0,即Vs>Vg,管…