深度学习Week8-咖啡豆识别(Pytorch)

news2025/7/13 5:15:04

目录

 一、前期准备

1.设置GPU

2. 导入数据

 3. 划分数据集

 二、手动搭建VGG-16模型

1. 搭建模型​编辑

2. 查看模型详情

 三、 训练模型

1. 编写训练函数

2. 编写测试函数

4. 正式训练

四、 结果可视化

1. Loss与Accuracy图

2. 指定图片进行预测

3. 模型评估

*五、优化模型

1.调整学习率和动态学习率


🍨 本文为[🔗365天深度学习训练营]中的学习记录博客
🍦 参考文章:Pytorch实战 | 第P7周:咖啡豆识别(训练营内部成员可读)
🍖 原作者:[K同学啊|接辅导、项目定制]

🍺 要求:

自己搭建VGG-16网络框架 √
调用官方的VGG-16网络框架 √
如何查看模型的参数量以及相关指标 √
🍻 拔高(可选):

验证集准确率达到100%
使用PPT画出VGG-16算法框架图(发论文需要这项技能)
🔎 探索(难度有点大)

在不影响准确率的前提下轻量化模型
● 目前VGG16的Total params是134,276,932
🏡 我的环境:
● 语言环境:Python 3.8
● 编译器:Pycharm
● 深度学习环境:Pytorch

 一、前期准备

1.设置GPU

import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision
from torchvision import transforms, datasets
import os, PIL, pathlib
 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
print(device)

输出:cuda

2. 导入数据

data_dir = './49-data/'
data_dir = pathlib.Path(data_dir)

data_paths = list(data_dir.glob('*'))
classeNames = [str(path).split("\\")[1] for path in data_paths]
print(classeNames)

['Dark', 'Green', 'Light', 'Medium']

图形变换,输出一下:用到torchvision.transforms.Compose()类,有兴趣的同学可以参考这篇博客:torchvision.transforms.Compose()详解【Pytorch手册】

train_transforms = transforms.Compose([
    transforms.Resize([224, 224]),  # 将输入图片resize成统一尺寸
    # transforms.RandomHorizontalFlip(), # 随机水平翻转
    transforms.ToTensor(),          # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间
    transforms.Normalize(           # 标准化处理-->转换为标准正太分布(高斯分布),使模型更容易收敛
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225])  # 其中 mean=[0.485,0.456,0.406]与std=[0.229,0.224,0.225] 从数据集中随机抽样计算得到的。
])

test_transform = transforms.Compose([
    transforms.Resize([224, 224]),  # 将输入图片resize成统一尺寸
    transforms.ToTensor(),          # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间
    transforms.Normalize(           # 标准化处理-->转换为标准正太分布(高斯分布),使模型更容易收敛
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225])  # 其中 mean=[0.485,0.456,0.406]与std=[0.229,0.224,0.225] 从数据集中随机抽样计算得到的。
])

total_data = datasets.ImageFolder("./49-data/",transform=train_transforms)
print(total_data.class_to_idx)

{'Dark': 0, 'Green': 1, 'Light': 2, 'Medium': 3}

 3. 划分数据集

因为不像week6有已经分好的训练集和测试集,所以这次要想以前那样,分为训练集和测试集.

train_size = int(0.8 * len(total_data))
test_size  = len(total_data) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(total_data, [train_size, test_size])
batch_size = 32
train_dl = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True,
                                           num_workers=0)
test_dl = torch.utils.data.DataLoader(test_dataset,
                                          batch_size=batch_size,
                                          shuffle=True,
                                          num_workers=0)
for X, y in test_dl:
    print("Shape of X [N, C, H, W]: ", X.shape)
    print("Shape of y: ", y.shape, y.dtype)
    break

Shape of X [N, C, H, W]:  torch.Size([32, 3, 224, 224])
Shape of y:  torch.Size([32]) torch.int64

 二、手动搭建VGG-16模型

1. 搭建模型

在这里插入图片描述

import torch.nn.functional as F

class vgg16(nn.Module):
    def __init__(self):
        super(vgg16, self).__init__()
        # 卷积块1
        self.block1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
        )
        # 卷积块2
        self.block2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
        )
        # 卷积块3
        self.block3 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
        )
        # 卷积块4
        self.block4 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
        )
        # 卷积块5
        self.block5 = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
        )
        
        # 全连接网络层,用于分类
        self.classifier = nn.Sequential(
            nn.Linear(in_features=512*7*7, out_features=4096),
            nn.ReLU(),
            nn.Linear(in_features=4096, out_features=4096),
            nn.ReLU(),
            nn.Linear(in_features=4096, out_features=4)
        )

    def forward(self, x):

        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.block5(x)
        x = torch.flatten(x, start_dim=1)
        x = self.classifier(x)

        return x

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))
    
model = vgg16().to(device)

2. 查看模型详情

统计模型参数量以及其他指标

import torchsummary as summary
summary.summary(model, (3, 224, 224))

 三、 训练模型

1. 编写训练函数

训练部分代码和之前cnn网络一样

# 训练循环
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)  # 训练集的大小
    num_batches = len(dataloader)   # 批次数目, (size/batch_size,向上取整)
 
    train_loss, train_acc = 0, 0  # 初始化训练损失和正确率
    
    for X, y in dataloader:  # 获取图片及其标签
        X, y = X.to(device), y.to(device)
        
        # 计算预测误差
        pred = model(X)          # 网络输出
        loss = loss_fn(pred, y)  # 计算网络输出和真实值之间的差距,targets为真实值,计算二者差值即为损失
        
        # 反向传播
        optimizer.zero_grad()  # grad属性归零
        loss.backward()        # 反向传播
        optimizer.step()       # 每一步自动更新
        
        # 记录acc与loss
        train_acc  += (pred.argmax(1) == y).type(torch.float).sum().item()
        train_loss += loss.item()
            
    train_acc  /= size
    train_loss /= num_batches
 
    return train_acc, train_loss

2. 编写测试函数

训练函数和测试函数差别不大,但是由于不进行梯度下降对网络权重进行更新,所以不用优化器

(所以测试函数代码部分和之前几周一样)

def test (dataloader, model, loss_fn):
    size        = len(dataloader.dataset)  # 测试集的大小
    num_batches = len(dataloader)          # 批次数目
    test_loss, test_acc = 0, 0
    
    # 当不进行训练时,停止梯度更新,节省计算内存消耗
    with torch.no_grad():
        for imgs, target in dataloader:
            imgs, target = imgs.to(device), target.to(device)
            
            # 计算loss
            target_pred = model(imgs)
            loss        = loss_fn(target_pred, target)
            
            test_loss += loss.item()
            test_acc  += (target_pred.argmax(1) == target).type(torch.float).sum().item()
 
    test_acc  /= size
    test_loss /= num_batches
 
    return test_acc, test_loss

4. 正式训练

这里也设置了训练器,结合前几次实验经验,使用Adam模型

import copy

optimizer  = torch.optim.Adam(model.parameters(), lr= 1e-4)
loss_fn    = nn.CrossEntropyLoss() # 创建损失函数

epochs = 40

train_loss = []
train_acc = []
test_loss = []
test_acc = []

best_acc = 0  # 设置一个最佳准确率,作为最佳模型的判别指标

for epoch in range(epochs):
    # 更新学习率(使用自定义学习率时使用)
    # adjust_learning_rate(optimizer, epoch, learn_rate)

    model.train()
    epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, optimizer)
    # scheduler.step() # 更新学习率(调用官方动态学习率接口时使用)

    model.eval()
    epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)

    # 保存最佳模型到 best_model
    if epoch_test_acc > best_acc:
        best_acc = epoch_test_acc
        best_model = copy.deepcopy(model)

    train_acc.append(epoch_train_acc)
    train_loss.append(epoch_train_loss)
    test_acc.append(epoch_test_acc)
    test_loss.append(epoch_test_loss)

    # 获取当前的学习率
    lr = optimizer.state_dict()['param_groups'][0]['lr']

    template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%, Test_loss:{:.3f}, Lr:{:.2E}')
    print(template.format(epoch + 1, epoch_train_acc * 100, epoch_train_loss,
                          epoch_test_acc * 100, epoch_test_loss, lr))

# 保存最佳模型到文件中
PATH = './best_model.pth'  # 保存的参数文件名
torch.save(model.state_dict(), PATH)

print('Done')

....

Epoch:38, Train_acc:96.7%, Train_loss:0.095, Test_acc:96.2%, Test_loss:0.099, Lr:1.00E-04
Epoch:39, Train_acc:96.0%, Train_loss:0.099, Test_acc:96.2%, Test_loss:0.112, Lr:1.00E-04
Epoch:40, Train_acc:96.5%, Train_loss:0.098, Test_acc:97.1%, Test_loss:0.091, Lr:1.00E-04
Done

四、 结果可视化

1. Loss与Accuracy图

import matplotlib.pyplot as plt
#隐藏警告
import warnings
warnings.filterwarnings("ignore")               #忽略警告信息
plt.rcParams['font.sans-serif']    = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False      # 用来正常显示负号
plt.rcParams['figure.dpi']         = 100        #分辨率

epochs_range = range(epochs)

plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)

plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

2. 指定图片进行预测

from PIL import Image
classes = list(total_data.class_to_idx)

def predict_one_image(image_path, model, transform, classes):
    test_img = Image.open(image_path).convert('RGB')
    plt.imshow(test_img)  # 展示预测的图片

    test_img = transform(test_img)
    img = test_img.to(device).unsqueeze(0)

    model.eval()
    output = model(img)

    _, pred = torch.max(output, 1)
    pred_class = classes[pred]
    print(f'预测结果是:{pred_class}')
 
# 预测训练集中的某张照片
predict_one_image(image_path='./49-data/Green/green (9).png',
                  model=model,
                  transform=train_transforms,
                  classes=classes)

3. 模型评估

以往都是看看最后几轮得到准确率,但是跳动比较大就不太好找准确率最高的一回,所以我们用函数返回进行比较。

best_model.eval()
epoch_test_acc, epoch_test_loss = test(test_dl, best_model, loss_fn)
print(epoch_test_acc, epoch_test_loss)
print(epoch_test_acc)

*五、优化模型

1.调整学习率和动态学习率

以下调用的官方动态学习率接口(上周内容也有提到):

学习率就用一开始的1e-4

动态学习率也是用到上周提到的torch.optim.LambdaLR,调用自己定义的函数更新学习率(lr_lambda)

learn_rate = 1e-4 # 初始学习率
lambda1 = lambda epoch: 0.92 ** (epoch // 10)
optimizer = torch.optim.Adam(model.parameters(), lr=learn_rate)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1)

结果98%左右...

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/36650.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

[附源码]java毕业设计中达小区物业管理系统

项目运行 环境配置: Jdk1.8 Tomcat7.0 Mysql HBuilderX(Webstorm也行) Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。 项目技术: SSM mybatis Maven Vue 等等组成,B/S模式 M…

学生家乡网页设计作品静态HTML网页模板源码 广西旅游景点网页设计 大学生家乡主题网站制作 简单家乡介绍网页设计成品

家乡旅游景点网页作业制作 网页代码运用了DIV盒子的使用方法,如盒子的嵌套、浮动、margin、border、background等属性的使用,外部大盒子设定居中,内部左中右布局,下方横向浮动排列,大学学习的前端知识点和布局方式都有…

Packet Tracer - 配置 IPv4 和 IPv6 接口

地址分配表 设备 接口 IPv4 地址 子网掩码 默认网关 IPv6 地址/前缀 R1 G0/0 172.16.20.1 255.255.255.128 不适用 G0/1 172.16.20.129 255.255.255.128 不适用 S0/0/0 209.165.200.225 255.255.255.252 不适用 PC1 NIC 172.16.20.10 255.255.255.128 1…

微服务架构的环境搭建及简单测试

目录 一、系统架构的演变过程 1.0 前言 1.1 单体应用架构 1.2 垂直应用架构 1.3 分布式架构 1.4 SOA架构 1.5 微服务架构 二、微服务架构搭建 2.1 微服务架构简介 2.2 微服务案例准备 2.3 创建父工程、基础模块 2.4 创建微服务 一、系统架构的演变过程 1.0 前言 随着互联网的…

【Queue】- 从源码分析PriorityQueue及其常用方法

文章目录PriorityQueue基础知识概述PriorityQueue内部结构PriorityQueue扩容操作PriorityQueue队列的构造方法PriorityQueue队列的常用方法public boolean offer(E e)public E peek()public boolean remove(Object o)public boolean contains(Object o)public Object[] toArray…

【SU-03T离线语音模块】:学习配置使用

前言 时不可以苟遇,道不可以虚行。 一、介绍 1、什么是语音识别模块 语音识别模块是在一种基于嵌入式的语音识别技术的模块,主要包括语音识别芯片和一些其他的附属电路,能够方便的与主控芯片进行通讯,开发者可以方便的将该模块嵌…

Node.js 入门教程 3 如何安装 Node.js

Node.js 入门教程 Node.js官方入门教程 Node.js中文网 本文仅用于学习记录,不存在任何商业用途,如侵删 文章目录Node.js 入门教程3 如何安装 Node.js3 如何安装 Node.js Node.js 可以通过多种方式安装。 所有主流平台的官方软件包都可以在 http://node…

终于见识到了微服务的天花板:阿里内部SpringCloud全线手册,太强了

后台都是在问微服务架构的面试题怎么答,想聊聊微服务架构了。微服务架构一跃成为 IT 领域炙手可热的话题也就这两年的事,大量一线互联网公司因为庞大的业务体量和业务需求,纷纷投入了微服务架构的建设中,像阿里巴巴、百度、美团等…

226. 翻转二叉树

文章目录1.题目2.示例3.答案①递归②迭代1.题目 给你一棵二叉树的根节点 root ,翻转这棵二叉树,并返回其根节点。 2.示例 输入:root [4,2,7,1,3,6,9] 输出:[4,7,2,9,6,3,1]输入:root [4,2,7,1,3,6,9] 输出&#xf…

智慧应急解决方案-最新全套文件

智慧应急解决方案-最新全套文件一、建设背景二、建设思路应急管理信息化发展“四纵四横”总体架构1、两网络2、四体系3、两机制三、建设方案四、获取 - 智慧应急全套最新解决方案合集一、建设背景 建立应急大数据管理体系是应急管理信息化建设中的重要环节,决定了应…

将数组沿指定轴划分为子数组numpy.split()

【小白从小学Python、C、Java】 【计算机等级考试500强双证书】 【Python-数据分析】 将数组沿指定轴划分为子数组 numpy.split() [太阳]选择题 以下python代码输出错误的一项是? import numpy as np xnp.array([1,2,3,4,5,6,7,8,9,10,11,12]) print(【显示】x&#x…

数据库安装记录——Mysql8.0.23 msi 保姆级安装教程

今天遇到现场服务器安装数据库,特意记录安装过程。 本篇记录的服务器系统为:Windows Server 2012 R2 Standard 数据库版本为:Mysql8.0.23 msi 1、官网下载相应版本 2、安装过程 开端不顺,开始就出弹窗: 先插入一…

下载神器-IDM使用教程及下载

软件介绍: IDM是“Internet Download Manager”的简称,意思是“互联网下载管理器”,既是一类软件的统称,也专指一个非常知名的互联网下载器,这个软件的名字就叫IDM,被誉为地表最强下载器,屌丝救…

Flutter 中使用 extension 使项目更具可读性和效率 01

Flutter 中使用 extension 使项目更具可读性和效率 01 原文 https://medium.com/bedirhanssaglam/make-your-flutter-projects-more-readable-and-effective-with-extensions-b7dffd32e2f4 前言 代码的可读性和实用性在《 Flutter 》中非常重要。今天我们将编写一些 extension …

代码行覆盖率学习

【强力推荐】jacoco代码测试覆盖率实战教学全集,7天从入门到精通【理论实战 赶紧拿走】_哔哩哔哩_bilibili on-the-fly: 测试的时候代码是动态的, 需要测试就帮你插桩, 不测就不帮你插桩 offline: 先把被测代码拿到一次性直接插桩, 一运行桩就已经插好了, 直接生成…

如何把一个视频分割成不同时长的多个小视频

大家平时找素材是不是有点困难,如何把一个视频一个分割成多个不同时长的小视频呢,分割视频时能不能按我们需要来分割,今天小编带大伙来了解决下分割视频操作方法和步骤。 先来看下原来视频,原视频时长是比较长的 接下来我们准备一…

掌握这些 Spring Boot 启动扩展点,已经超过 90% 的人了!

1.背景 Spring的核心思想就是容器,当容器refresh的时候,外部看上去风平浪静,其实内部则是一片惊涛骇浪,汪洋一片。Springboot更是封装了Spring,遵循约定大于配置,加上自动装配的机制。很多时候我们只要引用…

Docker - Docker部署war包

使用Docker部署war项目,必须要用容器,我们就用tomcact容器,其实都是将war包丢到tomcat的webapps目录下,tomcat启动的情况下会自动解压war包 部署war包有两种方式 1、在Docker中安装tomcat容器的镜像,然后把war包丢到…

【无人机】模拟一群配备向下摄像头的移动空中代理覆盖平面区域(Matlab代码实现)

👨‍🎓个人主页:研学社的博客 💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜…

极速Go语言入门(超全超详细)-基础篇2

文章目录函数进阶结构体接口继承type值类型与引用类型值传递、引用传递打包、引用包工具类打包文件代码引用包代码方法异常捕捉处理字符串常用函数日期常用函数管道(channel)书接上篇:极速Go语言入门(超全超详细)-基础篇 整个基础篇合计32000字左右,如有遗漏可以私…