动手学深度学习(Pytorch版)代码实践 -卷积神经网络-26网络中的网络NiN

news2025/5/24 18:02:24

26网络中的网络NiN

在这里插入图片描述

import torch
from torch import nn
import liliPytorch as lp
import matplotlib.pyplot as plt

# 定义一个NiN块
def nin_block(in_channels, out_channels, kernel_size, strides, padding):
    return nn.Sequential(
        # 传统的卷积层
        nn.Conv2d(in_channels, out_channels, kernel_size, strides, padding),
        nn.ReLU(),  # 激活函数ReLU
        # 1x1卷积层
        nn.Conv2d(out_channels, out_channels, kernel_size=1),
        nn.ReLU(),  
        # 另一个1x1卷积层
        nn.Conv2d(out_channels, out_channels, kernel_size=1),
        nn.ReLU()   
    )

# 设置dropout的概率
dropout = 0.5 

# 定义NiN模型
net = nn.Sequential(
    # 第一个NiN块,输入通道数为1,输出通道数为96
    nin_block(1, 96, kernel_size=11, strides=4, padding=0),
    # 最大池化层
    nn.MaxPool2d(kernel_size=3, stride=2),
    # 第二个NiN块,输入通道数为96,输出通道数为256
    nin_block(96, 256, kernel_size=5, strides=1, padding=2),
    # 最大池化层
    nn.MaxPool2d(kernel_size=3, stride=2),
    # 第三个NiN块,输入通道数为256,输出通道数为384
    nin_block(256, 384, kernel_size=3, strides=1, padding=1),
    # 最大池化层
    nn.MaxPool2d(kernel_size=3, stride=2),
    # Dropout层,用于防止过拟合
    nn.Dropout(dropout),

    # 最后一个NiN块,输入通道数为384,输出通道数为10
    nin_block(384, 10, kernel_size=3, strides=1, padding=1),
    # 全局平均池化层,将特征图的每个通道的空间维度调整为1x1
    nn.AdaptiveAvgPool2d((1, 1)),
    nn.Flatten()
)

X = torch.rand(size=(1, 1, 224, 224))
for layer in net:
    X = layer(X)
    print(layer.__class__.__name__,'output shape:\t', X.shape)
"""
Sequential output shape:         torch.Size([1, 96, 54, 54])
MaxPool2d output shape:  torch.Size([1, 96, 26, 26])
Sequential output shape:         torch.Size([1, 256, 26, 26])
MaxPool2d output shape:  torch.Size([1, 256, 12, 12])
Sequential output shape:         torch.Size([1, 384, 12, 12])
MaxPool2d output shape:  torch.Size([1, 384, 5, 5])
Dropout output shape:    torch.Size([1, 384, 5, 5])
Sequential output shape:         torch.Size([1, 10, 5, 5])
AdaptiveAvgPool2d output shape:  torch.Size([1, 10, 1, 1])
Flatten output shape:    torch.Size([1, 10])
"""

lr, num_epochs, batch_size = 0.1, 10, 128
train_iter, test_iter = lp.loda_data_fashion_mnist(batch_size, resize=224)
lp.train_ch6(net, train_iter, test_iter, num_epochs, lr, lp.try_gpu())
plt.show()  # 显示绘图
# loss 0.342, train acc 0.873, test acc 0.871
# 1395.1 examples/sec on cuda:0

运行结果:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1854156.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

个人成长的利器:复盘教你如何避免重蹈覆辙

前言 📫 大家好,我是南木元元,热爱技术和分享,欢迎大家交流,一起学习进步! 🍅 个人主页:南木元元 最近忙着学习和工作,更新比较少,期间一直在思考如何才能快速…

BLDC无感控制策略

本文根据 BLDC 的电路模型推导了一个简 化磁链方程来估计转子位置,转速适用范围较 广;重点分析了反电动势和换相电流对转矩脉动 的影响;设计了一种BLDC的无速度传感器高速 驱动控制方案。通过试验验证了新型控制策略 的性能。 1 低速时的转子位置检测 图1 为高速无刷直流电…

高职人工智能专业实训课之“图像识别基础”

一、前言 随着人工智能技术的迅猛发展,高职院校对人工智能专业实训课程的需求日益迫切。唯众人工智能教学实训平台作为一所前沿的教育技术平台,致力于为学生提供高效、便捷的人工智能实训环境,特别在“图像识别基础”这一关键课程中&#xf…

四川汇聚荣科技有限公司怎么样?

在探讨一家科技公司的综合实力时,我们往往从多个维度进行考量,包括但不限于公司的发展历程、产品与服务的质量、市场表现、技术创新能力以及企业文化。四川汇聚荣科技有限公司作为一家位于中国西部的科技企业,其表现和影响力自然也受到业界和…

从零开始使用Surya-OCR——检测后的精细化处理框1:降噪二值图下的空白检测框删除

目录 一、动机 二、降噪二值化处理 1.一般二值化处理 2.降噪二值化处理 三、图片区域空白框判断 1.计算区域黑色像素比重 2.设置阈值筛选空白区域 3.可视化检查结果 一、动机 在使用 Surya 检测文本框时,对于一些特殊的文本,尤其是中文的古籍等,存在检测不准确的问题。常常…

国产AI算力训练大模型技术实践

ChatGPT引领AI大模型热潮,国内外模型如雨后春笋,掀起新一轮科技浪潮。然而,国内大模型研发推广亦面临不小挑战。面对机遇与挑战,我们需保持清醒,持续推进技术创新与应用落地。 为应对挑战,我们需从战略高度…

Program-of-Thoughts(PoT):结合Python工具和CoT提升大语言模型数学推理能力

Program of Thoughts Prompting:Disentangling Computation from Reasoning for Numerical Reasoning Tasks github:https://github.com/wenhuchen/Program-of-Thoughts 一、动机 数学运算和金融方面都涉及算术推理。先前方法采用监督训练的形式,但这…

【git1】指令,commit,免密

文章目录 1.常用指令:git branch查看本地分支, -r查看远程分支, -a查看本地和远程,-v查看各分支最后一次提交, -D删除分支2.commit规范:git commit进入vi界面(进入前要git config core.editor vim设一下vi模…

《王者荣耀》国际服全球上线《Honor of Kings》海外下载榜首

原标题:《Honor of Kings》全球上线,国际玩家见证中国游戏魅力 易采游戏网6月23日独家消息:《王者荣耀》国际服《Honor of Kings》正式在全球160多个国家和地区上线,标志着这款源自中国的热门手机游戏迈向了国际舞台。尤其在加拿大…

Java面试八股之JVM永久代会发生垃圾回收吗

JVM永久代会发生垃圾回收吗 JVM的永久代(PermGen)在Java 8之前是存在的一部分,主要用于存储类的元数据、常量池、静态变量等。在这些版本中,永久代确实会发生垃圾回收,尤其是在永久代空间不足或超过某个阈值时&#x…

我在高职教STM32——LCD液晶显示(3)

大家好,我是老耿,高职青椒一枚,一直从事单片机、嵌入式、物联网等课程的教学。对于高职的学生层次,同行应该都懂的,老师在课堂上教学几乎是没什么成就感的。正因如此,才有了借助 CSDN 平台寻求认同感和成就…

【Linux详解】冯诺依曼架构 | 操作系统设计 | 斯坦福经典项目Pintos

目录 一. 冯诺依曼体系结构 (Von Neumann Architecture) 注意事项 存储器的意义:缓冲 数据流动示例 二. 操作系统 (Operating System) 操作系统的概念 操作系统的定位与目的 操作系统的管理 系统调用和库函数 操作系统的管理: sum 三. 系统调…

数据类型 运算符

基本数据类型与引用数据类型的区分 存储内容: 基本数据类型:直接存储实际的数据值,如整数、浮点数、字符等。引用数据类型:存储对象的引用(内存地址),而不是对象本身。 内存分配: 基…

Qt——系统

目录 概述 事件 鼠标事件 进入、离开事件 按下事件 释放事件 双击事件 移动事件 滚轮事件 按键事件 单个按键 组合按键 定时器 QTimerEvent QTimer 窗口事件 文件 输入输出设备 文件读写类 文件和目录信息类 多线程 常用API 线程安全 互斥锁 条件变量…

matplotlib之常见图像种类

Matplotlib 是一个用于绘制图表和数据可视化的 Python 库。它支持多种不同类型的图形,以满足各种数据可视化需求。以下是一些 Matplotlib 支持的主要图形种类: 折线图(Line Plot): 用于显示数据随时间或其他连续变量的…

珈和科技和比昂科技达成战略合作,共创智慧农业领域新篇章

6月14日,四川省水稻、茶叶病虫害监测预警与绿色防控培训班在成都蒲江举办。本次培训班由四川省农业农村厅植物保护站主办,蒲江县农业农村局、成都比昂科技筹办。四川省农业农村厅植物保护站及四川省14个市州36个县植保站负责人进行了观摩学习。 武汉珈…

Python中的性能分析和优化

在前几篇文章中,我们探讨了Python中的异步编程和并发编程,以及如何结合使用这些技术来提升程序性能。今天,我们将深入探讨如何分析以及优化Python代码的性能,确保应用程序的高效运行! 性能分析的基本工具和方法 在进…

[系统运维|Xshell]宿主机无法连接上NAT网络下的虚拟机进行维护?主机ping不通NAT网络下的虚拟机,虚拟机ping的通主机!解决办法

遇到的问题:主机ping不通NAT网络下的虚拟机,虚拟机ping的通主机 服务器:Linux(虚拟机) 主机PC:Windows 虚拟机:vb,vm测试过没问题,vnc没测试不清楚 虚拟机网络&#xff1…

cve-2015-3306-proftpd-vulfocus

1.原理 proftp是用于搭建基于ftp协议的应用软件 ProFTPD是ProFTPD团队的一套开源的FTP服务器软件。该软件具有可配置性强、安全、稳定等特点。 ProFTPD 1.3.5中的mod_copy模块允许远程攻击者通过站点cpfr和site cpto命令读取和写入任意文件。任何未经身份验证的客户端都可以…

牛客周赛Round48

第一题 A-小红的整数自增 链接:登录—专业IT笔试面试备考平台_牛客网 来源:牛客网 小红拿到了三个正整数。她准备进行若干次操作,每次操作选择一个元素加1。小红希望最终三个数相等,请你帮小红求出最小的操作次数。 思路&#x…