基于PaddlePaddle的图片分类实战 | 深度学习基础任务教程系列

news2025/6/21 1:13:17

 

       图像相比文字能够提供更加生动、容易理解及更具艺术感的信息,图像分类是根据图像的语义信息将不同类别图像区分开来,是图像检测、图像分割、物体跟踪、行为分析等其他高层视觉任务的基础。图像分类在安防、交通、互联网、医学等领域有着广泛的应用。

  一般来说,图像分类通过手工提取特征或特征学习方法对整个图像进行全部描述,然后使用分类器判别物体类别,因此如何提取图像的特征至关重要。基于深度学习的图像分类方法,可以通过有监督或无监督的方式学习层次化的特征描述,从而取代了手工设计或选择图像特征的工作。深度学习模型中的卷积神经网络(Convolution Neural Network, CNN) 直接利用图像像素信息作为输入,最大程度上保留了输入图像的所有信息,通过卷积操作进行特征的提取和高层抽象,模型输出直接是图像识别的结果。这种基于"输入-输出"直接端到端的学习方法取得了非常好的效果。

  本教程主要介绍图像分类的深度学习模型,以及如何使用PaddlePaddle在CIFAR10数据集上快速实现CNN模型。

  项目地址:http://paddlepaddle.org/documentation/docs/zh/1.3/beginners_guide/basics/image_classification/index.html

  基于ImageNet数据集训练的更多图像分类模型,及对应的预训练模型、finetune操作详情请参照Github:https://github.com/PaddlePaddle/models/blob/develop/PaddleCV/image_classification/README_cn.md

  效果展示

  图像分类包括通用图像分类、细粒度图像分类等。图1展示了通用图像分类效果,即模型可以正确识别图像上的主要物体。

  图1. 通用图像分类展示

  图2展示了细粒度图像分类-花卉识别的效果,要求模型可以正确识别花的类别。

  图2. 细粒度图像分类展示

  一个好的模型既要对不同类别识别正确,同时也应该能够对不同视角、光照、背景、变形或部分遮挡的图像正确识别(这里我们统一称作图像扰动)。图3展示了一些图像的扰动,较好的模型会像聪明的人类一样能够正确识别。

  图3. 扰动图片展示[7]

  模型概览

  CNN:传统CNN包含卷积层、全连接层等组件,并采用softmax多类别分类器和多类交叉熵损失函数,一个典型的卷积神经网络如图4所示,我们先介绍用来构造CNN的常见组件。

  图4. CNN网络示例[5]

  • 卷积层(convolution layer): 执行卷积操作提取底层到高层的特征,发掘出图片局部关联性质和空间不变性质。

  • 池化层(pooling layer): 执行降采样操作。通过取卷积输出特征图中局部区块的最大值(max-pooling)或者均值(avg-pooling)。降采样也是图像处理中常见的一种操作,可以过滤掉一些不重要的高频信息。

  • 全连接层(fully-connected layer,或者fc layer): 输入层到隐藏层的神经元是全部连接的。

  • 非线性变化: 卷积层、全连接层后面一般都会接非线性变化函数,例如Sigmoid、Tanh、ReLu等来增强网络的表达能力,在CNN里最常使用的为ReLu激活函数。

  • Dropout [1] : 在模型训练阶段随机让一些隐层节点权重不工作,提高网络的泛化能力,一定程度上防止过拟合。

  接下来我们主要介绍VGG,ResNet网络结构。

  VGG:牛津大学VGG(Visual Geometry Group)组在2014年ILSVRC提出的模型被称作VGG模型 [2] 。该模型相比以往模型进一步加宽和加深了网络结构,它的核心是五组卷积操作,每两组之间做Max-Pooling空间降维。同一组内采用多次连续的3X3卷积,卷积核的数目由较浅组的64增多到最深组的512,同一组内的卷积核数目是一样的。卷积之后接两层全连接层,之后是分类层。由于每组内卷积层的不同,有11、13、16、19层这几种模型,下图展示一个16层的网络结构。VGG模型结构相对简洁,提出之后也有很多文章基于此模型进行研究,如在ImageNet上首次公开超过人眼识别的模型[4]就是借鉴VGG模型的结构。

  图5. 基于ImageNet的VGG16模型

  ResNet:ResNet(Residual Network) [3] 是2015年ImageNet图像分类、图像物体定位和图像物体检测比赛的冠军。针对随着网络训练加深导致准确度下降的问题,ResNet提出了残差学习方法来减轻训练深层网络的困难。在已有设计思路(BN, 小卷积核,全卷积网络)的基础上,引入了残差模块。每个残差模块包含两条路径,其中一条路径是输入特征的直连通路,另一条路径对该特征做两到三次卷积操作得到该特征的残差,最后再将两条路径上的特征相加。

  残差模块如图7所示,左边是基本模块连接方式,由两个输出通道数相同的3x3卷积组成。右边是瓶颈模块(Bottleneck)连接方式,之所以称为瓶颈,是因为上面的1x1卷积用来降维(图示例即256->64),下面的1x1卷积用来升维(图示例即64->256),这样中间3x3卷积的输入和输出通道数都较小(图示例即64->64)。

  图7. 残差模块

  数据准备

  由于ImageNet数据集较大,下载和训练较慢,为了方便大家学习,我们使用CIFAR10数据集。CIFAR10数据集包含60,000张32x32的彩色图片,10个类别,每个类包含6,000张。其中50,000张图片作为训练集,10000张作为测试集。图11从每个类别中随机抽取了10张图片,展示了所有的类别。

  图11. CIFAR10数据集[6]

  Paddle API提供了自动加载cifar数据集模块 paddle.dataset.cifar。

  通过输入python train.py,就可以开始训练模型了,以下小节将详细介绍train.py的相关内容。

  模型结构

  Paddle 初始化

  让我们从导入 Paddle Fluid API 和辅助模块开始。

  本教程中我们提供了VGG和ResNet两个模型的配置。

  VGG

  首先介绍VGG模型结构,由于CIFAR10图片大小和数量相比ImageNet数据小很多,因此这里的模型针对CIFAR10数据做了一定的适配。卷积部分引入了BN和Dropout操作。 VGG核心模块的输入是数据层,vgg_bn_drop 定义了16层VGG结构,每层卷积后面引入BN层和Dropout层,详细的定义如下:

  首先定义了一组卷积网络,即conv_block。卷积核大小为3x3,池化窗口大小为2x2,窗口滑动大小为2,groups决定每组VGG模块是几次连续的卷积操作,dropouts指定Dropout操作的概率。所使用的img_conv_group是在paddle.fluit.net中预定义的模块,由若干组 Conv->BN->ReLu->Dropout 和 一组 Pooling 组成。

  五组卷积操作,即 5个conv_block。 第一、二组采用两次连续的卷积操作。第三、四、五组采用三次连续的卷积操作。每组最后一个卷积后面Dropout概率为0,即不使用Dropout操作。

  最后接两层512维的全连接。

  在这里,VGG网络首先提取高层特征,随后在全连接层中将其映射到和类别维度大小一致的向量上,最后通过Softmax方法计算图片划为每个类别的概率。

  ResNet

  ResNet模型的第1、3、4步和VGG模型相同,这里不再介绍。主要介绍第2步即CIFAR10数据集上ResNet核心模块。

  先介绍resnet_cifar10中的一些基本函数,再介绍网络连接过程。

  • conv_bn_layer : 带BN的卷积层。

  • shortcut : 残差模块的"直连"路径,"直连"实际分两种形式:残差模块输入和输出特征通道数不等时,采用1x1卷积的升维操作;残差模块输入和输出通道相等时,采用直连操作。

  • basicblock : 一个基础残差模块,即图9左边所示,由两组3x3卷积组成的路径和一条"直连"路径组成。

  • layer_warp : 一组残差模块,由若干个残差模块堆积而成。每组中第一个残差模块滑动窗口大小与其他可以不同,以用来减少特征图在垂直和水平方向的大小。

resnet_cifar10 的连接结构主要有以下几个过程。

  底层输入连接一层 conv_bn_layer,即带BN的卷积层。

  然后连接3组残差模块即下面配置3组 layer_warp ,每组采用图 10 左边残差模块组成。

  最后对网络做均值池化并返回该层。

  注意:除第一层卷积层和最后一层全连接层之外,要求三组 layer_warp 总的含参层数能够被6整除,即 resnet_cifar10 的 depth 要满足

  Infererence配置

  网络输入定义为 data_layer (数据层),在图像分类中即为图像像素信息。CIFRAR10是RGB 3通道32x32大小的彩色图,因此输入数据大小为3072(3x32x32)。

  Train 配置

  然后我们需要设置训练程序 train_network。它首先从推理程序中进行预测。 在训练期间,它将从预测中计算 avg_cost。 在有监督训练中需要输入图像对应的类别信息,同样通过fluid.layers.data来定义。训练中采用多类交叉熵作为损失函数,并作为网络的输出,预测阶段定义网络的输出为分类器得到的概率信息。

  注意: 训练程序应该返回一个数组,第一个返回参数必须是 avg_cost。训练器使用它来计算梯度。

  Optimizer 配置

  在下面的 Adam optimizer,learning_rate 是学习率,与网络的训练收敛速度有关系。

  def optimizer_program():

  return fluid.optimizer.Adam(learning_rate=0.001)

  训练模型

  Data Feeders 配置

  cifar.train10() 每次产生一条样本,在完成shuffle和batch之后,作为训练的输入。

  Trainer 程序的实现

  我们需要为训练过程制定一个main_program, 同样的,还需要为测试程序配置一个test_program。定义训练的 place ,并使用先前定义的优化器。

  训练主循环以及过程输出

  在接下来的主训练循环中,我们将通过输出来来观察训练过程,或进行测试等。

  训练

  通过trainer_loop函数训练, 这里我们只进行了2个Epoch, 一般我们在实际应用上会执行上百个以上Epoch

  注意: CPU,每个 Epoch 将花费大约15~20分钟。这部分可能需要一段时间。请随意修改代码,在GPU上运行测试,以提高训练速度。

  train_loop()

  一轮训练log示例如下所示,经过1个pass, 训练集上平均 Accuracy 为0.59 ,测试集上平均 Accuracy 为0.6 。

图13是训练的分类错误率曲线图,运行到第200个pass后基本收敛,最终得到测试集上分类错误率为8.54%。

  图13. CIFAR10数据集上VGG模型的分类错误率

  应用模型

  可以使用训练好的模型对图片进行分类,下面程序展示了如何加载已经训练好的网络和参数进行推断。

  生成预测输入数据

  dog.png 是一张小狗的图片. 我们将它转换成 numpy 数组以满足feeder的格式.

  Inferencer 配置和预测

  与训练过程类似,inferencer需要构建相应的过程。我们从params_dirname 加载网络和经过训练的参数。 我们可以简单地插入前面定义的推理程序。 现在我们准备做预测。

  总结

  传统图像分类方法由多个阶段构成,框架较为复杂,而端到端的CNN模型结构可一步到位,而且大幅度提升了分类准确率。本文我们首先介绍VGG、GoogleNet、ResNet三个经典的模型;然后基于CIFAR10数据集,介绍如何使用PaddlePaddle配置和训练CNN模型,尤其是VGG和ResNet模型;最后介绍如何使用PaddlePaddle的API接口对图片进行预测和特征提取。对于其他数据集比如ImageNet,配置和训练流程是同样的,请参照Github https://github.com/PaddlePaddle/models/blob/develop/PaddleCV/image_classification/README_cn.md。

  参考文献

  [1] G.E. Hinton, N. Srivastava, A. Krizhevsky, I. Sutskever, and R.R. Salakhutdinov. Improving neural networks by preventing co-adaptation of feature detectors. arXiv preprint arXiv:1207.0580, 2012.

  [2] K. Chatfield, K. Simonyan, A. Vedaldi, A. Zisserman. Return of the Devil in the Details: Delving Deep into Convolutional Nets. BMVC, 2014。

  [3] K. He, X. Zhang, S. Ren, J. Sun. Deep Residual Learning for Image Recognition. CVPR 2016.

  [4] He, K., Zhang, X., Ren, S., and Sun, J. Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification. ArXiv e-prints, February 2015.

  [5] http://deeplearning.net/tutorial/lenet.html

  [6] https://www.cs.toronto.edu/~kriz/cifar.html

  [7] http://cs231n.github.io/classification/

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/411424.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

LeetCode:77. 组合——回溯法,是暴力法?

🍎道阻且长,行则将至。🍓 🌻算法,不如说它是一种思考方式🍀算法专栏: 👉🏻123 一、🌱77. 组合 题目描述:给定两个整数 n 和 k,返回范…

风场数据抓取程序实现(java+python实现)

一、数据源参数定义 关键参数代码: package com.grab.catchWindData.pram;/*** ClassName: DevPrams* Description: TODO**/ public class DevPrams {public static String lev_0to0p1_m_below_ground "lev_0-0.1_m_below_ground";public static Stri…

【微服务笔记08】微服务组件之Hystrix实现请求合并功能

这篇文章,主要介绍微服务组件之Hystrix实现请求合并功能。 目录 一、Hystrix请求合并 1.1、什么是请求合并 1.2、请求合并的实现 (1)引入依赖 (2)编写服务提供者 (3)消费者(Se…

React | 认识React开发

✨ 个人主页:CoderHing 🖥️ Node.js专栏:Node.js 初级知识 🙋‍♂️ 个人简介:一个不甘平庸的平凡人🍬 💫 系列专栏:吊打面试官系列 16天学会Vue 11天学会React Node专栏 &#x…

【分享】免梯子的GPT,玩 ChatGPT 的正确姿势

火了一周的 ChatGPT,HG 不允许还有小伙伴不知道这个东西是什么?简单来说就是,你可以让它扮演任何事物,据说已经有人用它开始了颜色文学创作。因为它太火了,所以,本周特推在几十个带有“chatgpt”的项目中选…

双交叉注意学习用于细粒度视觉分类和目标重新识别

目录Dual Cross-Attention Learning for Fine-Grained Visual Categorization and Object Re-Identification摘要本文方法消融实验Dual Cross-Attention Learning for Fine-Grained Visual Categorization and Object Re-Identification 摘要 目的: 探索了如何扩展…

JDK8——新增时间类、有关时间数据的交互问题

目录 一、实体类 二、数据库 三、数据交换 四、关于LocalDateTime类型 (java 8) 4.1 旧版本日期时间问题 4.2 新版日期时间API介绍 4.2.1 LocalDate、LocalTime、LocalDateTime 4.2.2 日期时间的修改与比较 4.2.3 格式化和解析操作 4.2.4 Instant: 时间戳 4.2.5 Duration 与…

Doris(6):数据导入(Load)之Stream Load

Broker load是一个同步的导入方式,用户通过发送HTTP协议将本地文件或者数据流导入到Doris中,Stream Load同步执行导入并返回结果,用户可以通过返回判断导入是否成功。 1 适用场景 Stream load 主要适用于导入本地文件,或通过程序…

小厂实习要不要去?

大家好,我是帅地。 最近暑假实习招聘,不少 训练营 学员都拿到了小厂实习来保底,但是很多小厂基本要求一周内给答复,中大厂就还在流程之中,所以很纠结小厂实习要不要去。 不知道你是否有这样的纠结,今天帅地…

【测试面试汇总2】

目录Linux操作系统1.Linux操作命令2.在Linux中find和grep的区别?3.绝对路径用什么符号表示?4.当前目录、上层目录用什么表示?5.主目录用什么表示?6.怎么查看进程信息?7.保存文件并退出vi 编辑?8.怎么查看当前用户id&a…

【Python从入门到进阶】15、函数的定义和使用

接上篇《14、字典高级应用》 上一篇我们学习了有关字典的高级应用操作(字典的增删改查),本篇我们来学习Python中函数的定义和使用,包括函数的参数、返回值、局部变量和全景变量等操作。 一、一个思考 例如这里有一段大东北洗浴中…

2023年PMP报考时间安排攻略!

1.2023年PMP考试时间 PMP一年开考4次,分别为3月、6月、9月、12月,预计2023年PMP第一次考试时间在2023年3月左右,具体以基金会官方通知为准。 1)为什么考PMP? 大部分人考 PMP 无非以下几个原因,总的来说&…

运行时内存数据区之程序计数器

内存是非常重要的系统资源,是硬盘和CPU的中间仓库及桥梁,承载着操作系统和应用程序的实时选行。JVM内存布局规定了Java在运行过程中内存申请、分配、管理的策略,保证了JVM的高效稳定运行。 不同的VM对于内存的划分方式和管理机制存在着部分差…

算法时间复杂度计算

目录 1.时间复杂度计算 1.1 时间复杂度例题 1.1.1例题 1.1.2例题 1.1.3例题 1.1.4例题 1.2时间复杂度leetcode例题 1.时间复杂度计算 首先,我们需要了解时间复杂度是什么:算法的时间复杂度是指算法在编写成可执行程序后,运行时需要耗费…

一天吃透操作系统八股文

操作系统的四个特性? 并发:同一段时间内多个程序执行(与并行区分,并行指的是同一时刻有多个事件,多处理器系统可以使程序并行执行) 共享:系统中的资源可以被内存中多个并发执行的进线程共同使…

MATLAB | 给热图整点花哨操作(三角,树状图,分组图)

前段时间写的特殊热图绘制函数迎来大更新,基础使用教程可以看看这一篇: https://slandarer.blog.csdn.net/article/details/129292679 原本的绘图代码几乎完全不变,主要是增添了很多新的功能!!! 工具函数完…

FastChat开放,媲美ChatGPT的90%能力——从下载到安装、部署

FastChat开放,媲美ChatGPT的90%能力——从下载到安装、部署前言两个前置软件创建FastChat虚拟环境安装PyTorch安装 FastChat下载 LLaMA,并转换生成FastChat对应的模型Vicuna启动FastChat的命令行交互将模型部署为一个服务,提供Web GUI前言 最…

Cesium:自定义MaterialProperty

在项目中应用Cesium.js时,时常遇到需要对Cesium.js的Material材质或者MaterialProperty材质属性进行拓展的应用场景。如果对GLSL(openGL Shading Language ),即:OpenGL着色语言熟悉的话,参考Cesium官方文档,构建一个新的Material必定不是难事。而MaterialProperty材质属…

【C语言进阶:动态内存管理】动态内存函数的介绍

本节重点内容: malloc 和 free 函数calloc 函数realloc 函数🌸为什么存在动态内存分配 到目前为止,我们已经掌握的内存开辟方式有两种: 创建变量:int val 20; //在栈空间上开辟四个字节 创建数组&#xff1…

Html5钢琴块游戏制作与分享(音游可玩)

当年一款手机节奏音游,相信不少人都玩过或见过。最近也是将其做了出来分享给大家。 游戏的基本玩法:点击下落的黑色方块,弹奏音乐。(下落的速度会越来越快) 可以进行试玩,手机玩起来效果会更好些。 点击…