深度学习之使用BP神经网络识别MNIST数据集

news2025/6/23 14:01:15

目录

补充知识点

torch.nn.LogSoftmax()

 torchvision.transforms

transforms.Compose

transforms.ToTensor

transforms.Normalize(mean, std)

torchvision.datasets

MNIST(手写数字数据集)

torch.utils.data.DataLoader

torch.nn.NLLLoss() 

torch.nn.CrossEntropyLoss()

torch.nn.NLLLoss() 

enumerate() 函数

next() 函数

pytorch中.detach()

torch.squeeze()

torch.optim.SGD 

独热编码

代码实现

bp网络的搭建

建立我们的神经网络对象

预测

完整代码

结果极其图像显示


补充知识点

torch.nn.LogSoftmax()

torch.nn.LogSoftmax()和我们的softmax差不多只不过就是最后加入了一个log,关于softmax的详情大家可以看这一篇博客深度学习之感知机,激活函数,梯度消失,BP神经网络

我们只需要知道里面的一些常用参数就好,一般就是这个dim

下面是softmax的logsoftmax和他一样。

 torchvision.transforms

这个是一个功能强大的数据处理集合库

这里主要说一下transforms.Compose还有transforms.ToTensor以及transforms.Normalize(mean, std)

transforms.Compose

这个功能函数可以看作一个功能函数容器,它里面可以放多个功能函数(多个的话应该把这些功能函数放在一个列表内),当该功能函数定义好后,其内部的其他功能函数也会随之按我们给定的要求定义好,当我们调用compose的实例的时候,就会按照我们在容器内部摆放的顺序从左至右的依次调用功能函数

transforms.ToTensor

用于对载入的图片数据进行类型转换,将之前PIL图片的数据(应该是np.array()类型的)转换成Tensor数据类型的张量,以便于我们后续的数据使用

(补充:PIL默认输出的图片格式为 RGB)

(下面图片内容来自网络,侵权必删。)

transforms.Normalize(mean, std)

数据归一化处理。

下面是我在学习过程中遇到的问题:

1.归一化就是要把图片3个通道中的数据整理到[-1, 1]或者[0,1]区间,x = (x - mean(x))/std(x)只要输入数据集x就可以直接算出来,为什么Normalize()函数的mean和std还需要我们手动输入数值呢?

我的理解是,我们一开始就算好可以极大的减少运算量,如果我们自动的让他算的话,我们这个每一个图片都要算,这样运算量就极大。

2.RGB单个通道的值是[0, 255],所以一个通道的均值应该在127附近才对。我们接下来的代码如下图所示

所填的是0.5,0.5,这是为什么?

 因为我们应用了torchvision.transforms.ToTensor,他会将数据归一化到[0,1](是将数据除以255),transforms.ToTensor( )会把HWC会变成C *H *W(拓展:格式为(h,w,c),像素顺序为RGB),所以我们就应该输入0.5,0.5

(该图片内容来自网络,侵权璧必删)

我们一般来说只需了解一些常用的库函数,我们这里只是提一下我们这次会用到的函数,其余的函数若想了解的话,推荐大家看这位作者写的博客PyTorch之torchvision.tra()nsforms详解[原理+代码实现]-CSDN博客

torchvision.datasets

这里面包含了一些我们目前常用的一些数据集

我们这里主要讲一下mnist 

MNIST(手写数字数据集)

MNIST 数据集来自美国国家标准与技术研究所, National Institute of Standards and Technology (NIST). 训练集 (training set) 由来自 250 个不同人手写的数字构成, 其中 50% 是高中学生, 50% 来自人口普查局 (the Census Bureau) 的工作人员. 测试集(test set) 也是同样比例的手写数字数据。

在我们pytorch里面的torchvision里面是有的,我们可以直接用。

torchvision.datasets.MNIST(root, train=True, transform=None, target_transform=None, download=False)

相关参数:

root:就是我们从网上下载的数据集所放的目录,也就是文件路径

train:train=True意思就是下载的是训练集,如果train=Flase,那就是下载的是测试集

transform: 因为我们的数据直接下载过后一般都是要进行处理,所以这个后面跟的是一个transforms数据处理函数的一个实例化对象

download:为True就是从网络上下载数据集保存在我们的root路径中,为False就是不下载。

其余参数可自行查阅

torch.utils.data.DataLoader

(图片内容来自网络,侵权必删)

(其实我们记住常用的,dataset,batch_size,shuffle这三个常用的参数就好了)

torch.nn.NLLLoss() 

torch.nn.CrossEntropyLoss()

首先我们要了解交叉熵损失函数,torch.nn.CrossEntropyLoss()

什么是熵?

熵是用来描述一个系统的混乱程度,通过交叉熵我们就能够确定预测数据与真是数据之间的相近程度。交叉熵越小,表示数据越接近真实样本。

(预测的概率就是我们的预测值的准确值)

torch.nn.NLLLoss() 

 torch.nn.NLLLoss输入是一个对数概率向量和一个目标标签,它与torch.nn.CrossEntropyLoss的关系可以描述为:

假设有张量x,先softmax(x)得到y,然后再log(y)得到z,然后我们已知标签b,则:

NLLLoss(z,b)=CrossEntropyLoss(x,b)

代码:

nllloss = nn.NLLLoss()
predict = torch.Tensor([[2, 3, 1],
                        [3, 7, 9]])
predict = torch.log(torch.softmax(predict, dim=-1))
label = torch.tensor([1, 2])
nllloss(predict, label)
运行结果:tensor(0.2684)

而我们用 torch.nn.CrossEntropyLoss

cross_loss = nn.CrossEntropyLoss()

predict = torch.Tensor([[2, 3, 1],
                        [3, 7, 9]])
label = torch.tensor([1, 2])
cross_loss(predict, label)
运行结果: tensor(0.2684)

enumerate() 函数

这是python的一个内置函数。

enumerate() 函数用于将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列,同时列出数据和数据下标,一般用在 for 循环当中。

Python 2.3. 以上版本可用,2.6 添加 start 参数。

enumerate(sequence, [start=0])
sequence -- 一个序列、迭代器或其他支持迭代对象。
start -- 下标起始位置的值。
例如:
>>> seasons = ['Spring', 'Summer', 'Fall', 'Winter']
>>> list(enumerate(seasons))
[(0, 'Spring'), (1, 'Summer'), (2, 'Fall'), (3, 'Winter')]
>>> list(enumerate(seasons, start=1))       # 下标从 1 开始
[(1, 'Spring'), (2, 'Summer'), (3, 'Fall'), (4, 'Winter')]

next() 函数

python的内置函数

next() 返回迭代器的下一个项目。

next() 函数要和生成迭代器的 iter() 函数一起使用。

返回值:返回下一个项目。

next(iterable[, default])
iterable -- 可迭代对象
default -- 可选,用于设置在没有下一个元素时返回该默认值,如果不设置,又没有下一个元素则会触发 StopIteration 异常。
例如:
#!/usr/bin/python
# -*- coding: UTF-8 -*-
 
# 首先获得Iterator对象:
it = iter([1, 2, 3, 4, 5])
# 循环:
while True:
    try:
        # 获得下一个值:
        x = next(it)
        print(x)
    except StopIteration:
        # 遇到StopIteration就退出循环
        break
结果:
1
2
3
4
5

pytorch中.detach()

在 PyTorch 中,detach() 方法用于返回一个新的 Tensor,这个 Tensor 和原来的 Tensor 共享相同的内存空间,但是不会被计算图所追踪,也就是说它不会参与反向传播,不会影响到原有的计算图,这使得它成为处理中间结果的一种有效方式,通常在以下两种情况下使用:

第一:在计算图中间,需要截断反向传播的梯度计算时。例如,当计算某个 Tensor 的梯度时,我们希望在此处截断反向传播,而不是将梯度一直传递到计算图的顶部,从而减少计算量和内存占用。此时可以使用 detach() 方法将 Tensor 分离出来。

第二:在将 Tensor 从 GPU 上拷贝到 CPU 上时,由于 Tensor 默认是在 GPU 上存储的,所以直接进行拷贝可能会导致内存不一致的问题。此时可以使用 detach() 方法先将 Tensor 分离出来,然后再将分离出来的 Tensor 拷贝到 CPU 上。

torch.squeeze()

详情请看这篇博客深度学习之张量的处理(代码笔记)

torch.optim.SGD 

torch.optim.SGD 是 PyTorch 中用于实现随机梯度下降(Stochastic Gradient Descent,SGD)优化算法的类。SGD 是一种常用的优化算法。

原理部份可以看我的这篇博客:机器学习优化算法(深度学习)-CSDN博客

我们主要介绍一下常用的参数:

torch.optim.SGD(params, lr=<required parameter>, momentum=0, dampening=0, weight_decay=0, nesterov=False)

(图片内容来自网络,侵权必删) 

独热编码

就是如果是0-9这十个数,那我们就用[1,0,0,0,0,0,0,0,0,0]表示0,[0,1,0,0,0,0,0,0,0,0]表示1,等等

其余同理

代码实现

关于bp神经网络的原理,我们可以看这篇博客:深度学习之感知机,激活函数,梯度消失,BP神经网络-CSDN博客

我们接下来就只是讲解代码实现。

bp网络的搭建

#搭建bp神经网络
class BPNetwork(torch.nn.Module):
    def __init__(self):
        super(BPNetwork,self).__init__()
        #我们的每张图片都是28*28也就是784个像素点
        #第一个隐藏层
        self.linear1=torch.nn.Linear(784,128)
        #激活函数,这里选择Relu
        self.relu1=torch.nn.ReLU()
        #第二个隐藏层
        self.linear2=torch.nn.Linear(128,64)
        #激活函数
        self.relu2=torch.nn.ReLU()
        #第三个隐藏层:
        self.linear3=torch.nn.Linear(64,32)
        # 激活函数
        self.relu3 = torch.nn.ReLU()
        #输出层
        self.linear4=torch.nn.Linear(32,10)
        # 激活函数
        self.softmax=torch.nn.LogSoftmax()
    #前向传播
    def forward(self,x):
        #修改每一个批次的样本集尺寸,修改为64*784,因为我们的图片是28*28
        x=x.reshape(x.shape[0],-1)
        #前向传播
        x=self.linear1(x)#784*128
        x=self.relu1(x)
        x=self.linear2(x)#128*64
        x=self.relu2(x)
        x=self.linear3(x)#64*32
        x=self.relu3(x)
        x=self.linear4(x)#输出层32*10
        x=self.softmax(x)#最后输出的数值我们需要利用到独热编码的思想
        #上面的这些都可以这几使用x=self.model(x)来代替,为什么能用它,我的理解是,我们继承的class moudle 然后对立面写好的模型框架进行定义,而这个方法就是可以直接调用我们定义好的神经网络
        return x

一些关键点都在注释上

搭建这次的BP神经网络我的隐藏层有三层,分别是128,64,32个神经元,因为我们的图片是28*28=784得,我们需要把其展开成一维,所以第一层网络是784*128得,这样输入层中每一行代表一个样本(或者说一张图片得所有像素点),因为我们的每个神经元都有一个参数,第一层网络中每一列都是一个神经元对应的参数,所以最后就是n*784和784*128两个矩阵相乘,最后得到n*128得矩阵,以此类推,最后因为我们的输出要用到独热编码的思想,所以我们的输出层调整为10个神经元,或者说最后得线性网络是32*10。

建立我们的神经网络对象

#建立我们的神经网络对象
model=BPNetwork()
#定义损失函数
critimizer=torch.nn.NLLLoss()
#定义优化器
optimizer=torch.optim.SGD(model.parameters(),lr=0.003,momentum=0.9)
epochs=15#循环得轮数
#每轮抽取次数的和
a=0
loss_=[]
a_=[]
font = font_manager.FontProperties(fname="C:\\Users\\ASUS\\Desktop\\Fonts\\STZHONGS.TTF")
for i in range(epochs):
    # 损失值参数
    sumloss = 0
    for imges,labels in trainload:
        a+=1
        #前向传播
        output=model(imges)
        #反向传播
        loss=critimizer(output,labels)
        loss.backward()
        #参数更新
        optimizer.step()
        #梯度清零
        optimizer.zero_grad()
        #损失累加
        sumloss+=loss.item()
    loss_.append(sumloss)
    a_.append(a)
    print(f"第{i+1}轮的损失:{sumloss},抽取次数和:{a}")
plt.figure()
plt.plot(a_,loss_)
plt.title('损失值随着抽取总次数得变化情况:',fontproperties=font, fontsize=18)
plt.show()

注意看注释。

SGD还有损失函数等等上面的补充内容里面都有说明,这里不在多加阐述。

预测

#开始预测
example=enumerate(testLoad)#从测试集里面随机抽64份并且记录下来里面的内容和下标
batch_index,(imagess,labelss)=next(example)
# bath_index=0
# imagess=0
# labelss=0
# for i,j in example:
#     bath_index=i
#     (imagess, labelss)=j
fig=plt.figure()
for i in range(64):
    pre=model(imagess[i])#预测
    #第一张图片对应的pre得格式:
    # print(pre)
    # tensor([[-2.7053e+01, -1.1105e-03, -1.2767e+01, -1.1126e+01, -1.6005e+01,
    #          -2.0953e+01, -2.3342e+01, -6.8246e+00, -1.2127e+01, -1.6131e+01]],
    #        grad_fn= < LogSoftmaxBackward0 >)
    #接下来我们要用到独热编码的思想,我们取最大的数,也就是最高的概率对应得下标,就相当于这个最高概率对应得独热编码里面的1,其他是0
    pro = list(pre.detach().numpy()[0])
    pre_label=pro.index(max(pro))
    #print(pre_label)

注意看注释

我们实际上也可以不用next()我们可以直接用for循环比如这样。

import d2l.torch as d2l
import math
import torch
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize(0.5,0.5)])
traindata=torchvision.datasets.MNIST(root='D:\learn_pytorch\数据集',train=True,download=True,transform=transform)#训练集60,000张用于训练
testdata=torchvision.datasets.MNIST(root='D:\learn_pytorch\数据集',train=False,download=True,transform=transform)#测试集10,000张用于测试
#利用DataLoader加载数据集
trainload=DataLoader(dataset=traindata,shuffle=True,batch_size=64)
testLoad=DataLoader(dataset=testdata,shuffle=False,batch_size=64)
example=enumerate(testLoad)#从测试集里面随机抽64份并且记录下来里面的内容和下标
a=0
images=0
labelss=0
for i,j in example:
    a+=1
    index = i
    (imagess, labelss) = j
    print(imagess[0])
    print('数据集中抽取的64份数据的纯数据部份的尺寸:',imagess.shape)
    print(imagess[0].shape)
    print(labelss[0])
    print(labelss[0].shape)

    if a==1:
        break

这样得到的结果我们可以看到:

D:\Anaconda3\envs\pytorch\python.exe D:\learn_pytorch\prictice.py 
tensor([[[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -0.3412,
           0.4510,  0.2471,  0.1843, -0.5294, -0.7176, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,  0.7412,
           0.9922,  0.9922,  0.9922,  0.9922,  0.8902,  0.5529,  0.5529,
           0.5529,  0.5529,  0.5529,  0.5529,  0.5529,  0.5529,  0.3333,
          -0.5922, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -0.4745,
          -0.1059, -0.4353, -0.1059,  0.2784,  0.7804,  0.9922,  0.7647,
           0.9922,  0.9922,  0.9922,  0.9608,  0.7961,  0.9922,  0.9922,
           0.0980, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -0.8667, -0.4824, -0.8902,
          -0.4745, -0.4745, -0.4745, -0.5373, -0.8353,  0.8510,  0.9922,
          -0.1686, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -0.3490,  0.9843,  0.6392,
          -0.8588, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -0.8275,  0.8275,  1.0000, -0.3490,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000,  0.0118,  0.9922,  0.8667, -0.6549,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -0.5373,  0.9529,  0.9922, -0.5137, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000,  0.0431,  0.9922,  0.4667, -0.9608, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -0.9294,  0.6078,  0.9451, -0.5451, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -0.0118,  0.9922,  0.4275, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -0.4118,  0.9686,  0.8824, -0.5529, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -0.8510,
           0.7333,  0.9922,  0.3020, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -0.9765,  0.5922,
           0.9922,  0.7176, -0.7255, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -0.7020,  0.9922,
           0.9922, -0.3961, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -0.7569,  0.7569,  0.9922,
          -0.0980, -0.9922, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000,  0.0431,  0.9922,  0.9922,
          -0.5922, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -0.5216,  0.8980,  0.9922,  0.9922,
          -0.5922, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -0.0510,  0.9922,  0.9922,  0.7176,
          -0.6863, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -0.0510,  0.9922,  0.6235, -0.8588,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000]]])
数据集中抽取的64份数据的纯数据部份的尺寸: torch.Size([64, 1, 28, 28])
torch.Size([1, 28, 28])
tensor(7)
torch.Size([])

进程已结束,退出代码0

 我们可以看到一张图片得数据格式(这里面是已经归一化处理过的),因为我们的手写字体识别是单通道的灰度图,所以size是[1,28,28],这里很正确,对于彩色图三通道得来说,会有些不一样。

完整代码

import matplotlib.pyplot as plt
from matplotlib import font_manager

print('BP识别MNIST任务说明---------------------')
import torch
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
#导入数据集并且进行数据处理
transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize(0.5,0.5)])
traindata=torchvision.datasets.MNIST(root='D:\learn_pytorch\数据集',train=True,download=True,transform=transform)#训练集60,000张用于训练
testdata=torchvision.datasets.MNIST(root='D:\learn_pytorch\数据集',train=False,download=True,transform=transform)#测试集10,000张用于测试
#利用DataLoader加载数据集
trainload=DataLoader(dataset=traindata,shuffle=True,batch_size=64)
testLoad=DataLoader(dataset=testdata,shuffle=False,batch_size=64)
#搭建bp神经网络
class BPNetwork(torch.nn.Module):
    def __init__(self):
        super(BPNetwork,self).__init__()
        #我们的每张图片都是28*28也就是784个像素点
        #第一个隐藏层
        self.linear1=torch.nn.Linear(784,128)
        #激活函数,这里选择Relu
        self.relu1=torch.nn.ReLU()
        #第二个隐藏层
        self.linear2=torch.nn.Linear(128,64)
        #激活函数
        self.relu2=torch.nn.ReLU()
        #第三个隐藏层:
        self.linear3=torch.nn.Linear(64,32)
        # 激活函数
        self.relu3 = torch.nn.ReLU()
        #输出层
        self.linear4=torch.nn.Linear(32,10)
        # 激活函数
        self.softmax=torch.nn.LogSoftmax()
    #前向传播
    def forward(self,x):
        #修改每一个批次的样本集尺寸,修改为64*784,因为我们的图片是28*28
        x=x.reshape(x.shape[0],-1)
        #前向传播
        x=self.linear1(x)#784*128
        x=self.relu1(x)
        x=self.linear2(x)#128*64
        x=self.relu2(x)
        x=self.linear3(x)#64*32
        x=self.relu3(x)
        x=self.linear4(x)#输出层32*10
        x=self.softmax(x)#最后输出的数值我们需要利用到独热编码的思想
        #上面的这些都可以这几使用x=self.model(x)来代替,为什么能用它,我的理解是,我们继承的class moudle 然后对立面写好的模型框架进行定义,而这个方法就是可以直接调用我们定义好的神经网络
        return x
#建立我们的神经网络对象
model=BPNetwork()
#定义损失函数
critimizer=torch.nn.NLLLoss()
#定义优化器
optimizer=torch.optim.SGD(model.parameters(),lr=0.003,momentum=0.9)
epochs=15
#每轮抽取次数的和
a=0
loss_=[]
a_=[]
font = font_manager.FontProperties(fname="C:\\Users\\ASUS\\Desktop\\Fonts\\STZHONGS.TTF")
for i in range(epochs):
    # 损失值参数
    sumloss = 0
    for imges,labels in trainload:
        a+=1
        #前向传播
        output=model(imges)
        #反向传播
        loss=critimizer(output,labels)
        loss.backward()
        #参数更新
        optimizer.step()
        #梯度清零
        optimizer.zero_grad()
        #损失累加
        sumloss+=loss.item()
    loss_.append(sumloss)
    a_.append(a)
    print(f"第{i+1}轮的损失:{sumloss},抽取次数和:{a}")
plt.figure()
plt.plot(a_,loss_)
plt.title('损失值随着抽取总次数得变化情况:',fontproperties=font, fontsize=18)
plt.show()
#开始预测
example=enumerate(testLoad)#从测试集里面随机抽64份并且记录下来里面的内容和下标
batch_index,(imagess,labelss)=next(example)
# bath_index=0
# imagess=0
# labelss=0
# for i,j in example:
#     bath_index=i
#     (imagess, labelss)=j
fig=plt.figure()
for i in range(64):
    pre=model(imagess[i])#预测
    #第一张图片对应的pre得格式:
    # print(pre)
    # tensor([[-2.7053e+01, -1.1105e-03, -1.2767e+01, -1.1126e+01, -1.6005e+01,
    #          -2.0953e+01, -2.3342e+01, -6.8246e+00, -1.2127e+01, -1.6131e+01]],
    #        grad_fn= < LogSoftmaxBackward0 >)
    #接下来我们要用到独热编码的思想,我们取最大的数,也就是最高的概率对应得下标,就相当于这个最高概率对应得独热编码里面的1,其他是0
    pro = list(pre.detach().numpy()[0])
    pre_label=pro.index(max(pro))
    #print(pre_label)
    #图像显示
    img=torch.squeeze(imagess[i]).numpy()

    plt.subplot(8,8,i+1)
    plt.tight_layout()
    plt.imshow(img,cmap='gray',interpolation='none')
    plt.title(f"预测值:{pre_label}",fontproperties=font, fontsize=9)
    plt.xticks([])
    plt.yticks([])
plt.show()

















结果及其图像显示

D:\Anaconda3\envs\pytorch\python.exe D:\learn_pytorch\学习过程\第四周的代码\代码一:BP识别MNIST任务说明.py 
BP识别MNIST任务说明---------------------
D:\learn_pytorch\学习过程\第四周的代码\代码一:BP识别MNIST任务说明.py:49: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument.
  x=self.softmax(x)#最后输出的数值我们需要利用到独热编码的思想
第1轮的损失:815.2986399680376,抽取次数和:938
第2轮的损失:281.06414164602757,抽取次数和:1876
第3轮的损失:205.76270231604576,抽取次数和:2814
第4轮的损失:159.9431014917791,抽取次数和:3752
第5轮的损失:131.47989665158093,抽取次数和:4690
第6轮的损失:109.93954652175307,抽取次数和:5628
第7轮的损失:95.26143277343363,抽取次数和:6566
第8轮的损失:85.17149402853101,抽取次数和:7504
第9轮的损失:75.1239058477804,抽取次数和:8442
第10轮的损失:68.23363681556657,抽取次数和:9380
第11轮的损失:60.05981844640337,抽取次数和:10318
第12轮的损失:54.82598690944724,抽取次数和:11256
第13轮的损失:51.70861432119273,抽取次数和:12194
第14轮的损失:46.613128249999136,抽取次数和:13132
第15轮的损失:43.05269447225146,抽取次数和:14070

进程已结束,退出代码0

 可以看到,经过15轮得训练后,我们基本上已经完全能识别出来了。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1584546.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

LVM逻辑卷管理器

LVM是Linux系统对硬盘分区进行管理的一种机制&#xff0c;在硬盘分区和文件系统之间添加了一个逻辑层&#xff0c;它提供了一个抽象的卷组&#xff0c;可以把多块硬盘进行卷组合并。这样&#xff0c;用户无需关心物理硬盘设备的底层架构和布局&#xff0c;就可以实现对硬盘分区…

智过网:注册安全工程师注册有效期与周期解析

在职业领域&#xff0c;各种专业资格认证不仅是对从业者专业能力的认可&#xff0c;也是保障行业安全、规范发展的重要手段。其中&#xff0c;注册安全工程师证书在安全生产领域具有举足轻重的地位。那么&#xff0c;注册安全工程师的注册有效期是多久呢&#xff1f;又是几年一…

Unity 九宫格

1. 把图片拖拽进资源文件夹 2.选中图片&#xff0c;然后设置图片 3.设置九宫格 4.使用图片&#xff0c;在界面上创建2个相同的Image,然后使用图片&#xff0c;修改Image Type 为Sliced

图书推荐:《和AI一起编程》

《Coding with AI For Dummies》这本书由Chris Minnick撰写&#xff0c;主要分为四个部分&#xff0c;涵盖了与AI相结合的编程技术、AI编码工具的应用、利用AI编写代码的具体实践&#xff0c;以及测试、文档编制和维护代码的相关内容。 克里斯明尼克(Chris Minnick)&#xff1a…

【论文阅读】MCTformer: 弱监督语义分割的多类令牌转换器

【论文阅读】MCTformer: 弱监督语义分割的多类令牌转换器 文章目录 【论文阅读】MCTformer: 弱监督语义分割的多类令牌转换器一、介绍二、联系工作三、方法四、实验结果 Multi-class Token Transformer for Weakly Supervised Semantic Segmentation 本文提出了一种新的基于变换…

数字化浪潮下,制造业如何乘势而上实现精益生产

随着数字化技术的迅猛发展&#xff0c;制造业正迎来前所未有的变革机遇。本文将探讨如何利用数字化手段助推制造业实现精益生产&#xff0c;从而在激烈的市场竞争中脱颖而出。 1、构建智能化生产系统 借助物联网技术&#xff0c;实现设备之间的互联互通&#xff0c;构建智能化…

最祥解决python 将Dataframe格式数据上传数据库所碰到的问题

碰到的问题 上传Datafrane格式的数据到数据库 会碰见很多错误 举几个很普遍遇到的问题(主要以SqlServer举例) 这里解释下 将截断字符串或二进制数据 这个是字符长度超过数据库设置的长度 然后还有字符转int失败 或者字符串转换日期/或时间失败 这个是碰到的需要解决的最多的问…

比特币减半后 牛市爆发

作者&#xff1a;Arthur Hayes of Co-Founder of 100x 编译&#xff1a;Qin jin of ccvalue (以下内容仅代表作者个人观点&#xff0c;不应作为投资决策依据&#xff0c;也不应被视为参与投资交易的建议或意见&#xff09;。 Ping PingPing&#xff0c;我的手机发出的声音&…

词频统计程序

使用Hadoop MapReduce处理文本文件&#xff0c;Mapper负责将文本分割为单词&#xff0c;然后Reducer对每个单词进行计数&#xff0c;最后将结果写入输出文件。 // 定义WordCount公共类 public class WordCount {// 主入口方法&#xff0c;处理命令行参数public static void m…

C语言进阶|顺序表

✈顺序表的概念及结构 线性表&#xff08;linear list&#xff09;是n个具有相同特性的数据元素的有限序列。 线性表是一种在实际中广泛使 用的数据结构&#xff0c;常见的线性表&#xff1a;顺序表、链表、栈、队列、字符串.. 线性表在逻辑上是线性结构&#xff0c;也就说是连…

推荐学习什么编程语言?

选择编程语言学习时&#xff0c;除了就业因素外&#xff0c;还可以考虑以下几个方面来决定学习哪些编程语言&#xff1a; 个人兴趣与目标&#xff1a;如果你对某个特定领域感兴趣&#xff0c;比如游戏开发、数据分析、人工智能等&#xff0c;可以选择与该领域紧密相关的编程语言…

Python---【re库的使用】

目录&#xff1a; 一.re库简介 二.match方法 三.Match对象方法 四.使用search()方法进行匹配 五.使用findall()方法进行匹配 六.使用sub()方法替换字符串 七.使用split()方法分割字符串 一.re库简介 re库是Python用来实现“正则表达式”的库&#xff0c;并且re库在Pyth…

使用 nginx 服务器部署Vue项目

安装nginx 文本代理服务器 centos下载 注意需要root权限 在CentOS服务器上下载Nginx可以通过以下步骤完成&#xff1a; 更新系统软件包列表&#xff1a; yum update 安装EPEL存储库&#xff08;Extra Packages for Enterprise Linux&#xff09;&#xff1a; yum install…

分类预测 | Matlab实现KPCA-IDBO-LSSVM基于核主成分分析和改进蜣螂优化算法优化最小二乘支持向量机分类预测

分类预测 | Matlab实现KPCA-IDBO-LSSVM基于核主成分分析和改进蜣螂优化算法优化最小二乘支持向量机分类预测 目录 分类预测 | Matlab实现KPCA-IDBO-LSSVM基于核主成分分析和改进蜣螂优化算法优化最小二乘支持向量机分类预测分类效果基本描述程序设计参考资料 分类效果 基本描述…

Visual Studio C++ 正确创建项目与更改文件名

1、创建项目 1&#xff09;打开Visual Studio&#xff0c;选择创建新项目。 2&#xff09;创建空项目 3&#xff09;配置新项目&#xff0c;注意不要勾选 " 将解决方案和项目放在同一目录中 " 。并将位置的文件夹设为与解决方案同名&#xff0c;方便管理。项目名称则…

《零秒思考》像麦肯锡精英一样思考 - 三余书屋 3ysw.net

零秒思考&#xff1a;像麦肯锡精英一样思考 大家好&#xff0c;今天我们要深入探讨的著作是《零秒思考》。在领导提出问题时&#xff0c;我们常常会陷入沉思&#xff0c;却依然难以有所进展&#xff0c;仿佛原地踏步&#xff0c;但是身边的同事却能够立即给出清晰的回答。这种…

乡村智慧化升级:数字乡村打造农村生活新品质

目录 一、乡村智慧化升级的内涵与意义 二、乡村智慧化升级的具体实践 1、加强农村信息基础设施建设 2、推广智慧农业应用 3、提升乡村治理智慧化水平 4、丰富智慧乡村生活内容 三、数字乡村打造农村生活新品质的成果展现 1、农业生产效率与质量双提升 2、农民收入与消…

学习Rust的第二天:Cargo

We dive into Cargo, the powerful and convenient build system and package manager for Rust. 基于Steve Klabnik的《The Rust Programming Language》一书&#xff0c;我们深入了解Cargo&#xff0c;这是Rust强大而方便的构建系统和包管理器。 Cargo is a robust and effic…

Linux文件与目录的默认权限和隐藏权限

&#x1f407;明明跟你说过&#xff1a;个人主页 &#x1f3c5;个人专栏&#xff1a;《Kubernetes航线图&#xff1a;从船长到K8s掌舵者》 &#x1f3c5; &#x1f516;行路有良友&#xff0c;便是天堂&#x1f516; 目录 一、Linux的起源与发展 二、文件默认权限&#xf…

vitepress/vite vue3 怎么实现vue模版字符串实时编译

如果是vue模版字符串的话&#xff0c;先解析成模版对象 另一篇文章里有vue模版字符串解析成vue模版对象-CSDN博客 //vue3写法&#xff08;vue2可以用new Vue.extend(vue模版对象)来实现&#xff09;import { createApp, defineComponent } from vue;// 定义一个简单的Vue组件c…