深度学习系列2——Pytorch 图像分类(AlexNet)

news2025/7/23 23:50:21

1. 概述

本文主要是参照 B 站 UP 主 霹雳吧啦Wz 的视频学习笔记,参考的相关资料在文末参照栏给出,包括实现代码和文中用的一些图片。

整个工程已经上传个人的 github https://github.com/lovewinds13/QYQXDeepLearning ,下载即可直接测试,数据集文件因为比较大,已经删除了,按照下文教程下载即可。

论文下载:ImageNet Classification with Deep Convolutional Neural Networks

2. AlexNet

AlexNet 是 2012 年 ISLVRC 2012(ImageNet Large Scale Visual Recognition Challenge) 竞赛的冠军网络, 分类准确率由传统的 70%+ 提升到 80%+。它是由 Hinton 和他的学生Alex Krizhevsky设计的。

AlexNet 2012 年在大规模图像识别中一起绝尘,从此引领了深度神经网络的热潮。另外,AlexNet 提出的 ReLU 激活函数, LRN, GPU 加速,数据增强,Dropout 失活部分神经元等方式,深刻的影响了后续的神经网络。

2.1 网络框架

AlexNet 共有 8 层组成,其中包括 5 个卷积层,3 个全连接层。各层参数如下:

计算公式参考深度学习系列1——Pytorch 图像分类(LeNet),此处不再列出。

其传输框图如下:

卷积层1数据传输如下:


其余层依此类似,只需类推。

2.2 补充

2.2.1 过拟合

根本原因是特征维度过多, 模型假设过于复杂, 参数过多, 训练数据过少, 噪声过多,导致拟合的函数完美的预测训练集, 但对新数据的测试集预测结果差。 过度的拟合了训练数据, 而没有考虑到泛化能力。

过拟合主要受数据量和模型复杂度的影响。

一句话就是:平时作业完成的非常好,但是考试就歇菜了。

2.2.2 Dropout

网络正向传播过程中随机失活一部分神经元,减少过拟合。Dropout 主要用在全连接层。

3. demo 实现

3.1 数据集

本文使用花分类数据集,下载链接: 花分类数据集——http://download.tensorflow.org/example_images/flower_photos.tgz

数据集划分参考这个pytorch图像分类篇:3.搭建AlexNet并训练花分类数据集

3.1 demo 结构:

在 CPU 训练的基础上,为了修改为 GPU 训练,因此单独修改了一个文件 train_gpu.py。

3.2 model.py

"""
模型
"""


import torch.nn as nn
import torch


class AlexNet(nn.Module):
    def __init__(self, num_classes=1000, init_weights=False):
        super(AlexNet, self).__init__()
        """
        特征提取
        """
        self.features = nn.Sequential(
            nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2),   # 输入[3, 224, 224] 输出[48, 55, 55]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),  # 输出 [48,27,27]
            nn.Conv2d(48, 128, kernel_size=5, padding=2),   # 输出 [128, 27, 27]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),  # 输出 [128, 13, 13]
            nn.Conv2d(128, 192, kernel_size=3, padding=1),  # 输出[192, 13, 13]
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 192, kernel_size=3, padding=1), # 输出[192, 13, 13]
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 128, kernel_size=3, padding=1),  # 输出[128, 13, 13]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2)   # 输出 [128, 6, 6]
        )
        """
        分类器
        """
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.5),  # Dropout 随机失活神经元, 比例诶0.5
            nn.Linear(128 * 6 * 6, 2048),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(2048, 2048),
            nn.ReLU(inplace=True),
            nn.Linear(2048, num_classes)
        )
        if init_weights:
            self._initialize_weights()

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, start_dim=1)
        x = self.classifier(x)

        return x

    """
    权重初始化
    """
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0.01)
                nn.init.constant_(m.bias, 0)

"""
测试模型
"""
# if __name__ == '__main__':
#     input1 = torch.rand([224, 3, 224, 224])
#     model_x = AlexNet()
#     print(model_x)
    # output = AlexNet(input1)

此处的网络模型仅使用了一半的参数,即原 AlexNet 两块 GPU 中的一块。

3.2.1 nn.Sequential 介绍

nn.Sequential 是 nn.Module 的容器,用于按顺序包装一组网络层,在模型复杂情况下,使用 nn.Sequential 方法对模块划分。除了 nn.Sequetial,还有 nn.ModuleList 和 nn.ModuleDict。

3.2.2 Tensor 展平

直接使用 torch.flatten 方法展平张量


 x = torch.flatten(x, start_dim=1)	# 二维平坦
 

3.3 model.py

3.3.1 导入包


"""
训练(CPU)
"""
import os
import sys
import json
import time
import torch
import torch.nn as nn
from torchvision import transforms, datasets, utils
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
from tqdm import tqdm   # 显示进度条模块

from model import AlexNet

3.3.2 数据集预处理


    data_transform = {
        "train": transforms.Compose([
                                    transforms.RandomResizedCrop(224),  # 随机裁剪, 再缩放为 224*224
                                    transforms.RandomHorizontalFlip(),  # 水平随机翻转
                                    transforms.ToTensor(),
                                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ]),
        "val": transforms.Compose([
                                    transforms.Resize((224, 224)),  # 元组(224, 224)
                                    transforms.ToTensor(),
                                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])
    }

3.3.3 加载数据集

3.3.3.1 读取数据路径


# data_root = os.path.abspath(os.path.join(os.getcwd(), "../..")) # 读取数据路径
data_root = os.path.abspath(os.path.join(os.getcwd(), "./"))
image_path = os.path.join(data_root, "data_set", "flower_data")
# image_path = data_root + "/data_set/flower_data/"
assert os.path.exists(image_path), "{} path does not exist.".format(image_path)

此处相比于 UP 主教程,修改了读取路径。

3.3.3.2 加载训练集


 train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
                                         transform=data_transform["train"]
                                         )
 train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=nw
                                               )

3.3.3.3 加载验证集


val_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
                                       transform=data_transform["val"]
                                       )
val_num = len(val_dataset)
val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=4,
                                             shuffle=False,
                                             num_workers=nw
                                             )

3.3.3.4 保存数据索引


flower_list = train_dataset.class_to_idx
    cla_dict = dict((val, key) for key, val in flower_list.items())
    json_str = json.dumps(cla_dict, indent=4)
    with open("calss_indices.json", 'w') as json_file:
        json_file.write(json_str)

3.3.4 训练过程


net = AlexNet(num_classes=5, init_weights=True)     # 实例化网络(5分类)
    # net.to(device)
    net.to("cpu")   # 直接指定 cpu
    loss_function = nn.CrossEntropyLoss()   # 交叉熵损失
    optimizer = optim.Adam(net.parameters(), lr=0.0002)     # 优化器(训练参数, 学习率)

    epochs = 10     # 训练轮数
    save_path = "./AlexNet.pth"
    best_accuracy = 0.0
    train_steps = len(train_loader)
    for epoch in range(epochs):
        net.train()     # 开启Dropout
        running_loss = 0.0
        train_bar = tqdm(train_loader, file=sys.stdout)     # 设置进度条图标
        for step, data in enumerate(train_bar):     # 遍历训练集,
            images, labels = data   # 获取训练集图像和标签
            optimizer.zero_grad()   # 清除历史梯度
            outputs = net(images)   # 正向传播
            loss = loss_function(outputs, labels)   # 计算损失值
            loss.backward()     # 方向传播
            optimizer.step()    # 更新优化器参数
            running_loss += loss.item()
            train_bar.desc = "train epoch [{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                      epochs,
                                                                      loss
                                                                      )
        # 验证
        net.eval()      # 关闭Dropout
        acc = 0.0
        with torch.no_grad():
            val_bar = tqdm(val_loader, file=sys.stdout)
            for val_data in val_bar:
                val_images, val_labels = val_data
                outputs = net(val_images)
                predict_y = torch.max(outputs, dim=1)[1]
                acc += torch.eq(predict_y, val_labels).sum().item()
        val_accuracy = acc / val_num
        print("[epoch %d ] train_loss: %3f    val_accurancy: %3f" %
              (epoch + 1, running_loss / train_steps, val_accuracy))
        if val_accuracy > best_accuracy:    # 保存准确率最高的
            best_accuracy = val_accuracy
            torch.save(net.state_dict(), save_path)
    print("Finshed Training.")

训练过程可视化信息输出:


GPU 训练代码: 仅在 CPU 训练的基础上做了数据转换处理。


"""
训练(GPU)
"""
import os
import sys
import json
import time
import torch
import torch.nn as nn
from torchvision import transforms, datasets, utils
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
from tqdm import tqdm

from model import AlexNet


def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"use device is {device}")

    data_transform = {
        "train": transforms.Compose([
                                    transforms.RandomResizedCrop(224),
                                    transforms.RandomHorizontalFlip(),
                                    transforms.ToTensor(),
                                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ]),
        "val": transforms.Compose([
                                    transforms.Resize((224, 224)),
                                    transforms.ToTensor(),
                                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])
    }
    # data_root = os.path.abspath(os.path.join(os.getcwd(), "../..")) # 读取数据路径
    data_root = os.path.abspath(os.path.join(os.getcwd(), "./"))
    image_path = os.path.join(data_root, "data_set", "flower_data")
    # image_path = data_root + "/data_set/flower_data/"
    assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
    train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
                                         transform=data_transform["train"]
                                         )
    train_num = len(train_dataset)
    flower_list = train_dataset.class_to_idx
    cla_dict = dict((val, key) for key, val in flower_list.items())
    json_str = json.dumps(cla_dict, indent=4)
    with open("calss_indices.json", 'w') as json_file:
        json_file.write(json_str)

    batch_size = 32
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # 线程数计算
    nw = 0
    print(f"Using {nw} dataloader workers every process.")

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=nw
                                               )
    val_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
                                       transform=data_transform["val"]
                                       )
    val_num = len(val_dataset)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=4,
                                             shuffle=False,
                                             num_workers=nw
                                             )
    print(f"Using {train_num} images for training, {val_num} images for validation.")

    # test_data_iter = iter(val_loader)
    # test_image, test_label = next(test_data_iter)

    """ 测试数据集图片"""
    # def imshow(img):
    #     img = img / 2 + 0.5
    #     np_img = img.numpy()
    #     plt.imshow(np.transpose(np_img, (1, 2, 0)))
    #     plt.show()
    # print(' '.join('%5s' % cla_dict[test_label[j].item()] for j in range(4)))
    # imshow(utils.make_grid(test_image))

    net = AlexNet(num_classes=5, init_weights=True)
    net.to(device)
    loss_function = nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters(), lr=0.0002)

    epochs = 10
    save_path = "./AlexNet.pth"
    best_accuracy = 0.0
    train_steps = len(train_loader)
    for epoch in range(epochs):
        net.train()
        running_loss = 0.0
        train_bar = tqdm(train_loader, file=sys.stdout)
        for step, data in enumerate(train_bar):
            images, labels = data
            optimizer.zero_grad()
            outputs = net(images.to(device))
            loss = loss_function(outputs, labels.to(device))
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            train_bar.desc = "train epoch [{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                      epochs,
                                                                      loss
                                                                      )
        # 验证
        net.eval()
        acc = 0.0
        with torch.no_grad():
            val_bar = tqdm(val_loader, file=sys.stdout)
            for val_data in val_bar:
                val_images, val_labels = val_data
                outputs = net(val_images.to(device))
                predict_y = torch.max(outputs, dim=1)[1]
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()
        val_accuracy = acc / val_num
        print("[epoch %d ] train_loss: %3f    val_accurancy: %3f" %
              (epoch + 1, running_loss / train_steps, val_accuracy))
        if val_accuracy > best_accuracy:
            best_accuracy = val_accuracy
            torch.save(net.state_dict(), save_path)
    print("Finshed Training.")


3.3.5 结果预测

"""
预测
"""

"""
预测
"""

import os
import json
import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
from model import AlexNet


def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    data_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    image_path = "./sunflowers01.jpg"
    img = Image.open(image_path)
    plt.imshow(img)
    img = data_transform(img)   # [N, C H, W]
    img = torch.unsqueeze(img, dim=0)   # 维度扩展
    # print(f"img={img}")
    json_path = "./calss_indices.json"
    with open(json_path, 'r') as f:
        class_indict = json.load(f)

    # model = AlexNet(num_classes=5).to(device)   # GPU
    model = AlexNet(num_classes=5)  # CPU
    weights_path = "./AlexNet.pth"
    model.load_state_dict(torch.load(weights_path))
    model.eval()    # 关闭 Dorpout
    with torch.no_grad():
        # output = torch.squeeze(model(img.to(device))).cpu()   #GPU
        output = torch.squeeze(model(img))      # 维度压缩
        predict = torch.softmax(output, dim=0)
        predict_cla = torch.argmax(predict).numpy()
        print_res = "class: {}  prob: {:.3}".format(class_indict[str(predict_cla)],
                                                    predict[predict_cla].numpy())
        plt.title(print_res)
        # for i in range(len(predict)):
        #     print("class: {}  prob: {:.3}".format(class_indict[str(predict_cla)],
        #                                             predict[predict_cla].numpy()))
        plt.show()


预测结果如下:

输入向日葵,预测准确率为 1.0 。


欢迎关注公众号:【千艺千寻】,共同成长


参考:

  1. pytorch图像分类篇:3.搭建AlexNet并训练花分类数据集
  2. B站UP主——3.2 使用pytorch搭建AlexNet并训练花分类数据集
  3. TensorFlow2学习十三、实现AlexNet
  4. PyTorch 11:模型容器与 AlexNet 构建
  5. PyTorch 10:模型创建步骤与 nn.Module
  6. 卷积——动图演示
  7. Python3 os.path() 模块

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/35471.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

你了解PMP考试新考纲的内容吗?

2021年新版PMP考纲变化趋势 随着时代发展,PMP认证本身也通过改版不断调整定位,与全球项目管理趋势相匹配,确保在全球项目管理专业领域保持“黄金标准”。 新版本变化如下: 五大过程组变为三大板块。之前一直沿用的“启动、规划…

Transformer时间序列预测

介绍: 提示:Transformer-decoder 总体介绍 本文将介绍一个 Transformer-decoder 架构,用于预测Woodsense提供的湿度时间序列数据集。该项目是先前项目的后续项目,该项目涉及在同一数据集上训练一个简单的 LSTM。人们认为 LSTM 在…

阿里P8总结的Nacos入门笔记,从安装到进阶小白也能轻松学会

前言 都说程序员工资高、待遇好, 2022 金九银十到了,你的小目标是 30K、40K,还是 16薪的 20K?作为一名 Java 开发工程师,当能力可以满足公司业务需求时,拿到超预期的 Offer 并不算难。然而,提升…

GPC规范-SCP02

SPC02 流程 SPC02 指令 命令: 响应: 举例回复: 密钥分散数据: 0000FFFFFFFFFFFFFFFF Key Info: 20 02(scp02) Card挑战数: 001AC6619BE83082 Card加密值: 7…

leetcode刷题(133)——剑指 Offer 07. 重建二叉树

输入某二叉树的前序遍历和中序遍历的结果,请构建该二叉树并返回其根节点。 假设输入的前序遍历和中序遍历的结果中都不含重复的数字。 示例 1: Input: preorder [3,9,20,15,7], inorder [9,3,15,20,7] Output: [3,9,20,null,null,15,7]示例 2: Input: preord…

(十一)笔记.net学习表达式目录树Expression

(十一)笔记.net学习表达式目录树Expression1.什么是表达式目录树(1)Func和表达式的不同(2)表达式树拆解(3)自己拼装表达式目录2.动态拼装表达式目录和扩展应用3.解析表达式目录&…

阿里云服务器采用AMD CPU处理器ECS实例规格详解

阿里云服务器有AMD CPU处理器,阿里云服务器ECS通用型g7a、计算型c7a和内存型r7a采用2.55 GHz主频的AMD EPYCTM MILAN处理器,单核睿频最高3.5 GHz;通用型g6a、计算型c6a和内存型r6a采用2.6 GHz主频的AMD EPYCTM ROME处理器,睿频3.3…

MySQL读取的记录和我想象的不一致——事物隔离级别和MVCC

本篇是《MySQL是怎样运行的》读书笔记,主要分析并发的事务在运行过程中会出现一些可能引发一致性问题的现象。 文章目录1.事务的特性简介1.1 原子性(Atomicity)1.2 隔离性(Isolation)1.3 一致性(Consistenc…

JUC基础

synchronized 复习虚假唤醒什么是虚假唤醒虚假唤醒产生的原因?解决虚假唤醒?Lock接口ReentrantLock 和 synchronized 的区别Lock 实现线程通信Lock 实现线程定制化通信集合线程安全ArrayListHashSetHashMapsynchronized 锁的范围多线程锁公平锁和非公平锁…

CameraMetadata 知识学习整理

一、涉及的相关代码路径 system/media/camera/src/camera_metadata.c // metadata的核心内容,包含metadata内存分配,扩容规则,update, find等 system/media/camera/src/camera_metadata_tag_info.c // 所有android原生tag的在内存里面sect…

22/11/24

1,单调队列; (76条消息) 单调队列专题_Dull丶的博客-CSDN博客 2,kmp算法; 先是自己和自己匹配,求出ne数组,然后和另一串匹配,进行求解; 循环里三步:while&#xff0c…

【Lilishop商城】No2-3.确定软件架构搭建二(本篇包括接口规范、日志处理)

仅涉及后端,全部目录看顶部专栏,代码、文档、接口路径在: 【Lilishop商城】记录一下B2B2C商城系统学习笔记~_清晨敲代码的博客-CSDN博客 全篇只介绍重点架构逻辑,具体编写看源代码就行,读起来也不复杂~ 谨慎&#xf…

【数据聚类】基于粒子群、遗传和差分算法实现数据聚类附matlab代码

✅作者简介:热爱科研的Matlab仿真开发者,修心和技术同步精进,matlab项目合作可私信。 🍎个人主页:Matlab科研工作室 🍊个人信条:格物致知。 更多Matlab仿真内容点击👇 智能优化算法 …

【App自动化测试】(十)特殊控件Toast识别

目录1. toast介绍2. toast定位3. 实例演示前言: 本文为在霍格沃兹测试开发学社中学习到的一些技术写出来分享给大家,希望有志同道合的小伙伴可以一起交流技术,一起进步~ 😘 1. toast介绍 Toast,简易的消息提示框。为了…

CANdelaStudio-从入门到深入目录

前文介绍诊断协议那些事儿专栏,为大家深入介绍了ISO 14229各个服务的基础知识、请求与响应的报文格式,详情可查看:诊断协议那些事儿,从本专题开始,将由浅入深的展开诊断实际开发与测试的数据库编辑,包含大量实际开发过程中的步骤、使用技巧与少量对Autosar标准的解读。希…

HTML5学习笔记(四)

CSS3 颜色样式 在CSS3中,增加了大量定义颜色方面样式的属性,主要包括以下3种。 ▶ opacity透明度 ▶ RGBA颜色 ▶ CSS3渐变 opacity透明度 opacity属性取值是一个数值,取值范围为0.0~1.0。其中0.0表示完全透明,1.0表示完全不透…

I/O模型

网络IO的本质 网络IO的本质就是socket流的读取,通常一次IO读取会涉及两个阶段与两个对象,其中两个对象为:用户进程(线程)Process(Thread)、内核对象(kernel),两个阶段为…

北方地区长乐市污水厂(150000m3d)工艺设计

目 录 1设计说明书 3 1.1概述 3 1.1.1设计题目 3 1.1.2设计任务 3 1.1.3设计阶段(设计程度) 3 1.1.4设计依据 3 1.1.5设计原始资料 3 1.1.6设计工作量 5 1.1.7设计要求 5 1.1.8 毕业设计日期 5 1.2 设计要求 6 1.2.1 设计原则 6 1.2.3 设计内容 6 1.3 水…

PLC中ST编程的比较运算

比较运算符&#xff1a; >大于、 <小于、 >大于等于、 <小于等于、等于、 <>不等于。 BOOL类型的比较是通过1&#xff08;TRUE&#xff09;和0&#xff08;FALSE&#xff09;来比较的&#xff1b; 只有xIn_1为真&#xff0c;xIn_2为假的时候&#xff0c;xRe…

《小猫猫大课堂》1——小喵是如何开启敲代码之路的?

更新不易&#xff0c;麻烦多多点赞&#xff0c;欢迎你的提问&#xff0c;感谢你的转发&#xff0c; 最后的最后&#xff0c;关注我&#xff0c;关注我&#xff0c;关注我&#xff0c;你会看到更多有趣的博客哦&#xff01;&#xff01;&#xff01; 喵喵喵&#xff0c;你对我真…