人工智能学习:ResNet神经网络(8)

news2025/7/19 9:01:29

ResNet是一种非常有效的图像分类识别的模型,可以参考如下的链接
https://blog.csdn.net/qq_45649076/article/details/120494328

ResNet网络由残差(Residual)结构的基本模块构成,每一个基本模块包含几个卷积层。其中,除了网络的推理输出,基本模块的输入也被直接加成到模块的输出。这种设计可以防止网络在深度加大之后产生退化的现象。

通常所见的有Resnet-18,Resnet-34,Resnet-50,Resnet-152等多层神经网络。作为示例,以下通过Tensorflow来构建Resnet-18模型。

首先,导入需要的模块

import numpy as np

import tensorflow as tf
from tensorflow import keras
from keras import models, layers

import matplotlib.pyplot as plt

然后,定义Resnet的基本模块

# define class of basic block
class ResBlock(keras.Model):
    def __init__(self, filters, strides=1, down_sample=False):
        super().__init__()
        
        self.down_sample = down_sample
        
        self.conv1 = layers.Conv2D(filters, (3,3), strides=strides, padding='same', use_bias=False)
        self.bn1 = layers.BatchNormalization()
        self.relu1 = layers.Activation('relu')
        
        self.conv2 = layers.Conv2D(filters, (3,3), strides=1, padding='same', use_bias=False)
        self.bn2 = layers.BatchNormalization()

        if self.down_sample:
            self.down_conv = layers.Conv2D(filters, (1,1), strides=strides, padding='same', use_bias=False)
            self.down_bn = layers.BatchNormalization()
            
        self.relu2 = layers.Activation('relu')

    def call(self, inputs):
        net = self.conv1(inputs)
        net = self.bn1(net)
        net = self.relu1(net)

        net = self.conv2(net)
        net = self.bn2(net)

        # down sample inputs if dimension changes
        if self.down_sample:
            identity_out = self.down_conv(inputs)
            identity_out = self.down_bn(identity_out)
        else:
            identity_out = inputs

        net = self.relu2(net+identity_out)
        
        return net

ResBlock由两个卷积层组成,每一个卷积层后面跟BatchNormalization层和Relu层。输入层的数据与第二层的卷积输出相加,通过Relu层产生模块的输出。这里分为两类,如果模块中的第一个卷积层进行了stride>1(通常为2)的降维卷积,那么输入也需要进行kernel_size为1的降维操作。

然后根据ResBlock来构建Resnet网络

# define class of Resnet-18 model
class Resnet18(keras.Model):
    def __init__(self, initial_filters=64):
        # each item in block_list represent number of base blocks(ResBlock) in that block
        super().__init__()
        
        filters = initial_filters
        
        # input layers
        self.input_layer = models.Sequential([
            layers.Conv2D(filters, (3,3), strides=1, padding='same', use_bias=False),
            layers.BatchNormalization(),
            layers.Activation('relu')
        ])
        
        # first layers, no down sample
        self.layer1 = models.Sequential([
            ResBlock(filters),
            ResBlock(filters)
        ])
        
        # second layer, filters doubles
        filters *= 2
        self.layer2 = models.Sequential([
            ResBlock(filters, strides=2, down_sample=True),
            ResBlock(filters)
        ])
        
        # third layer
        filters *= 2
        self.layer3 = models.Sequential([
            ResBlock(filters, strides=2, down_sample=True),
            ResBlock(filters)
        ])

        # third layer
        filters *= 2
        self.layer4 = models.Sequential([
            ResBlock(filters, strides=2, down_sample=True),
            ResBlock(filters)
        ])

        # output layer
        self.output_layer = models.Sequential([
            layers.GlobalAveragePooling2D(),
            layers.Dense(10, activation='softmax')
        ])
        
    def call(self, inputs):
        # input layer
        net = self.input_layer(inputs)
        
        # Resnet layers
        net = self.layer1(net)
        net = self.layer2(net)
        net = self.layer3(net)
        net = self.layer4(net)
        
        # output layer
        net = self.output_layer(net)
        
        return net

Resnet18由一个输入层、4个中间层和一个输出层组成,每一个中间层包含两基本模块,除了第一个中间层,每一层的第一个ResBlock为降维操作,第二个ResBlock为同维的卷积操作。每到下一个中间层,卷积特征的个数倍增。输入层为一个卷积层和一个BN层、一个Relu层组成。有些地方用7x7的降维卷积层和池化层,这里准备用于尺寸较小的CIFAR-10数据集,不进行降维操作。输出层为一个均值池化层和全连接层的连接。

下面构建网络并在CIFAR-10数据集上进行测试。
构建Resnet18网络,选择优化器

# build model, Resnet18
model = Resnet18()
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy'])

载入CIFAR-10数据

# test on CIFAR-10 data, load CIFAR-10 dataset
(train_images, train_labels), (test_images, test_labels) = keras.datasets.cifar10.load_data()

# train_images: 50000*32*32*3, train_labels: 50000*1
# test_images: 10000*32*32*3, test_labels: 10000*1

# pre-process data
train_input = train_images/255.0
test_input = test_images/255.0

train_output = train_labels
test_output = test_labels

定义数据处理器

# define data generator
data_generator = keras.preprocessing.image.ImageDataGenerator(
    featurewise_center=False,              #set mean of data to 0, by feature
    samplewise_center=False,               #set mean of data to 0, by sample
    featurewise_std_normalization=False,   #normalize by std, by feature
    samplewise_std_normalization=False,    #normalize by std, by sample
    zca_whitening=False,                   #zca whitening
    #zca_epsilon                           #zca epsilon, default 1e-6
    rotation_range=15,                     #degree of random rotation (integer, 0-180)
    width_shift_range=0.1,                 #probability of horizontal shift
    height_shift_range=0.1,                #probability of vertical shift
    horizontal_flip=True,                  #if random horizantal flip
    vertical_flip=False                    #if random vertical flip
)

data_generator.fit(train_input)

进行训练

# train, with batch size and epochs
epochs = 60
batch_size = 128

history = model.fit(data_generator.flow(train_input, train_output, batch_size=batch_size), epochs=epochs,
                      steps_per_epoch=len(train_input)//batch_size, validation_data=(test_input, test_output))

结果如下

Epoch 1/60
390/390 [==============================] - 47s 113ms/step - loss: 1.4509 - sparse_categorical_accuracy: 0.4840 - val_loss: 2.3173 - val_sparse_categorical_accuracy: 0.3550
Epoch 2/60
390/390 [==============================] - 43s 110ms/step - loss: 0.9640 - sparse_categorical_accuracy: 0.6601 - val_loss: 1.2524 - val_sparse_categorical_accuracy: 0.5868
Epoch 3/60
390/390 [==============================] - 43s 111ms/step - loss: 0.7742 - sparse_categorical_accuracy: 0.7273 - val_loss: 1.4201 - val_sparse_categorical_accuracy: 0.5968
...
Epoch 59/60
390/390 [==============================] - 43s 110ms/step - loss: 0.0483 - sparse_categorical_accuracy: 0.9824 - val_loss: 0.3880 - val_sparse_categorical_accuracy: 0.9106
Epoch 60/60
390/390 [==============================] - 43s 110ms/step - loss: 0.0501 - sparse_categorical_accuracy: 0.9830 - val_loss: 0.4432 - val_sparse_categorical_accuracy: 0.9006

最后达到训练精度98.3%,测试精度90.06%。
绘制训练曲线

# plot train history
loss = history.history['loss']
val_loss = history.history['val_loss']
acc = history.history['sparse_categorical_accuracy']
val_acc = history.history['val_sparse_categorical_accuracy']

plt.figure(figsize=(11,3.5))

plt.subplot(1,2,1)
plt.plot(loss, color='blue', label='train')
plt.plot(val_loss, color='red', label='test')
plt.ylabel('loss')
plt.legend()

plt.subplot(1,2,2)
plt.plot(acc, color='blue', label='train')
plt.plot(val_acc, color='red', label='test')
plt.ylabel('accuracy')
plt.legend()

结果如下
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/17440.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【MySQL数据库笔记 - 进阶篇】(五)锁

✍个人博客:https://blog.csdn.net/Newin2020?spm1011.2415.3001.5343 📚专栏地址:暂定 📝视频地址:黑马程序员 MySQL数据库入门到精通 📣专栏定位:这个专栏我将会整理 B 站黑马程序员的 MySQL…

硬件科普系列之显示篇——LCD与OLED知多少

前言 无论是手机还是电脑,作为机器与人交互最为频繁的硬件设备,显示屏一直是决定用户体验最为关键的因素之一。大家近几年在购买手机的时候,可以发现目前大部分手机都在使用OLED屏幕,那么你有没有思考过为什么各大厂商都在大力推…

jupuyter的背景主题

jupuyter的背景主题一.背景主题安装查看可用主题1.主题安装2. **查看可用主题**3.更换主题,字体等其他设置4.其他命令,还原原本主题二.每个主题的效果1.chesterish2. grade33.gruvboxd4.oceans165.onedork6.solarizedd7.solarizedl一.背景主题安装查看可…

上帝视角看Vue源码整体架构+相关源码问答

前言 这段时间利用课余时间夹杂了很多很多事把 Vue2 源码学习了一遍,但很多都是跟着视频大概过了一遍,也都画了自己的思维导图。但还是对详情的感念模糊不清,故这段时间对源码进行了总结梳理。 本篇文章更合适于已看过 Vue2 源码&#xff0c…

使用NNI对DLASeg剪枝的失败记录

本文希望对CenterNet算法的Backbone暨DLASeg进行剪枝。 剪枝试验涉及3个文件,分别为: DCN可变性卷积dcn_v2.py,因为DLASeg依赖DCN。 #!/usr/bin/env python from __future__ import absolute_import from __future__ import print_functio…

如何在 Windows 10上修复0x000006ba错误

修复0x000006ba错误 可能导致此错误代码的原因已确认的可行的解决办法运行打印机疑难解答重新启动后台打印程序服务清除 PRINTERS 文件夹运行 SFC 和 DISM 扫描启用打印机共享某些 Windows 10 在尝试在 Windows 10 上打印新文档时遇到0x000006ba错误代码。其他用户在尝试使用 W…

【面试题】line-height继承问题

1. line-height为具体数值 当父元素line-height的值为具体数值的时候&#xff0c;例如30px&#xff0c;则子元素的line-height直接继承该数值。 <style>body{font-size: 20px;line-height: 50px;}p{background-color: #ccc;font-size: 16px;} </style><body&g…

类和对象的初步介绍

文章目录面向对象的初步认识什么是面向对象面向对象与面向过程类定义和使用简单认识类类的定义格式随堂练习定义一个学生类类的实例化什么是实例化类和对象的说明this 引用为什么要有this引用什么时this引用this引用的特性对象的构造和初始化构造方法概念特性默认初始化就地初始…

Shell脚本学习指南(三)——文本处理工具

文章目录排序文本行的排序以字段的排序文本块排序sort的效率sort的稳定性sort小结删除重复重新格式化段落计算行数、字数以及字符数打印打印技术的演化其他打印软件提取开头或结尾数行排序文本 含有独立数据记录的文本文恶剪&#xff0c;通常都可以拿来排序。一个可预期的记录…

Vue3 - 组件通信(父传子)

前言 在 Vue3 中&#xff0c;父组件向子组件传参的方法。 与 Vue2 相比&#xff0c;还是有一些区别的。 基础示例 现在我们的需求是&#xff0c;要通过父组件&#xff0c;传递一个标题来让子组件显示。 子组件 Com.vue&#xff1a; <template><div>{{ title }}&l…

大数据工程师必备之数据可视化技术

可视化技术 数据&#xff1a; 偏耀明 7800 高军鹏 8000 代欣 8800 王国庆 20000 ​ 应对现在数据可视化的趋势&#xff0c;越来越多企业需要在很多场景(营销数据、生产数据、用户数据)下使用&#xff0c;可视化图表来展示体现数据&#xff0c;让数据更加直观&#xff0c;数…

tp6使用redis消息队列

尾部写入 for ($i1;$i<1000;$i){Cache::store(redis)->rpush(list,date("Y-m-d H:i:s")."消息{$i}"); }头部读取消息队列并删除 $list Cache::store(redis)->lpop(list); 1、新建个方法运行写入队列 public function hello(){for ($i1;$i<…

C++ Reference: Standard C++ Library reference: Containers: deque: deque: erase

C官网参考链接&#xff1a;https://cplusplus.com/reference/deque/deque/erase/ 公有成员函数 <deque> std::deque::erase C98 iterator erase (iterator position); iterator erase (iterator first, iterator last); C11 iterator erase (const_iterator position )…

Android 后台服务启动Actvity

一、问题背景 相机自动化测试需求&#xff0c;测试apk通过bindService绑定相机apk里面的一个服务&#xff0c;通过AIDL接口的方式向相机apk发送命令&#xff0c;服务接收到命令之后会拉起相机的Activity。原本没有人为干预的情况下是可以拉起这个Activity的&#xff0c;但是拉…

基于PYTHON游乐场服务管理系统的设计与实现

摘要 项目门票是游乐园必不可少的一个部分。在游乐园发展的整个过程中&#xff0c;项目门票担负着最重要的角色。为满足如今日益复杂的管理需求&#xff0c;各类管理系统程序也在不断改进。本课题所设计的游乐场服务管理系统&#xff0c;使用Django框架&#xff0c;Python语言进…

如何优雅部署OpenStack私有云I--Kolla

为方便大数据平台与管理工具的研发&#xff0c;在公司成本不额外增加的情况下&#xff0c;从公司仓库里拉了几台下线物理机来做大数据平台的实验环境。但整体物理机性能都偏高&#xff0c;单独安装一个大数据服务&#xff0c;很豪&#xff0c;但是也很浪费。而且主机台数不是很…

优先级队列(堆)——小记

文章目录堆概念堆的创建堆向下调整堆的插入堆的删除堆排序整体代码&#xff08;创建堆&#xff08;向下调整&#xff09;&#xff0c;堆的插入&#xff0c;堆的删除&#xff0c;堆排序&#xff09;TOPKPriorityQueue特性堆 概念 如果有一个关键码的集合Kk0&#xff0c;k1&…

48 基于 jdk9 编译的 jdk8 的字节码报错

前言 呵呵 大概是由于最近的这个 “Apache Log4j被曝存在严重高危险级别远程代码执行漏洞” 昨天晚上 编译了一下 logging-log4j2-log4j-2.15.0-rc2, 项目需要一个 toolchain.xml 的一个配置, 里面需要配置为 jdk9 因此 我的项目配置的 jdk 为 jdk9 然后 idea 里面默认…

【计算机毕业设计】校园二手市场平台+vue源码

一、系统截图&#xff08;需要演示视频可以私聊&#xff09; 摘 要 21世纪的今天&#xff0c;随着社会的不断发展与进步&#xff0c;人们对于信息科学化的认识&#xff0c;已由低层次向高层次发展&#xff0c;由原来的感性认识向理性认识提高&#xff0c;管理工作的重要性已逐…

校园跑腿系统小程序怎么用_校园跑腿系统小程序的基本功能是什么

大学可能是人生中最可能的阶段&#xff0c;而大学也是创业的最佳选择。近年来&#xff0c;在微信小程序的红利生态圈下&#xff0c;校园跑腿系统系统已经成为大学校园创业的第一热点。 随着大学生人数的增加&#xff0c;消费水平也在不断地提高&#xff0c;大学校园内代取快递、…