现代卷积神经网络(NiN),并使用NIN训练CIFAR10的分类

news2025/7/21 0:58:33

专栏:神经网络复现目录


本章介绍的是现代神经网络的结构和复现,包括深度卷积神经网络(AlexNet),VGG,NiN,GoogleNet,残差网络(ResNet),稠密连接网络(DenseNet)。
文章部分文字和代码来自《动手学深度学习》

文章目录

  • 网络中的网络(NiN)
    • 简介
    • 全局平均汇聚层
    • 和VGG的区别
    • 优点
  • 网络结构
    • 定义
    • 实现
  • 实战(CIFAR10分类)
    • 模型设计
    • 导入模块
    • 数据集
    • 训练和评估
    • 保存模型
    • 测试


网络中的网络(NiN)

简介

NiN,即Network in Network,是一种由Min Lin等人于2013年提出的深度神经网络架构。相较于传统的卷积神经网络,NiN引入了“1x1卷积层”以及全局平均池化层,这两个模块的引入使得NiN模型可以通过不同的网络层抽取多层次、多尺度的特征。

NiN的核心思想是通过在卷积神经网络的最后两层之间添加一个或多个全局平均池化层(global average pooling layer),将每个空间位置的特征合并为一个全局统计值,从而减少模型的参数量,防止过拟合。与全连接层相比,全局平均池化层能够更好地提取特征,因为它不依赖于特定的位置信息,而是能够更好地保留特征图的整体信息。此外,NiN中还引入了1x1卷积层,它的作用在于改变通道数。通过对特征通道进行1x1卷积,可以让不同通道之间的特征在保留其信息的同时,使得这些特征更加明确,进而提高模型的准确率。

NiN的整体结构包括多个NiN块,每个NiN块包括一个卷积层、多个1x1卷积层和全局平均池化层。其中,NiN块的卷积层使用了普通的卷积层,而不是分离卷积层,因此模型可以学习到不同特征尺度之间的相互关系。NiN的最后一层通常是全局平均池化层和全连接层,用于分类。由于NiN中引入了全局平均池化层和1x1卷积层,因此NiN模型相较于传统的卷积神经网络具有更好的性能和更少的参数量。

全局平均汇聚层

全局平均汇聚层(Global Average Pooling Layer)是一种常见的卷积神经网络结构,在最后一层卷积层后面添加一个全局平均汇聚层,用于将卷积层的输出进行降维,从而得到一个固定长度的特征向量。

与全连接层相比,全局平均汇聚层具有以下优点:

  1. 参数少:全局平均汇聚层不包含可训练参数,因此它的计算量比全连接层小,可以避免过拟合。

  2. 降低过拟合:由于全局平均汇聚层可以对每个通道进行平均池化,因此其对空间信息的保留比较好,可以提高网络的泛化能力,从而减少过拟合。

  3. 不依赖于输入大小:由于全局平均汇聚层的输出特征向量的长度与输入大小无关,因此可以处理不同大小的输入。

  4. 更适合于可视化:全局平均汇聚层将卷积层的输出降维到一个固定长度的特征向量,这个特征向量可以更方便地进行可视化,帮助理解模型的行为。

总之,全局平均汇聚层是一种有效的降维方法,可以帮助神经网络提取更加具有代表性的特征,并在一定程度上避免过拟合。

全局平均汇聚层的逻辑如下:

  1. 对于输入的 feature map,每个通道内的所有值进行平均得到一个标量输出。

  2. 对于输入 feature map 的每个通道,都执行第 1 步,得到通道内所有值的平均值。

  3. 将每个通道的平均值串联起来,作为输出。

全局平均汇聚层的输出是对输入特征图的每个通道执行平均池化后的结果,通常是一个大小为 1x1x通道数 的特征图。这个特征图可以被视为每个通道的特征的“汇聚”,因为它将该通道的所有信息转换为一个标量值。这种操作通常在最后的卷积层后使用,以将卷积层输出的高维特征图压缩为低维的特征向量,方便后续的全连接层处理。

和VGG的区别

下图对比了VGG和NIN的区别:
在这里插入图片描述
它们的区别在于:

卷积层的组成:NIN的卷积层采用了一种被称为1x1卷积的技术,这种技术可以增加非线性变换的能力,同时可以减少参数数量;而VGG采用的是更加传统的3x3卷积层,这种卷积层可以保留更多的空间信息,但参数数量相对较多。

池化层的位置:NIN的池化层通常放置在卷积层中间,这样可以使得特征图的大小降低,同时又不会丢失太多的信息;而VGG则采用了较为传统的做法,将池化层放在每两个卷积层之间,这样可以使得特征图的大小减半,同时也会丢失一些信息。

全连接层的设计:NIN使用全局平均池化层代替了传统的全连接层,这样可以有效减少参数数量,同时也可以避免过拟合;而VGG则采用了多个全连接层,可以更好地拟合复杂的非线性映射关系,但也增加了参数数量和过拟合的风险。

总的来说,NIN相对于VGG来说,参数数量更少,计算量更小,同时也可以获得更好的效果。但是,在某些任务中,比如图像分类,VGG表现仍然非常优秀,因为其较为复杂的网络结构可以更好地捕获图像的细节和纹理信息。

优点

NIN的优点包括:

  1. 更少的参数:NIN使用1x1卷积层来代替全连接层,因此参数数量显著减少,使得模型更容易训练,同时也能减轻过拟合。

  2. 网络结构更加简单:NIN的网络结构比较简单,包含了很少的层数和卷积核尺寸,使得它更容易理解和优化。

  3. 提高了模型的表达能力:NIN引入了MLPconv层,使得模型可以在特征空间中进行非线性变换,从而提高了模型的表达能力。

  4. 易于并行计算:由于NIN采用了1x1卷积层代替全连接层,使得卷积计算可以并行化处理,从而提高了计算效率。

网络结构

定义

NIN(Network In Network)是由Min Lin、Qiang Chen、Shuicheng Yan在2013年提出的深度神经网络。它的主要思想是在传统的卷积神经网络中嵌入由一个 MLP(多层感知机)组成的模块,称为“Micro Network”,以增加模型的非线性表达能力。

具体来说,NIN的网络结构包含三个由卷积层、MLP层和全局平均池化层组成的模块,其中MLP层用于替代传统卷积层的非线性映射。整个网络的最后一层是全连接层,用于分类。

以下是NIN的网络结构:

Input -> Conv (kernel size: 11x11, stride: 4, num filters: 96) -> ReLU ->
        Conv (kernel size: 1x1, stride: 1, num filters: 96) -> ReLU ->
        Conv (kernel size: 1x1, stride: 1, num filters: 96) -> ReLU ->
        MaxPool (kernel size: 3x3, stride: 2) -> Dropout (p=0.5) ->

        Conv (kernel size: 5x5, stride: 1, num filters: 256, padding: 2) -> ReLU ->
        Conv (kernel size: 1x1, stride: 1, num filters: 256) -> ReLU ->
        Conv (kernel size: 1x1, stride: 1, num filters: 256) -> ReLU ->
        MaxPool (kernel size: 3x3, stride: 2) -> Dropout (p=0.5) ->

        Conv (kernel size: 3x3, stride: 1, num filters: 384, padding: 1) -> ReLU ->
        Conv (kernel size: 1x1, stride: 1, num filters: 384) -> ReLU ->
        Conv (kernel size: 1x1, stride: 1, num filters: 384) -> ReLU ->
        MaxPool (kernel size: 3x3, stride: 2) -> Dropout (p=0.5) ->

        Conv (kernel size: 3x3, stride: 1, num filters: 1024, padding: 1) -> ReLU ->
        Conv (kernel size: 1x1, stride: 1, num filters: 1024) -> ReLU ->
        Conv (kernel size: 1x1, stride: 1, num filters: 10) -> ReLU ->
        GlobalAvgPool -> Output

其中,每个卷积层之后都跟随一个ReLU激活函数,用于引入非线性表达能力。全局平均池化层用于将每个卷积核的输出值求平均,进而计算输出特征图的每个通道的平均值,从而得到该通道对应的输出值。这种做法可以有效地降低特征图的维度,减少模型的参数量,进而减轻过拟合的风险

实现

def nin_block(in_channels, out_channels, kernel_size, stride, padding):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding),
        nn.ReLU(),
        nn.Conv2d(out_channels, out_channels, kernel_size=1),
        nn.ReLU(),
        nn.Conv2d(out_channels, out_channels, kernel_size=1),
        nn.ReLU())

net = nn.Sequential(
    nin_block(3, 96, kernel_size=11, stride=4, padding=0),
    nn.MaxPool2d(kernel_size=3, stride=2),
    nin_block(96, 256, kernel_size=5, stride=1, padding=2),
    nn.MaxPool2d(kernel_size=3, stride=2),
    nin_block(256, 384, kernel_size=3, stride=1, padding=1),
    nn.MaxPool2d(kernel_size=3, stride=2),
    nn.Dropout(0.5),
    # 标签类别数是10
    nin_block(384, 10, kernel_size=3, stride=1, padding=1),
    # 全局平均池化层可通过将窗口形状设置成输入的高和宽实现
    nn.AvgPool2d(kernel_size=5),
    # 将四维的输出转成二维的输出,其形状为(批量大小, 10)
    nn.Flatten())

实战(CIFAR10分类)

模型设计

import torch.nn as nn

def nin_block(in_channels, out_channels, kernel_size, stride, padding):
    block = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding),
        nn.ReLU(),
        nn.Conv2d(out_channels, out_channels, kernel_size=1),
        nn.ReLU(),
        nn.Conv2d(out_channels, out_channels, kernel_size=1),
        nn.ReLU())
    return block

class NiN(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nin_block(3, 96, kernel_size=5, stride=1, padding=2),
            nn.MaxPool2d(kernel_size=3, stride=2,padding=1),
            nn.Dropout(0.5),
            nin_block(96, 256, kernel_size=5, stride=1, padding=2),
            nn.MaxPool2d(kernel_size=3, stride=2,padding=1),
            nn.Dropout(0.5),
            # 标签类别数是10
            nin_block(256, 10, kernel_size=3, stride=1, padding=1),
            #全局平均代替最后的全连接层
            nn.AdaptiveAvgPool2d((1,1))
            )

    def forward(self,input):
        x = self.net(input)
        x = x.view(x.size(0), 10)
        #print(x.shape)
        return x

导入模块

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import time


# 获取设备,优先使用GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

数据集

# 加载数据集
def load_data_cifar10(batch_size, resize=None):
    transform_list = [transforms.ToTensor()]
    if resize:
        transform_list.insert(0, transforms.Resize(resize))
    transform = transforms.Compose(transform_list)
    cifar10_train = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    cifar10_test = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
    train_loader = torch.utils.data.DataLoader(cifar10_train, batch_size=batch_size, shuffle=True, num_workers=4)
    test_loader = torch.utils.data.DataLoader(cifar10_test, batch_size=batch_size, shuffle=False, num_workers=4)
    return train_loader, test_loader

训练和评估

def evaluate_accuracy(data_iter, net, device):
    net.eval()  # 评估模式
    acc_sum, n = 0.0, 0
    with torch.no_grad():
        for X, y in data_iter:
            X, y = X.to(device), y.to(device)
            acc_sum += (net(X).argmax(dim=1) == y).float().sum().cpu().item()
            n += y.shape[0]
    net.train()  # 改回训练模式
    return acc_sum / n

def train_nin(net, train_iter, test_iter, loss, optimizer, device, epochs):
    net = net.to(device)
    print("training on ", device)
    batch_count = 0
    for epoch in range(epochs):
        train_l_sum, train_acc_sum, n, start = 0.0, 0.0, 0, time.time()
        for X, y in train_iter:
            X, y = X.to(device), y.to(device)
            y_hat = net(X)
            l = loss(y_hat, y)
            optimizer.zero_grad()
            l.backward()
            optimizer.step()
            train_l_sum += l.cpu().item()
            train_acc_sum += (y_hat.argmax(dim=1) == y).sum().cpu().item()
            n += y.shape[0]
            batch_count += 1
        test_acc = evaluate_accuracy(test_iter, net, device)
        print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f, time %.1f sec' % (
            epoch + 1, train_l_sum / batch_count, train_acc_sum / n, test_acc, time.time() - start))
net = NiN()
batch_size=128
train_iter,test_iter=load_data_cifar10(batch_size,resize=224)
lr, num_epochs = 0.1, 10
loss = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9)
train_nin(net, train_iter, test_iter, loss, optimizer, device, num_epochs)

保存模型

torch.save(net.state_dict(),"nin.pth")

测试

import torch
import torchvision
import torchvision.transforms as transforms

# 加载数据集并进行预处理
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2)

# 定义类别标签
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

# 加载模型
net = NiN()
net.load_state_dict(torch.load("nin.pth"))

# 使用CPU或GPU进行预测
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net.to(device)

# 进行预测并输出结果
net.eval()
with torch.no_grad():
    for data in testloader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        for i in range(4):
            print('GroundTruth: %s, Predicted: %s' % (classes[labels[i]], classes[predicted[i]]))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/395296.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【3.7】Redis数据类型、CPU缓存一致性、哈希表

文章目录数据类型篇StringListHashSetZsetBitMapHyperLogLogGEOStreamCPU 缓存一致性CPU是如何执行任务的?什么是软中断?为什么0.1 0.2不等于0.3?哈希表数据类型篇 String String 是最基本的 key-value 结构,key 是唯一标识&…

03 | 授权服务:授权码和访问令牌的颁发流程是怎样的? 笔记

03 | 授权服务:授权码和访问令牌的颁发流程是怎样的? 授权服务的工作过程 小兔软件需要去到京东的平台那里”备案“注册,京东商家开放平台就会给小兔软件 app_id 和 app_secret 等信息,以方便后面授权时的各种身份校验&#xff0…

scratch绘制雷达 电子学会图形化编程scratch等级考试三级真题和答案解析2022年9月

目录 scratch绘制雷达 一、题目要求 1、准备工作 2、功能实现 二、案例分析

阶段二12_面向对象高级_继承1

一.继承的入门介绍 (1)继承的概念理解 让类与类之间产生关系(子父类关系),子类可以直接使用父类中非私有的成员 (2)通过extends关键字实现继承 格式:public class 子类名 extends 父类名 { } 范例:public class Zi e…

Grafana 如何使用本地CSV文件作为数据源

Grafana提供了一个插件,可以把CSV文件作为数据源,关于CSV插件的说明,可以参考:https://grafana.com/grafana/plugins/marcusolsson-csv-datasource/?tabinstallation。我是在本地使用命令行grafana-cli plugins install marcusol…

通过45人!1-2月,誉天红帽RHCE学员再创佳绩!

学习的喜悦在于结果,也在于过程;在于取得成功时的豁然开朗,也在于持之以恒后的层层递进。结果固然重要,但在求知过程中获得的满足感,也同样让人乐在其中。 RHCE的学习过程就充满了这样的喜悦。对每一行命令的理解、对每…

【Linux学习】日积月累——调试器gdb的使用教程

一、背景 gdb是一款强大的命令行调试工具,可以形成执行程序、脚本。只需要几个简单的命令,就能够实现Windows环境下VC等IDE的图形化调式工具的功能。 调试的相关常识: 程序的发布方式有两种,debug模式和release模式;L…

197.Spark(四):Spark 案例实操,MVC方式代码编程

一、Spark 案例实操 1.数据准备 电商网站的用户行为数据,主要包含用户的 4 种行为:搜索,点击,下单,支付 样例类: 2. Top10 热门品类 先按照点击数排名,靠前的就排名高;如果点击数相同,再比较下单数;下单数再相同,就比较支付数。 我们有多种写法,越往后性能越…

【Linux开发笔记】《Linux嵌入式开发从0到1》(一):初探Linux——与Linux的初次相遇

1.什么是Linux Linux就是一个操作系统,就是一个开源、自由的操作系统,就是一个免费使用和自由传播的类UNIX操作系统,就是一个基于POSIX的多用户、多任务、支持多线程和多CPU的操作系统。 简单来讲,Linux就是一个操作系统而已… …

React的Hooks

React Hooks useState useMemo 和usecallback Hooks显示的指明因变量有什么好处 当使用时,y与changeX会被缓存下来,只要x不变,始终读取的是缓存的值, 如果不使用时,每次函数组件执行时,实际会基于x&#xf…

计算机写论文时,怎么引用文献? - 易智编译EaseEditing

首先需要清楚哪些引用必须注明[1]: 任何直接引用都要用引号并注明来源; 任何不是自己的口头或书面的观点、解释和结论都应注明来源; 即使不用原话,但是他人的思路、概念或观点也应注明; 不要为了适合你的观点修改原…

机器学习——无监督学习

机器学习的分类一般分为下面几种类别:监督学习( supervised Learning )无监督学习( Unsupervised Learning )强化学习( Reinforcement Learning,增强学习)半监督学习( Semi-supervised Learning )深度学习(Deep Learning)Python Scikit-learn. http: // …

day40|198.打家劫舍、213.打家劫舍II、337.打家劫舍III

198.打家劫舍 你是一个专业的小偷,计划偷窃沿街的房屋。每间房内都藏有一定的现金,影响你偷窃的唯一制约因素就是相邻的房屋装有相互连通的防盗系统,如果两间相邻的房屋在同一晚上被小偷闯入,系统会自动报警。 给定一个代表每个…

软件测试8

一 缺陷 软件缺陷:是指软件或程序中存在的各种问题及错误,会导致软件产品在某种程度上不能满足用户的需求 二 软件缺陷的判定标准 1.软件未达到需求规格说明书中表明的功能 2.软件出现了需求规格说明书不会出现错误的地方 3.软件的功能超出了需求规格…

14 nuxt3学习(布局 渲染模式 插件plugin 生命周期)

布局 布局是围绕包含多个页面的公共用户界面的页面的包装器,例如页眉和页脚显示。 布局是使用slot 组件显示页面内容的Vue文件。 默认情况下使用layouts/default.vue文件。 自定义布局可以设置为页面元数据的一部分。 方式一:默认布局 在layouts目录下…

Xmind快捷键大全

Xmind快捷键大全 1、常用 CtrlShiftL 快捷键助手CtrlHome 返回中心主题Enter 插入主题Tab 插入子主题F2 编辑主题F3 添加/编辑标签F4 添加/编辑备注F6 下钻ShiftF6 上钻Delete 删除Ctrl] 插入摘要CtrlI 插入图片CtrlShiftH 插入超链接Ctrl1,2,3,4,5,6快速添加优先等级图标Ctr…

applicationContext相关加载

spring refresh 概述 refresh是一个方法,spring中所有的ApplicationContext容器都需要通过refresh方法初始化; 处理步骤 其中refresh方法包含12个主要的处理步骤: 1、第1个步骤做前置准备 2、第2~6步骤创建BeanFactory(Appl…

Java中垃圾回收(GC)算法详解

咱们要进行垃圾回收,是不是要知道哪些对象是垃圾,然后针对这些垃圾要怎么回收呢?那本篇文章我们就将垃圾回收分为标记垃圾、清除垃圾两个阶段讲解,详细说明每个阶段都有那些算法。1、标记阶段算法在堆里存放着几乎所有的Java对象实…

2023年交通与智慧城市国际会议(ICoTSC 2023)

2023年交通与智慧城市国际会议(ICoTSC 2023) 重要信息 会议网址:www.icotsc.org 会议时间:2023年7月28-30日 召开地点:长沙 截稿时间:2023年6月15日 录用通知:投稿后2周内 收录检索:EI,Scopus 会议简介…

轻松玩转开源大语言模型bloom(一)

前言 chatgpt已经成为了当下热门,github首页的trending排行榜上天天都有它的相关项目,但背后隐藏的却是openai公司提供的api收费服务。作为一名开源爱好者,我非常不喜欢知识付费或者服务收费的理念,所以便有决心写下此系列&#…