自定义的卷积神经网络模型CNN,对图片进行分类并使用图片进行测试模型-适合入门,从模型到训练再到测试,开源项目

news2025/7/27 23:23:15

自定义的卷积神经网络模型CNN,对图片进行分类并使用图片进行测试模型-适合入门,从模型到训练再到测试:开源项目
开源项目完整代码及基础教程:
https://mbd.pub/o/bread/ZZWclp5x

在这里插入图片描述
CNN模型:
在这里插入图片描述

1.导入必要的库和模块:

torch:PyTorch深度学习框架。

torchvision:PyTorch的计算机视觉库,用于处理图像数据。

transforms:包含数据预处理的模块。

nn:PyTorch的神经网络模块。

F:PyTorch的函数模块,包括各种激活函数等。

optim:优化算法模块。

2.数据预处理:

transforms.Compose:将一系列数据预处理步骤组合在一起。

transforms.ToTensor():将图像数据转换为张量。

transforms.Normalize:对图像数据进行归一化处理,以均值0.5和标准差0.5。

定义批处理大小:

batch_size:每个训练批次包含的图像数量。

加载训练集:

trainset:使用CIFAR-10数据集,设置训练标志为True。

torch.utils.data.DataLoader:创建用于加载训练数据的数据加载器,指定批处理大小和其他参数。

加载测试集:

testset:使用CIFAR-10数据集,设置训练标志为False。

torch.utils.data.DataLoader:创建用于加载测试数据的数据加载器,指定批处理大小和其他参数。

定义CNN模型:

My_CNN:自定义的卷积神经网络模型,包括卷积层、池化层和全连接层。

创建CNN模型、损失函数和优化器:

model:创建My_CNN模型的实例。

nn.CrossEntropyLoss():定义用于多分类问题的交叉熵损失函数。

optim.SGD:使用随机梯度下降优化器,指定学习率和动量。

训练模型:

epochs:指定训练轮数。

循环中的嵌套循环:迭代训练数据批次,进行前向传播、反向传播和参数优化。

保存模型:

model_path:指定模型保存的路径。

torch.save:保存训练后的模型。

在测试集上评估模型性能:

计算模型在测试集上的准确率。

计算每个类别的准确率。

具体代码来说:

transform = transforms.Compose(

[transforms.ToTensor(),

 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

解释:

transforms.Compose:这是一个用于组合多个数据预处理步骤的函数。它允许你按顺序应用多个转换,以便将原始数据转换为最终的形式。

transforms.ToTensor():这是一个数据预处理步骤,将图像数据转换为张量(tensor)的格式。在深度学习中,张量是常用的数据表示方式,因此需要将图像数据从常见的图像格式(如JPEG或PNG)转换为张量。

transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)):这是另一个数据预处理步骤,用于对图像进行归一化处理。归一化的目的是将图像的像素值缩放到一个特定的范围,以便神经网络更容易学习。在这里,均值和标准差都被设置为0.5,这将使图像像素值在-1到1之间。

batch_size = 4

 trainset = torchvision.datasets.CIFAR10(root='./data', train=True,

                                        download=True, transform=transform)

 trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,

                                          shuffle=True, num_workers=0)

 testset = torchvision.datasets.CIFAR10(root='./data', train=False,

                                       download=True, transform=transform)

 testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,

                                         shuffle=False, num_workers=0)

解释:

batch_size = 4:定义了每个训练和测试批次中包含的图像数量。在深度学习中,通常将数据分成小批次进行训练,以便更有效地使用计算资源。

trainset 和 testset 的定义:这两行代码加载了CIFAR-10数据集的训练集和测试集,并进行了如下操作:

torchvision.datasets.CIFAR10:使用CIFAR-10数据集,它包括一组包含10个不同类别的图像数据,适用于图像分类任务。

root=‘./data’:指定数据集的存储目录,可以根据需要更改。

train=True 和 train=False:这两个参数分别用于加载训练集和测试集。

download=True:如果数据集尚未下载,会自动下载。

transform=transform:指定了前面定义的数据预处理管道,将在加载数据时应用。

trainloader 和 testloader 的定义:这两行代码创建了数据加载器,将数据集划分为批次以进行训练和测试。

torch.utils.data.DataLoader:这是PyTorch提供的用于加载数据的工具,可以自动处理数据的分批和洗牌等任务。

batch_size=batch_size:指定了每个批次的大小,即每次加载多少图像数据。

shuffle=True 和 shuffle=False:shuffle参数指定是否在每个epoch(训练轮次)之前对数据进行洗牌,以增加数据的随机性。通常在训练时进行洗牌,而在测试时不进行洗牌。

num_workers=0:这个参数指定用于数据加载的线程数。在此代码中,设置为0表示不使用多线程加载数据。如果有多个CPU核心可用,可以将其设置为大于0的值以加速数据加载。

class My_CNN(nn.Module):

def __init__(self):

    super().__init__()

省略部分代码

def forward(self, x):

省略部分代码

    return x

model = My_CNN()

criterion = nn.CrossEntropyLoss()

optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

解释:

这部分代码定义了一个卷积神经网络(CNN)模型,并创建了用于训练该模型的损失函数和优化器。让我们逐步解释每一部分:

class My_CNN(nn.Module)::这是一个自定义的CNN模型类的定义。这个类继承自nn.Module,这是PyTorch中构建神经网络模型的基本方式。

def init(self)::这是构造函数,用于初始化CNN模型的各个层。

super().init():调用父类nn.Module的构造函数以确保正确初始化模型。

self.conv1 和 self.conv2:这是两个卷积层的定义,分别具有不同数量的输入和输出通道以及卷积核的大小。

self.pool:这是最大池化层的定义,用于减小特征图的空间尺寸。

self.fc1、self.fc2 和 self.fc3:这是三个全连接层(也称为线性层),用于将卷积层的输出转换为最终的分类结果。

def forward(self, x)::这是前向传播函数,定义了模型的前向传播过程。

在前向传播中,输入x经过卷积层、激活函数(F.relu)、池化层以及全连接层,最终输出分类结果。

torch.flatten(x, 1):这一步将卷积层的输出扁平化,以便将其输入到全连接层。

返回值是模型的输出,表示对输入数据的分类预测。

model = My_CNN():创建了My_CNN类的一个实例,即CNN模型。

criterion = nn.CrossEntropyLoss():定义了损失函数,这里使用的是交叉熵损失函数。它用于衡量模型的预测与实际标签之间的差距,是一个用于监督学习任务的常见损失函数。

optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9):定义了优化器,这里使用的是随机梯度下降(SGD)。优化器负责更新模型的参数,以减小损失函数的值。学习率(lr)和动量(momentum)是优化算法的超参数,影响了参数更新的速度和方向。

epochs=5

for epoch in range(epochs): # loop over the dataset multiple times

running_loss = 0.0

for i, data in enumerate(trainloader, 0):

    # get the inputs; data is a list of [inputs, labels]

    inputs, labels = data



    # zero the parameter gradients

    optimizer.zero_grad()

    # forward + backward + optimize

    outputs = model(inputs)

    loss = criterion(outputs, labels)

    loss.backward()

    optimizer.step()

    # print statistics

    running_loss += loss.item()

    if i % 2000 == 1999:    # print every 2000 mini-batches

        print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')

        running_loss = 0.0

print('Finished Training')

name_path = './cnn_model_model.pth'

torch.save(model,name_path)

解释:

epochs=5:定义了训练的轮次(epochs),也就是模型将遍历整个训练数据集的次数。

for epoch in range(epochs)::这是一个循环,遍历每个训练轮次。

running_loss = 0.0:用于追踪每个训练轮次的累积损失。

for i, data in enumerate(trainloader, 0)::这个嵌套循环遍历训练数据集的小批次。

i 表示当前批次的索引。

data 包含了当前批次的输入数据和标签。

optimizer.zero_grad():在每个批次开始时,将优化器的梯度清零,以便准备计算新的梯度。

outputs = model(inputs):进行前向传播,将输入数据传递给模型,得到模型的输出。

loss = criterion(outputs, labels):计算损失,衡量模型的预测与实际标签之间的差距。使用了前面定义的交叉熵损失函数。

loss.backward():进行反向传播,计算模型参数相对于损失的梯度。

optimizer.step():根据计算得到的梯度,更新模型的参数,以减小损失函数的值。

running_loss += loss.item():累积当前批次的损失值,用于后续打印统计信息。

if i % 2000 == 1999::每经过2000个小批次,打印一次统计信息。这是为了跟踪训练进度,查看损失是否在逐渐减小。

print(f’[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}'):打印当前训练轮次和批次的损失值。

running_loss = 0.0:重置累积损失值,以便下一个统计周期。

print(‘Finished Training’):当所有轮次的训练完成后,打印 “Finished Training” 以指示训练结束。

name_path = ‘./cnn_model_model.pth’:指定模型的保存路径。

torch.save(model, name_path):将训练好的模型保存到指定路径。这样可以在之后的任务中加载和使用该模型,而不需要重新训练。

这段代码执行了模型的训练过程,循环遍历多个轮次,每轮次内遍历训练数据的小批次。在每个小批次中,进行前向传播、计算损失、反向传播以及参数更新。训练的目标是通过调整模型参数,减小损失函数的值,从而提高模型的性能。同时,每隔一定数量的小批次,打印训练统计信息以监视训练进度。最后,训练完成后,模型被保存到文件以备将来使用。

classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')

prepare to count predictions for each class

correct_pred = {classname: 0 for classname in classes}

total_pred = {classname: 0 for classname in classes}

again no gradients needed

with torch.no_grad():

for data in testloader:

    images, labels = data

    outputs = model(images)

    _, predictions = torch.max(outputs, 1)

    # collect the correct predictions for each class

    for label, prediction in zip(labels, predictions):

        if label == prediction:

            correct_pred[classes[label]] += 1

        total_pred[classes[label]] += 1

print accuracy for each class

for classname, correct_count in correct_pred.items():

accuracy = 100 * float(correct_count) / total_pred[classname]

print(f'Accuracy for class: {classname:5s} is {accuracy:.1f} %')

解释:

classes = (‘plane’, ‘car’, ‘bird’, ‘cat’,‘deer’, ‘dog’, ‘frog’, ‘horse’, ‘ship’, ‘truck’):这是数据集中的类别标签,它们代表CIFAR-10数据集中的10个不同类别,分别是飞机、汽车、鸟类、猫、鹿、狗、青蛙、马、船和卡车。

correct_pred 和 total_pred:这两个字典用于跟踪每个类别的正确预测数量和总预测数量,初始化为零。

with torch.no_grad()::这个语句块指示在此之后的计算不需要梯度信息。这是因为在测试阶段,我们不需要计算梯度,只是进行前向传播和计算准确度。

for data in testloader::遍历测试数据集的小批次。

images, labels = data:将小批次数据分成图像和对应的标签。

outputs = model(images):使用训练好的模型对图像进行预测,得到模型的输出。

_, predictions = torch.max(outputs, 1):通过 torch.max 函数找到每个样本预测的类别,即具有最高预测分数的类别。

for label, prediction in zip(labels, predictions)::通过 zip 函数,将实际标签和预测标签一一对应起来,以便比较它们。

if label == prediction::比较实际标签和预测标签,如果它们相等,表示模型做出了正确的预测。

correct_pred[classes[label]] += 1:对应类别的正确预测数量加一。

total_pred[classes[label]] += 1:对应类别的总预测数量加一。

for classname, correct_count in correct_pred.items()::遍历每个类别和其正确预测数量。

accuracy = 100 * float(correct_count) / total_pred[classname]:计算每个类别的准确度,即正确预测数量除以总预测数量,以百分比表示。

print(f’Accuracy for class: {classname:5s} is {accuracy:.1f} %'):打印每个类别的准确度,格式化输出。

总结:这段代码的目标是计算并打印出每个类别的分类准确度,以便评估模型在不同类别上的性能。这是在测试阶段对模型性能进行评估的一种方式。

测试模型的代码:

import torch

from PIL import Image

from torch import nn

import torch

import torchvision

import torch.nn.functional as F

device = torch.device('cuda')

image_path=“plane.png”

image =Image.open(image_path)

print(image)

image=image.convert('RGB')

transform=torchvision.transforms.Compose([torchvision.transforms.Resize((32,32)),torchvision.transforms.ToTensor()])

image=transform(image)

print(image.shape)

class My_CNN(nn.Module):

def __init__(self):

    super().__init__()

 省略部分代码        return x

#加载模型

model = torch.load(“cnn_net_model.pth”,map_location=torch.device(‘cuda’))#加载完成网络模型,映射

print(model)#维数不够

image = torch.reshape(image,(1,3,32,32))#这一个很重要,要满足四个通道

image=image.to(device)#做cuda变换,不然报错

model.eval()

with torch.no_grad():#节约内存性能

output=model(image)

#识别类别,数字最大的就是我们的结果

print(output)

解释:

导入必要的库和模块:

torch:PyTorch库,用于构建和运行深度学习模型。

PIL:Python Imaging Library,用于处理图像。

nn:PyTorch的神经网络模块。

F:PyTorch的函数模块。

device:将模型加载到GPU设备。

image_path:待分类的图像文件路径。

Image.open(image_path):使用PIL库打开图像文件。

图像的预处理:

image.convert(‘RGB’):将图像转换为RGB模式,以确保图像通道数为3。

transform:定义了一系列的图像预处理操作,包括将图像缩放到32x32像素大小并将其转换为PyTorch的Tensor数据类型。

image = transform(image):应用上述的预处理操作,将图像准备好以供模型处理。

定义神经网络模型:

My_CNN 类:这是一个自定义的卷积神经网络模型,包括两个卷积层,两个池化层,以及三个全连接层。这个模型与之前训练的CNN模型相似,用于图像分类任务。

加载预训练模型:

model = torch.load(“cnn_net_model.pth”, map_location=torch.device(‘cuda’)):加载之前训练并保存的CNN模型。map_location 参数指定了模型的加载位置,这里指定为CUDA/GPU。

调整输入图像的维度和数据类型:

image = torch.reshape(image, (1, 3, 32, 32)):将输入的图像数据调整为适合模型的维度(1个样本,3个通道,32x32像素大小)。

image = image.to(device):将图像数据移动到GPU设备,以便进行GPU上的推理。

模型推理和分类:

model.eval():将模型切换到推理模式,这意味着模型不再更新梯度。

with torch.no_grad()::在这个块中,不会计算或保存梯度信息,以提高性能和节省内存。

output = model(image):对输入的图像进行前向传播,得到模型的输出。

print(output):打印模型的输出,这是一个包含了不同类别的分数的张量。

测试结果:

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1156624.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

大模型时代的人工智能+大数据平台,加速创新涌现

大模型和MaaS概念的出现,定义了以模型为中心的一整套AI开发新范式,而这背后日益增长的巨大算力需求,对AI工程底座提出了新的挑战。今天,大模型时代下的人工智能大数据平台,需要具备计算效率、开发效率、处理效率为一体…

Java日志组件介绍之二

一、前言 Java日志组件介绍之一 主要介绍了JDK内置日志和Apache的common-logging通用日志接口,今天这篇我们继续了解Java其它一些日志组件。 二、slf4j slf4j即Simple Logging Facade for JAVA ,简单日志门面,类似common-logging&#xff0…

RBAC:基于角色的访问控制

1.介绍 RBAC是一种库表设计思想 基于角色的访问控制(RBAC)是实施面向企业安全策略的一种有效的访问控制方式。一种数据库的设计思想,其核心是角色。其基本思想是,对系统操作的各种权限不是直接授予具体的用户,而是在…

element表格自定义筛选

文章目录 前言一、简介二、效果展示三、源码总结 前言 提示:这里可以添加本文要记录的大概内容: …待续 提示:以下是本篇文章正文内容,下面案例可供参考 一、简介 修改el-table的筛选…待续 二、效果展示 三、源码 使用方法…

视频汇聚平台EasyCVR分发的流如何进行token鉴权?具体步骤是什么?

视频监控EasyCVR平台能在复杂的网络环境中,将分散的各类视频资源进行统一汇聚、整合、集中管理,在视频监控播放上,TSINGSEE青犀视频安防监控汇聚平台可支持1、4、9、16个画面窗口播放,可同时播放多路视频流,也能支持视…

喜报!CACTER邮件安全网关荣获2023鲲鹏应用创新大赛广东赛区三等奖

近期,2023鲲鹏应用创新大赛广东赛区暨广东省信息技术应用创新产业联盟创新大赛圆满落幕,Coremail凭借“基于鲲鹏CPU的邮件网关一体机解决方案”,荣获“金融行业方向”三等奖。 ​ 鲲鹏凌粤 展翅湾区 本届大赛广东区域赛以“鲲鹏凌粤 展翅湾…

数据结构与算法-树和森林

🌞 “永远面朝阳光,阴影被甩在身后!” 树和森林 🎈1.线索二叉树🎈2.树和森林🔭2.1树的存储结构🔭2.2双亲表示法🔭2.3孩子链表表示法📝2.3.1孩子链表表示法的实现&#x1…

基于深度学习网络的美食检测系统matlab仿真

目录 1.算法运行效果图预览 2.算法运行软件版本 3.部分核心程序 4.算法理论概述 5.算法完整程序工程 1.算法运行效果图预览 2.算法运行软件版本 matlab2022a 3.部分核心程序 % 图像大小 image_size [224 224 3]; num_classes size(VD,2)-1;% 目标类别数量…

解决:Xshell连接服务器卡在To escape to local shell, press ‘Ctrl+Alt+]‘.很久才能够连接上

如下图:在输入服务器的账号密码后,会卡在这里没有任何反映需要几分钟才能连接上 造成这个情况的原因: 在SSH服务中,UseDNS用于指定当用户SSH登录一个域名时,服务器是否使用DNS来确认该域名对应的IP地址。如果UseDNS设置…

【约会云栖】从初中至大学,我见证了科技变革的历程。

前言 提起阿里云开发者大会, 你一定会觉得陌生;但提起云栖大会,你又会耳熟能详。实际上,云栖大会的前身就是阿里云开发者大会,2015年,它永久落户在杭州市西湖区云栖小镇。 2023年10月31日至11月2日&#xf…

echarts 饼图中心添加图片

需求 问题 - 暂时无法解决(如果图标居中不存在该问题) 由于此处饼图位置不处于当前 echarts 图表容器的中心位置,而是偏左一点,我们需要设置: 中心图片所在靠左位置【见 - 主要代码1】官方手册 https://echarts.apache…

记一次 logback 没有生成独立日志文件问题

背景 在新项目发布后发现日志文件并没有按照期望的方式独立开来&#xff0c;而是都写在了 application.log 文件中。 问题展示 日志文件&#xff1a; 项目引入展示&#xff1a; <include resource"paas/sendinfo/switch/client/sendinfo-paas-switch-client-log.…

全面解析C++ std::move

全面解析C std::move 本篇文章首发知识星球&#xff0c;感兴趣的可以点击下方加入即可。 std::move 是一个非常重要的函数&#xff0c;它提供了一种方式&#xff0c;可以将一个左值对象标记为一个特殊类型的右值对象&#xff0c;即 xvalue。这种转变是为了允许执行移动语义&…

通过wordpress能搭建有影响力的帮助中心

wordpress建站服务是一种提供简单易用的工具和功能&#xff0c;帮助用户轻松创建和管理网站的服务。它适用于各类网站管理员、个人博主和小型企业主&#xff0c;无论是想要搭建个人博客、展示作品集还是开设在线商店&#xff0c;都可以通过wordpress建站服务来实现。 | 一、搭建…

反转链表II(C++解法)

题目 给你单链表的头指针 head 和两个整数 left 和 right &#xff0c;其中 left < right 。请你反转从位置 left 到位置 right 的链表节点&#xff0c;返回 反转后的链表 。 示例 1&#xff1a; 输入&#xff1a;head [1,2,3,4,5], left 2, right 4 输出&#xff1a;[1…

嘴尚绝:健康卤味风靡市场,美味与健康并行

随着人们生活水平的提高&#xff0c;健康饮食成为越来越多人的追求。在卤味市场&#xff0c;传统重口味卤味逐渐被健康卤味所取代。本文将探讨健康卤味如何逐步占领市场&#xff0c;以及其背后的原因和未来的发展趋势。 卤味&#xff0c;作为中国美食的代表之一&#xff0c;有…

【FastBond2阶段1——基于ESP32C3开发的简易IO调试设备】

【FastBond2阶段1——基于ESP32C3开发的简易IO调试设备】 1. 功能介绍2. 主要元器件介绍2.1 主控板&#xff1a;CORE ESP32-C3核心板2.2 传感器2.2.1 旋转编码器&#xff1a;2.2.2 模拟ADC&#xff1a;2.2.3 GPIO接口&#xff1a; 2.3 执行器2.3.1 WS2812:2.3.2 90舵机&#xf…

Linux期末复习——C编程基础

Linux下C语言编译环境概述 编译器&#xff1a;VI 编译器&#xff1a;GCC 调试器&#xff1a;GDB 项目管理器&#xff1a;make vi编辑器 三种模式 命令行模式&#xff1a;默认模式&#xff0c;不可以编辑&#xff0c;只可以上下移动光标“整行删除&#xff0c;删除字符”&…

【C++】多态 ⑨ ( vptr 指针初始化问题 | 构造函数 中 调用 虚函数 - 没有多态效果 )

文章目录 一、vptr 指针初始化问题1、vptr 指针与虚函数表2、vptr 指针初始化时机3、构造函数 中 调用 虚函数 - 没有多态效果4、代码示例 构造函数 的 作用就是 创建对象 , 构造函数 最后 一行代码 执行完成 , 才意味着 对象构建完成 , 对象构建完成后 , 才会将 vptr 指针 指向…

前端面试 面试多起来了

就在昨天 10.17 号,同时收到了三个同学面试的消息。他们的基本情况都是双非院校本科、没有实习经历、不会消息中间件和 Spring Cloud 微服务,做的都是单体项目。但他们投递简历还算积极,从今年 9 月初就开始投递简历了,到现在也有一个多月了。 来看看,这些消息。 为…