pytorch实现经典神经网络:VGG16模型之复现

news2025/6/20 3:43:27

可以参考https://blog.csdn.net/m0_37867091/article/details/107237671
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

在这里插入图片描述
分成 提取特征网络结构+分类结构

模型代码:
此模型写了VGG的几种网络结构

一、官方权重

# official pretrain weights
model_urls = {
    'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
    'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
    'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
    'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth'
}

一、 根据论文中模型结构搭建 提取特征网络1部分

首先写了cfgs这个字典
我们以vgg11为例
他的构建了一个列表
其中数字代表了卷积核个数(通道数)
M代表进入池化工作


cfgs = {
    'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}

以VGG11的特征提取网络部分为例:
所谓11就是 8层卷积层+3层池化
在这里插入图片描述

二、然后是提取特征网络2部分:

此处定义特征提取网络函数
首先参数以列表形式传进来

这里学一下函数定义传入方法
http://www.manongjc.com/detail/51-busecnjmsdoijob.html
https://www.jianshu.com/p/20d1b512b8b2

1、其中

def make_features(cfg: list):

cfg:list表示传入的是一个配置变量。它是一个list类型,因此用的时候只需要传入对应配置的列表即可。在这里插入图片描述
2、我们首先创建一个空列表叫做layers以存放自己的神经网络。
3、初始设定in_channels=3,因为初始图片是3通道
4、遍历我们传入的配置列表,如果是M,代表进入池化操作,因为计算过,池化的卷积核都是2,stride=2。因此在layers列表中自动添加池化
5、如果不是,进入添加卷积核的操作,输入是in_channels,输出是我们传入列表的当前值。(所有卷积核都是3,stride=1)
6、卷积的每一步后面添加ReLU激活函数减少数据量
7、并且把当前的列表值,赋给v作为下一层卷积的输入
8、最神的来了!

return nn.Sequential(*layers)

*将我们的列表layer,以非关键字参数的行书传入nn.Sequential(layers),可以传入任意数量,星号的作用是解包,把序列里面的元素一个个拆开

这是因为Sequential,一般是以这样非关键字形式传入(当然也可以用字典的形式)
在这里插入图片描述

def make_features(cfg: list):
    layers = []
    in_channels = 3
    for v in cfg:
        if v == "M":
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        else:
            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
            layers += [conv2d, nn.ReLU(True)]
            in_channels = v
    return nn.Sequential(*layers)

插个楼
https://www.jianshu.com/p/20d1b512b8b2
这里讲解了函数def的一下方法:
在这里插入图片描述
在这里插入图片描述

三、继承nn.moudule的神经网络主框架

class VGG(nn.Module):
    def __init__(self, features, num_classes=1000, init_weights=False):
        super(VGG, self).__init__()
        self.features = features
        self.classifier = nn.Sequential(
            nn.Linear(512*7*7, 4096),
            nn.ReLU(True),
            nn.Dropout(p=0.5),
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            nn.Dropout(p=0.5),
            nn.Linear(4096, num_classes)
        )
        if init_weights:
            self._initialize_weights()
首先定义变量features,类别1000种。
类种的features=全局变量种的features(特征)
也就是特征提取部分刚才的 代码二
后面是分类器也就是全连接层
 nn.Dropout(p=0.5)的目的是防止过拟合
 最后全连接层linear输出的是分类的类别个数


在这里插入图片描述
因为全连接层之前最后是7×7×512,所以linear输入是7×7×512

    def forward(self, x):
        # N x 3 x 224 x 224
        x = self.features(x)
        # N x 512 x 7 x 7
        x = torch.flatten(x, start_dim=1)
        # N x 512*7*7
        x = self.classifier(x)
        return x
        定义正向传播函数,我们继承VGG类后,函数传入这里
        首先便利列表传入feature内实现提取特征
        然后展平  **这里start_dim=1是从第1个维度展开**
         flatten就是把(N,C,H,W)的张量,变成(N,C*HW)
        展平后进入上面定义好的分类网络结构

这里start_dim=1是从第1个维度展开
因为第0个维度是batchsize
四个维度是(batchsize,channel,H,W) flatten就是把(N,C,H,W)的张量,变成(N,C*HW)

判断是否需要对网络结构进行参数初始化
如果之前class VGG(nn.Module):
def init(self, features, num_classes=1000, **init_weights=**False):
这里为TRUE时则初始化

        if init_weights:
            self._initialize_weights()

我们再来看一下初始化函数:

首先会便利网络的每一个子模块,
如果当前层是 卷积层 则会进入xavier初始化方法
去初始化卷积核的权重
如果卷积核采用了偏置
则会被置为变量0;
如果当前层是 全连接层(linaer) 则会进入xavier初始化方法
同理如果卷积核采用了偏置
则会被置为变量0def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                # nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

四、实例化我们VGG网络的方法:

通过给定model_name来实例化自己需要的模型
以“vgg16”为例
我们将model_name(vgg16)这个key值传入到定义好的字典中,得到vgg16后面的一系列列表:

def vgg(model_name="vgg16", **kwargs):
    assert model_name in cfgs, "Warning: model number {} not in cfgs dict!".format(model_name)
    cfg = cfgs[model_name]

通过model实例化我们第三步定义的VGG网络:将cfg
的列表传入vgg网络的make_feature内。。
    model = VGG(make_features(cfg), **kwargs)
    return model

**kwargs代表可变长度字典变量,
在这里插入图片描述
就是这些东西都可以传进去

以上,就是我们VGG网络的整体
也就是model.py的内容

整体代码:

import torch.nn as nn
import torch

# official pretrain weights
model_urls = {
    'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
    'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
    'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
    'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth'
}


class VGG(nn.Module):
    def __init__(self, features, num_classes=1000, init_weights=False):
        super(VGG, self).__init__()
        self.features = features
        self.classifier = nn.Sequential(
            nn.Linear(512*7*7, 4096),
            nn.ReLU(True),
            nn.Dropout(p=0.5),
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            nn.Dropout(p=0.5),
            nn.Linear(4096, num_classes)
        )
        if init_weights:
            self._initialize_weights()

    def forward(self, x):
        # N x 3 x 224 x 224
        x = self.features(x)
        # N x 512 x 7 x 7
        x = torch.flatten(x, start_dim=1)
        # N x 512*7*7
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                # nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)


def make_features(cfg: list):
    layers = []
    in_channels = 3
    for v in cfg:
        if v == "M":
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        else:
            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
            layers += [conv2d, nn.ReLU(True)]
            in_channels = v
    return nn.Sequential(*layers)


cfgs = {
    'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}


def vgg(model_name="vgg16", **kwargs):
    assert model_name in cfgs, "Warning: model number {} not in cfgs dict!".format(model_name)
    cfg = cfgs[model_name]

    model = VGG(make_features(cfg), **kwargs)
    return model

assert in的用法:
assert () // 断言,用于判断一个表达式,在这里,这个表达式是'pre_boxes' in outputs,仅在条件为false时触发,且一般写在代码的开始处。
() = 'pre_boxes' in outputs // in 关键词,用于判断关坚持是否在字典中,存在则返回true,不存在则返回false。
————————————————
版权声明:本文为CSDN博主「蛊惑one」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/guhuoone/article/details/124540721

五、训练部分代码:

 data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
        "val": transforms.Compose([transforms.Resize((224, 224)),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}

主要讲一下预处理这里:
随即裁剪
随机翻转。
转为tensor格式
发现在一般人写的代码中:
在预处理环节会分别用rgb剪去123.68,116.78,103,94
这三个值对应着 Imgnet数据集所有数据三通道的均值
如果自己采用迁移学习的方式,则需要有这一步。

  nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using {} dataloader workers every process'.format(nw))

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size, shuffle=True,
                                               num_workers=nw)

    validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
                                            transform=data_transform["val"])
    val_num = len(validate_dataset)
    validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                                  batch_size=batch_size, shuffle=False,
                                                  num_workers=nw)

这里num_workers线程数,如果是windows就是0
其他ubuntu可以改线程

这里输入自己要用的模型

 model_name = "vgg16"
    net = vgg(model_name=model_name, num_classes=5, init_weights=True)

参数会保存model.py中
在这里插入图片描述

六、一些想法:

1、dropout随即失活
全连接层linear后dropout防止过拟合
在这里插入图片描述
最后一步linear后面不加dropout是由于

linear最后的输出是 类的数量
随即失火反而会出错

2、from tqdm import tqdm
tqdm是运行时动态展示训练的情况、比如进度条啥的

3、导入model.py时红色下划线:
把上级目录设置为根目录即可

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1086845.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

vue2.6 和 2.7对可选链的不同支持导致构建失败

有两个vue2项目,构建配置和依赖基本上都一样,但一个可以在 template 模板中使用可选链(?.),另一个使用就报错。 但是报错的那个项目,在另一个同事那又不报错。 已知 node14 之后就支持可选链了,我和同事用的是 node…

【TensorFlow2 之015】 在 TF 2.0 中实现 AlexNet

一、说明 在这篇文章中,我们将展示如何在 TensorFlow 2.0 中实现基本的卷积神经网络 \(AlexNet\)。AlexNet 架构由 Alex Krizhevsky 设计,并与 Ilya Sutskever 和 Geoffrey Hinton 一起发布。并获得Image Net2012竞赛中冠军。 教程概述: 理论…

Spring framework Day13:注解结合Java配置类

前言 前面我们管理 bean 都是在 xml 文件中去管理,本次我们将介绍如何在 Java 配置类中去管理 bean。 注解结合 Java 配置类是一种常见的 Spring 注入 Bean 的方式。通常情况下,开发人员会使用 Java Config 来定义应用程序的配置信息,而在 …

三维地下管线建模软件MagicPipe3D V3.1.3发布

经纬管网建模系统MagicPipe3D V3.1.3持续更新,内容如下: (1)新增管线流向配置,建模生成带流向箭头管道模型; (2)新增建模完成后可以直接载入3DTiles或obj模型功能; &a…

GoLang开发使用gin框架搭建web程序

目录 1.SDK安装 ​2.编辑器下载 3.编辑器准备 4.使用 4.1常见请求方式 1.SDK安装 保证装了Golang的sdk(官网下载windows.zip->解压,安装,配置bin的环境变量) 2.编辑器下载 Download GoLand: A Go IDE with extended support for JavaScript, Ty…

postman 密码rsa加密登录-1获取公钥

fiddler抓包看到:请求系统地址会自动跳转到sso接口,查看200状态的接口返回的html里存在一个encrypt的信息,咨询开发这个就是返回的公钥。 在postman的tests里对该返回进行处理,获取公钥并设为环境变量 //获取公钥 var pubKey re…

Rancher 使用指南

Rancher 使用指南 Rancher 是什么?Rancher 与 OpenShift / Kubesphere 主要区别对比RancherOpenShiftKubesphere 对比 Rancher 和 OpenShift Rancher 安装 Rancher 是什么? 企业级Kubernetes管理平台 Rancher 是供采用容器的团队使用的完整软件堆栈。它解决了管理多个Kuber…

RT-Thread 内核移植(学习)

内核移植 内核移植就是指将RT-Thread内核在不同的芯片架构、不同的板卡上运行起来,能够具备线程管理和调度,内存管理,线程间同步和通信、定时器管理等功能。 移植可分为CPU架构移植和BSP(Board support package,板级…

催交费通知单套打单纸设置说明

2.0系统打印催交费通知单设置尺寸操作展示如下,仅供参考。具体如下: 一、Win7系统 1.找到设备和打印机,选中对应打印机后点击上方打印服务器属性; 2.创建一个宽14cm,高14cm的表单; 二、win10系统 1.找到打印机,点管理,选择打印首选项;

Unity关键词语音识别

一、背景 最近使用unity开发语音交互内容的时候,遇到了这样的需求,就是需要使用语音关键字来唤醒应用程序,然后再和程序做交互,有点像智能音箱的意思。具体的技术方案方面,也找了一些第三方的服务,比如百度…

当涉及到API接口数据分析时,主要可以从以下几个方面展开

当涉及到API接口数据分析时,主要可以从以下几个方面展开: 请求分析:可以统计每个API接口的请求次数、请求成功率、失败率等基础指标。这些指标可以帮助你了解API接口的使用情况,比如哪个API接口被调用的次数最多,哪个…

2023年09月 C/C++(四级)真题解析#中国电子学会#全国青少年软件编程等级考试

C/C编程(1~8级)全部真题・点这里 Python编程(1~6级)全部真题・点这里 第1题:酒鬼 Santo刚刚与房东打赌赢得了一间在New Clondike 的大客厅。今天,他来到这个大客厅欣赏他的奖品。房东摆出了一行瓶子在酒吧上…

《向量数据库指南》——向量数据库与 ANN 算法库的区别

向量数据库与 ANN 算法库的区别 我们经常听到一个这样的错误观念——向量数据库只是在 ANN(approximate nearest neighbor,近似最近邻)算法上封装了一层。但这种说法大错特错。 向量数据库可以处理大规模数据,而 ANN 算法库只能处理小型的数据集 从本质上来看,以 Milvus 为…

Adobe Premiere Pro 和 After Effects 安装出错的解决路径

在有点年头的电脑上安装Premiere Pro 和 After Effects 遇到了前所未有的的麻烦,请了某宝上的小哥进行远程安装,两个软件倒是可以用了,但Win11系统无法正常关机,用了几天系统除了关机时会蓝屏几十秒,其他没有发现毛病&…

centos 7 lamp owncloud

OwnCloud是一款开源的云存储软件,基于PHP的自建网盘。基本上是私人使用,没有用户注册功能,但是有用户添加功能,你可以无限制地添加用户,OwnCloud支持多个平台(windows,MAC,Android&a…

计算机网络 | 物理层

计算机网络 | 物理层 计算机网络 | 物理层基本概念数据通信基本知识(一)一个数据通信流程的例子数据通信相关术语三种通信方式数据传输方式串行传输和并行传输同步传输和异步传输 小结 数据通信基本知识(二)码元(Symbo…

【Java 进阶篇】JavaScript 一元运算符详解

在JavaScript中,一元运算符是一类操作符,它们作用于单一操作数(一个值)。这些运算符执行各种操作,包括递增、递减、类型转换等。本文将详细介绍JavaScript中的一元运算符,解释它们的用途,提供示…

MySQL MVCC详细介绍

MVCC概念 MVCC(Multi-Version Concurrency Control) 多版本并发控制,是一种并发控制机制,用于处理数据库中的并发读写操作,它通过在每个事务中创建数据的快照,实现了读写操作的隔离性,从而避免了读写冲突和数据不一致的问题。 M…

VUE echarts 柱状图、折线图 双Y轴 显示

weekData: [“1周”,“2周”,“3周”,“4周”,“5周”,“6周”,“7周”,“8周”,“9周”,“10周”], //柱状图横轴 jdslData: [150, 220, 430, 360, 450, 680, 100, 450, 680, 200], // 折线图的数据 cyslData: [100, 200, 400, 300, 500, 500, 500, 450, 480, 400], // 柱状图…

基于VScode 使用plantUML 插件设计状态机

本文主要记录本人初次在VScode上使用PlantUML设计 本文只讲述操作的实际方法,假设java已安装成功 。 1. 在VScode下安装如下插件 2. 验证环境是否正常 新建一个文件夹并在目录下面新建文件test.plantuml 其内容如下所示: startuml hello world skinparam Style …