动手学深度学习:CNN和LeNet

news2025/5/15 20:40:33

前言

该篇文章记述从零如何实现CNN,以及LeNet对于之前数据集分类的提升效果。

从零实现卷积核

import torch
def conv2d(X,k):
    h,w=k.shape
    Y=torch.zeros((X.shape[0]-h+1,X.shape[1]-w+1))
    for i in range(Y.shape[0]):
        for j in range(Y.shape[1]):
            Y[i,j]=(X[i:i+h,j:j+w]*k).sum()
    return Y
X=torch.tensor([[0.,1.,2.],[3.,4.,5.],[6.,7.,8.]])
k=torch.tensor([[0.,1.],[2.,3.]])
conv2d(X,k)

在这里插入图片描述

卷积层

from torch import nn
class Conv2D(nn.Module):
    def __init__(self,kernel_size):
        super.__init__()
        self.weight=nn.Parameter(torch.rand(kernel_size))
        self.bias=nn.Parameter(torch.zeros(1))
    def forward(self,x):
        return conv2d(x,self.weight)+self.bias

验证卷积层对于图像的检测作用

x=torch.ones((6,8))
x[:,2:6]=0
x

在这里插入图片描述

k=torch.tensor([[1.0,-1.0]])
y=conv2d(x,k)
y

在这里插入图片描述
很明显这个卷积核提取到了垂直上的特征

conv2d(x.t(),k)

在这里插入图片描述
并没有学习到水平特征

学习卷积核

我们可以让卷积核自己学习里面的参数以达到对不同图像提取的作用

conv2d=nn.Conv2d(1,1,kernel_size=(1,2),bias=False)

x=x.reshape((1,1,x.shape[0],x.shape[1]))
y=y.reshape((1,1,6,7))
lr=3e-2

for i in range(10):
    y_hat=conv2d(x)
    l=(y_hat-y)**2
    conv2d.zero_grad()
    l.sum().backward()
    
    conv2d.weight.data[:]-=lr*conv2d.weight.grad
    print(f"第{i}轮,loss为{l.sum()}")

在这里插入图片描述

conv2d.weight.data

在这里插入图片描述

填充

def comp_conv2d(conv2d,x):
    #(1,1)添加batch大小和通道数
    x=x.reshape((1,1)+x.shape)
    y=conv2d(x)
    return y.reshape(y.shape[2:])
conv2d=nn.Conv2d(1,1,kernel_size=3,padding=1)
x=torch.rand(size=(8,8))
comp_conv2d(conv2d,x).shape

在这里插入图片描述

conv2d=nn.Conv2d(1,1,kernel_size=(5,3),padding=(2,1))
x=torch.rand(size=(8,8))
comp_conv2d(conv2d,x).shape

在这里插入图片描述

步幅

conv2d=nn.Conv2d(1,1,kernel_size=(3,3),padding=1,stride=2)
x=torch.rand(size=(8,8))
comp_conv2d(conv2d,x).shape

在这里插入图片描述

多通道

from d2l import torch as d2l
def corr2d_multi_in(X,K):
    return sum(d2l.corr2d(x,k) for x,k in zip(X,K))
x=torch.randn(size=(4,2,3))
k=torch.randn(size=(4,1,3))
corr2d_multi_in(x,k)

在这里插入图片描述

多输出通道

def corr2d_multi_in_out(X,K):
    return torch.stack([corr2d_multi_in(X,k)for k in K],0)
K=torch.stack((k,k+1,k+2),0)
K.shape

在这里插入图片描述

corr2d_multi_in_out(x,K)

在这里插入图片描述

1x1卷积

def corr2d_multi_in_out_1x1(X,K):
    c_i,h,w=X.shape
    c_o=K.shape[0]
    X=X.reshape((c_i,h*w))
    K=K.reshape((c_o,c_i))
    Y=torch.matmul(K,X)
    return Y.reshape((c_o,h,w))
X=torch.normal(0,1,(3,3,3))
K=torch.normal(0,1,(2,3,1,1))
Y1=corr2d_multi_in_out_1x1(X,K)
Y2=corr2d_multi_in_out(X,K)
Y1==Y2

在这里插入图片描述

汇聚层

def pool2d(x,pool_size,mode='max'):
    p_h,p_w=pool_size
    Y=torch.zeros((X.shape[0]-p_h+1,X.shape[1]-p_w+1))
    for i in range(Y.shape[0]):
        for j in range(Y.shape[1]):
            if mode=='max':
                Y[i,j]=X[i:i+p_h,j:j+p_w].max()
            elif mode=='avg':
                Y[i,j]=X[i:i+p_h,j:j+p_w].mean()
    return Y
X=torch.tensor([[0.0,1.,2.],[3.,4.,5.],[6.,7.,8.]])
pool2d(X,(2,2))

在这里插入图片描述

pool2d(X,(2,2),'avg')

在这里插入图片描述

LeNet

这是最早的神经网络,根据我的测试,这个模型在我的数据集上的效果比MLP要提高了1%以上,在这段时间里面,我页发现了原有数据集在分类上存在问题,所以重新制作了一份,在这份数据集上,随着我数据量的提升以及模型的修改,准确率达到了99.7%,且无过拟合现象。

原始的LeNet

from torch import nn
net=nn.Sequential(
nn.Conv2d(1,6,kernel_size=5,padding=2),nn.Sigmoid(),
nn.AvgPool2d(kernel_size=2,stride=2),
nn.Conv2d(6,16,kernel_size=5),nn.Sigmoid(),
nn.AvgPool2d(kernel_size=2,stride=2),
nn.Flatten(),
nn.Linear(16*3*5,120),nn.Sigmoid(),
nn.Linear(120,84),nn.Sigmoid(),
nn.Linear(84,9))
def init_weight(m):
    if type(m)==nn.Linear or type(m)==nn.Conv2d:
        nn.init.xavier_uniform_(m.weight)
net.apply(init_weight)

在这里插入图片描述

测试结果

我忘记截图了,效果达到了99%以上,同样的数据集在MLP上是98%

改进后的LeNet

第一版

我将平均池化层改成了最大池化层

net=nn.Sequential(
nn.Conv2d(1,6,kernel_size=5,padding=2),nn.Sigmoid(),
nn.MaxPool2d(kernel_size=2,stride=2),
nn.Conv2d(6,16,kernel_size=5),nn.Sigmoid(),
nn.MaxPool2d(kernel_size=2,stride=2),
nn.Flatten(),
nn.Linear(16*3*5,120),nn.Sigmoid(),
nn.Linear(120,84),nn.Sigmoid(),
nn.Linear(84,9))
def init_weight(m):
    if type(m)==nn.Linear or type(m)==nn.Conv2d:
        nn.init.xavier_uniform_(m.weight)
net.apply(init_weight)

在这里插入图片描述

训练修改

我在训练过程中添加了记录test的loss最低时,保存pt和onnx,用于后续推理。

epochs_num=100
train_len=len(train_iter.dataset)
all_acc=[]
all_loss=[]
test_all_acc=[]
shape=None
for epoch in range(epochs_num):
    acc=0
    loss=0
    for x,y in train_iter:
        hat_y=net(x)
        l=loss_fn(hat_y,y)
        loss+=l
        optimer.zero_grad()
        l.backward()
        optimer.step()
        acc+=(hat_y.argmax(1)==y).sum()
    all_acc.append(acc/train_len)
    all_loss.append(loss.detach().numpy())
    test_acc=0
    test_loss=0
    test_len=len(test_iter.dataset)
    with torch.no_grad():
        for x,y in test_iter:
            shape=x.shape
            hat_y=net(x)
            test_loss+=loss_fn(hat_y,y)
            test_acc+=(hat_y.argmax(1)==y).sum()
    test_all_acc.append(test_acc/test_len)
    print(f'{epoch}的test的acc{test_acc/test_len}')
    # 保存测试损失最小的模型
    if test_loss < best_test_loss:
        best_test_loss = test_loss
        torch.save(net, best_model_path)
        dummy_input = torch.randn(shape)  
        torch.onnx.export(net, dummy_input, "./models/LeNet5.onnx", opset_version=11)
        print(f'Saved better model with Test Loss: {best_test_loss:.4f}')

在这里插入图片描述

损失函数可视化

plt.plot(range(1,epochs_num+1),all_loss,'.-',label='train_loss')
plt.text(epochs_num, all_loss[-1], f'{all_loss[-1]:.4f}', fontsize=12, verticalalignment='bottom')

在这里插入图片描述

准确率可视化

plt.plot(range(1,epochs_num+1),all_acc,'-',label='train_acc')
plt.text(epochs_num, all_acc[-1], f'{all_acc[-1]:.4f}', fontsize=12, verticalalignment='bottom')
plt.plot(range(1,epochs_num+1),test_all_acc,'-.',label='test_acc')
plt.legend()

在这里插入图片描述

预测结果

import numpy as np
with torch.no_grad():
    all_num=5
    index=1
    plt.figure(figsize=(12,5))
    for i,label in zip(test_data_path,test_labels):
        if index<=all_num:
            img=cv2.imread(i)
            input_img=cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)
            img=cv2.cvtColor(input_img,cv2.COLOR_BGR2RGB)
            input_img = np.expand_dims(input_img, axis=2)  # 增加通道维度,形状变为 [1, H, W]
            input_img=transforms.ToTensor()(input_img)
            input_img = input_img.unsqueeze(0)  # 增加批量维度,形状变为 [1, 1, 28, 20]
            print(input_img.shape)
            result=net(input_img).argmax(1)
            plt.subplot(1,all_num,index)
            plt.imshow(img)
            plt.title(f'true{label},predict{result.detach().numpy()}')
            plt.axis("off")
            index+=1

在这里插入图片描述

第二版

我将sigmoid激活函数换成了ReLU函数,发现最终的收敛速度极快,损失值也下降了一些,test的acc上升了0.04%,虽然不多,但是训练时间极大减少。

net=nn.Sequential(
nn.Conv2d(1,6,kernel_size=5,padding=2),nn.ReLU(),
nn.MaxPool2d(kernel_size=2,stride=2),
nn.Conv2d(6,16,kernel_size=5),nn.ReLU(),
nn.MaxPool2d(kernel_size=2,stride=2),
nn.Flatten(),
nn.Linear(16*3*5,120),nn.ReLU(),
nn.Linear(120,84),nn.ReLU(),
nn.Linear(84,9))
def init_weight(m):
    if type(m)==nn.Linear or type(m)==nn.Conv2d:
        nn.init.xavier_uniform_(m.weight)
net.apply(init_weight)

其他部分同之前一样。

训练过程

在这里插入图片描述

损失函数可视化

在这里插入图片描述

准确率可视化

在这里插入图片描述

总结

数据集收集过程中遇到了部分麻烦,数据集还不够完整。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2317857.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

删除排序链表中的重复元素(js实现,LeetCode:83)

看到这道题的第一反应是使用快慢指针&#xff0c;之前做过类似的题&#xff1a;删除有序数组中的重复项&#xff08;js实现&#xff0c;LeetCode&#xff1a;26&#xff09;原理都是一样,区别是这题需要将重复项删除&#xff0c;所以只需要走一遍单循环就可以实现 /*** Defini…

单片机自学总结

自从工作以来&#xff0c;一直努力耕耘单片机&#xff0c;至今&#xff0c;颇有收获。从51单片机&#xff0c;PIC单片机&#xff0c;直到STM32&#xff0c;以及RTOS和Linux&#xff0c;几乎天天在搞:51单片机&#xff0c;STM8S207单片机&#xff0c;PY32F003单片机&#xff0c;…

Unity教程(二十二)技能系统 分身技能

Unity开发2D类银河恶魔城游戏学习笔记 Unity教程&#xff08;零&#xff09;Unity和VS的使用相关内容 Unity教程&#xff08;一&#xff09;开始学习状态机 Unity教程&#xff08;二&#xff09;角色移动的实现 Unity教程&#xff08;三&#xff09;角色跳跃的实现 Unity教程&…

HTML5扫雷游戏开发实战

HTML5扫雷游戏开发实战 这里写目录标题 HTML5扫雷游戏开发实战项目介绍技术栈项目架构1. 游戏界面设计2. 核心类设计 核心功能实现1. 游戏初始化2. 地雷布置算法3. 数字计算逻辑4. 扫雷功能实现 性能优化1. DOM操作优化2. 算法优化 项目亮点技术难点突破1. 首次点击保护2. 连锁…

【Git学习笔记】Git分支管理策略及其结构原理分析

【Git学习笔记】Git分支管理策略及其结构原理分析 &#x1f525;个人主页&#xff1a;大白的编程日记 &#x1f525;专栏&#xff1a;Git学习笔记 文章目录 【Git学习笔记】Git分支管理策略及其结构原理分析前言一.合并冲突二. 分支管理策略2.1 分支策略2.2 bug分支2.3 删除临…

Spring Cloud Alibaba Nacos 2023.X 配置问题

文章目录 问题现象&#xff08;一&#xff09;解决方法&#xff08;一&#xff09;问题现象&#xff08;二&#xff09;解决方法&#xff08;二&#xff09;问题现象&#xff08;三&#xff09;解决方法&#xff08;三&#xff09; 问题现象&#xff08;一&#xff09; Spring…

厨卫行业供应链产销协同前中后大平台现状需求分析报告+P120(120页PPT)(文末有下载方式)

资料解读&#xff1a;厨卫行业供应链产销协同前中后大平台现状需求分析报告 详细资料请看本解读文章的最后内容。在当前厨卫行业竞争激烈的市场环境下&#xff0c;企业的发展战略和业务模式创新至关重要。本次解读的报告围绕某厨卫企业展开&#xff0c;深入探讨其供应链产销协同…

我在哪,要去哪

在直播间听到一首好听的歌《我在哪&#xff0c;要去哪》-汤倩。 遇见的事&#xff1a;21~24号抽调去招生。 感受到的情绪&#xff1a;公假吗&#xff1f;给工作量吗&#xff1f;月工作量不够扣钱吗&#xff1f;报销方便吗&#xff1f;有事情&#xff0c;从来不解决后顾&#x…

SpringBoot-2整合MyBatis以及基本的使用方法

目录 1.引入依赖 2.数据库表的创建 3.数据源的配置 4.编写pojo类 5.编写controller类 6.编写接口 7.编写接口的实现类 8.编写mapper 1.引入依赖 在pom.xml引入依赖 <!-- mysql--><dependency><groupId>com.mysql</groupId><artifac…

本周安全速报(2025.3.11~3.17)

合规速递 01 瑞士出台新规&#xff1a;关基设施遭遇网络攻击需在24小时内上报 原文: https://www.bleepingcomputer.com/news/security/swiss-critical-sector-faces-new-24-hour-cyberattack-reporting-rule/ 新规要求&#xff0c;关键基础设施组织发现网络攻击后&…

【css酷炫效果】纯CSS实现瀑布流加载动画

【css酷炫效果】纯CSS实现瀑布流加载动画 缘创作背景html结构css样式完整代码基础版进阶版(无限往复加载) 效果图 想直接拿走的老板&#xff0c;链接放在这里&#xff1a;https://download.csdn.net/download/u011561335/90492012 缘 创作随缘&#xff0c;不定时更新。 创作…

咖啡点单小程序毕业设计(JAVA+SpringBoot+微信小程序+完整源码+论文)

✌全网粉丝20W,csdn特邀作者、博客专家、CSDN新星计划导师、java领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ &#x1f345;文末获取项目下载方式&#x1f345; 一、项目背景介绍&#xff1a; 随着社会的快速发展和…

网络编程套接字【端口号/TCPUDP/网络字节序/socket编程接口/UDPTCP网络实验】

网络编程套接字 0. 前言1. 认识端口号2. 认识TCP和UDP协议3. 网络字节序4. socket编程接口5. 实现一个简单的UDP网络程序5.1 需求分析5.2 头文件准备5.3 服务器端设计5.4 客户端设计5.5 本地测试5.6 跨网络测试5.7 UDP小应用——客户端输入命令&#xff0c;服务器端执行 6. 地址…

DeepSeek 3FS 与 JuiceFS:架构与特性比较

近期&#xff0c;DeepSeek 开源了其文件系统 Fire-Flyer File System (3FS)&#xff0c;使得文件系统这一有着 70 多年历时的“古老”的技术&#xff0c;又获得了各方的关注。在 AI 业务中&#xff0c;企业需要处理大量的文本、图像、视频等非结构化数据&#xff0c;还需要应对…

Unity WebGL项目访问时自动全屏

Unity WebGL项目访问时自动全屏 打开TemplateData/style.css文件 在文件最下方添加红色框内的两行代码 使用vscode或者其他编辑器打开index.html 将按钮注释掉&#xff0c;并且更改为默认全屏

C# WPF编程-Menu

C# WPF编程-Menu 布局&#xff1a;代码&#xff1a;效果 在WPF&#xff08;Windows Presentation Foundation&#xff09;中&#xff0c;Menu控件用于创建下拉菜单或上下文菜单&#xff0c;它提供了丰富的定制选项来满足不同的应用需求。下面将介绍如何在WPF应用程序中使用Menu…

Docker和containerd之概览(Overview of Docker and Containerd)

Docker和containerd之概览 容器本质上就是一个进程。 Namespace是一种逻辑分组机制&#xff0c;允许您将集群资源划分为独立的虚拟环境。每个 Namespace 为资源提供了一个范围&#xff0c;使得不同的团队、应用程序或环境可以在同一集群中共存&#xff0c;而不会相互干扰。 C…

【多线程】线程不安全问题

文章目录 多线程不安全的原因大的层面->多线程是随机调度的容易产生死锁 小的层面->内存不可见性引入volatile关键字 指令重排序不是原子性带来的隐患 synchronized锁的互斥性及作用可重入性——解决死锁 wait()和notify()两个突然迸发出的疑问 多线程不安全的原因 大的…

【C++】树和二叉树的实现(下)

本篇博客给大家带来的是用C语言来实现数据结构树和二叉树的实现&#xff01; &#x1f41f;&#x1f41f;文章专栏&#xff1a;数据结构 &#x1f680;&#x1f680;若有问题评论区下讨论&#xff0c;我会及时回答 ❤❤欢迎大家点赞、收藏、分享&#xff01; 今日思想&#xff…

kafka指北

为自己总结一下kafka指北&#xff0c;会持续更新。创作不易&#xff0c;转载请注明出处。 目录 集群controller选举过程broker启动流程 主题创建副本分布ISRleader副本选举机制LEO 生产数据流程同步发送和异步发送 分区策略ack应答生产者发送消息的幂等性跨分区幂等性问题&…