Pytorch从零开始实战03

news2025/5/26 12:57:07

Pytorch从零开始实战——天气识别

本系列来源于365天深度学习训练营

原作者K同学

文章目录

  • Pytorch从零开始实战——天气识别
    • 环境准备
    • 数据集
    • 模型选择
    • 模型训练
    • 数据可视化
    • 总结

环境准备

本文基于Jupyter notebook,使用Python3.8,Pytorch2.0.1+cu118,torchvision0.15.2,需读者自行配置好环境且有一些深度学习理论基础。
第一步,导入常用包。

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import torchvision
import torch.nn.functional as F
import torchvision.transforms as transforms
import random
import time
import numpy as np
import pandas as pd
import datetime
import gc
import pathlib
import os
import PIL
os.environ['KMP_DUPLICATE_LIB_OK']='True'  # 用于避免jupyter环境突然关闭
torch.backends.cudnn.benchmark=True  # 用于加速GPU运算的代码

创建设备对象。

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device # device(type='cuda')

设置随机数种子

torch.manual_seed(428)
torch.cuda.manual_seed(428)
torch.cuda.manual_seed_all(428)
random.seed(428)
np.random.seed(428)

数据集

本次实验使用的天气图片数据集,共有1127张天气图片,分别存在’cloudy’, ‘sunrise’, ‘shine’, 'rain’四个文件夹中,其中文件夹名代表类别。数据集获取可联系K同学。
导入数据集。
根据自己数据集存放的路径,转换为pathlib.Path对象,然后获取路径下的所有文件路径,使用字符串分割函数获取文件名,也就是类别名。

data_dir = './data/weather_photos'
data_dir = pathlib.Path(data_dir) # 转换为pathlib.Path对象

data_paths = list(data_dir.glob('*')) # 获取data_dir路径下的所有文件路径
data_paths # data/weather_photos/xxxx
classNames = [str(path).split("/")[2] for path in data_paths]
classNames # ['cloudy', 'sunrise', 'shine', 'rain']

对数据集进行预处理。调整到相同的尺寸,转换为张量对象,并进行标准化处理。使用torchvision.datasets.ImageFolder函数读取数据集,并且使用文件名当做数据集的标签。

total_dir = './data/weather_photos'
train_transforms = transforms.Compose([
    transforms.Resize([224, 224]), # 调整相同的尺寸
    transforms.ToTensor(),
    transforms.Normalize(          # 标准化处理-->转换为标准正太分布
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])
total_data = torchvision.datasets.ImageFolder(total_dir, transform=train_transforms) # 通过total_dir下的子文件夹当做标签
total_data

我们根据8:2划分训练集和测试集。

# 划分数据集
train_size = int(0.8 * len(total_data))
test_size = len(total_data) - train_size
train_ds, test_ds = torch.utils.data.random_split(total_data, [train_size, test_size])
len(train_ds), len(test_ds) # (901, 226)

又是前面几篇出现的函数,随机查看五张图片。

def plotsample(data):
    fig, axs = plt.subplots(1, 5, figsize=(10, 10)) #建立子图
    for i in range(5):
        num = random.randint(0, len(data) - 1) #首先选取随机数,随机选取五次
        #抽取数据中对应的图像对象,make_grid函数可将任意格式的图像的通道数升为3,而不改变图像原始的数据
        #而展示图像用的imshow函数最常见的输入格式也是3通道
        npimg = torchvision.utils.make_grid(data[num][0]).numpy()
        nplabel = data[num][1] #提取标签 
        #将图像由(3, weight, height)转化为(weight, height, 3),并放入imshow函数中读取
        axs[i].imshow(np.transpose(npimg, (1, 2, 0))) 
        axs[i].set_title(nplabel) #给每个子图加上标签
        axs[i].axis("off") #消除每个子图的坐标轴

plotsample(train_ds)

在这里插入图片描述
使用DataLoder将它按照batch_size批量划分,并将数据集顺序打乱。

batch_size = 32
train_dl = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True)
test_dl = torch.utils.data.DataLoader(test_ds, batch_size=batch_size, shuffle=True)
for X, y in test_dl:
    print(X.shape) # 32, 3, 224, 224
    print(y) # 3 0 2 0 3 2 0 0 2 1....
    break

模型选择

本文使用卷积神经网络,大致流程是卷积->卷积->池化->卷积->卷积->池化->线性层,并进行数据归一化处理,本文选用的卷积核大小为5 * 5。

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(3, 12, kernel_size=5, stride=1, padding=0)
        self.bn1 = nn.BatchNorm2d(12)
        self.conv2 = nn.Conv2d(12, 12, kernel_size=5, stride=1, padding=0)
        self.bn2 = nn.BatchNorm2d(12)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.conv3 = nn.Conv2d(12, 24, kernel_size=5, stride=1, padding=0)
        self.bn3 = nn.BatchNorm2d(24)
        self.conv4 = nn.Conv2d(24, 24, kernel_size=5, stride=1, padding=0)
        self.bn4 = nn.BatchNorm2d(24)
        self.pool4 = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(24 * 50 * 50, len(classNames))

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.pool2(x)
        x = F.relu(self.bn3(self.conv3(x)))
        x = F.relu(self.bn4(self.conv4(x)))
        x = self.pool4(x)
        x = x.view(-1, 24 * 50 * 50)
        x = self.fc1(x)
        return x

请添加图片描述
使用summary展示模型架构。

from torchsummary import summary
# 将模型转移到GPU中
model = Model().to(device)
summary(model, input_size=(3, 224, 224))

请添加图片描述

模型训练

定义超参数,本次选择的学习率为0.0001,经实验,最初设置为0.01效果并不是很好。

loss_fn = nn.CrossEntropyLoss()
learn_rate = 0.0001
opt = torch.optim.SGD(model.parameters(), lr=learn_rate)

训练函数。

def train(dataloader, model, loss_fn, opt):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    train_acc, train_loss = 0, 0

    for X, y in dataloader:
        X, y = X.to(device), y.to(device)
        pred = model(X)
        loss = loss_fn(pred, y)

        opt.zero_grad()
        loss.backward()
        opt.step()

        train_acc += (pred.argmax(1) == y).type(torch.float).sum().item()
        train_loss += loss.item()

    train_acc /= size
    train_loss /= num_batches
    return train_acc, train_loss

测试函数。

def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_acc, test_loss = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            loss = loss_fn(pred, y)
    
            test_acc += (pred.argmax(1) == y).type(torch.float).sum().item()
            test_loss += loss.item()

    test_acc /= size
    test_loss /= num_batches
    return test_acc, test_loss

开始训练,训练20轮,在测试集准确率达到94.7%,还是很不错的。

import time
epochs = 20
train_loss = []
train_acc = []
test_loss = []
test_acc = []

T1 = time.time()

for epoch in range(epochs):
    model.train()
    epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, opt)
    
    model.eval() # 确保模型不会进行训练操作
    epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)
        
    train_acc.append(epoch_train_acc)
    train_loss.append(epoch_train_loss)
    test_acc.append(epoch_test_acc)
    test_loss.append(epoch_test_loss)
    
    print("epoch:%d, train_acc:%.1f%%, train_loss:%.3f, test_acc:%.1f%%, test_loss:%.3f"
          % (epoch + 1, epoch_train_acc * 100, epoch_train_loss, epoch_test_acc * 100, epoch_test_loss))
print("Done")
T2 = time.time()
print('程序运行时间:%s毫秒' % ((T2 - T1)*1000))

请添加图片描述

数据可视化

使用matplotlib进行训练数据、测试数据的可视化。

import warnings
warnings.filterwarnings("ignore")               #忽略警告信息
plt.rcParams['font.sans-serif']    = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False      # 用来正常显示负号
plt.rcParams['figure.dpi']         = 100        #分辨率

epochs_range = range(epochs)

plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)

plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

请添加图片描述

总结

经过几次实验,发现三个问题:
1.经过实验,将学习率从0.01改为0.0001,模型效果会好很多。
2.有的时候每轮epoch准确率一直为百分之20多,可能是模型陷入局部最小值或鞍点,所以后续可以引入提前停止。
3.无脑的增加层数并不会使模型效果变好。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1010366.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Linux学习笔记】 - 常用指令学习及其验证(上)

前言:本文主要记录对Linux常用指令的使用验证。环境为阿里云服务器CentOS 7.9。关于环境如何搭建等问题,大家可到同平台等各大资源网进行搜索学习,本文不再赘述。 由于本人对Linux学习程度尚且较浅,本文仅介绍验证常用指令的常用…

Writesonic:博客和内容创作者的终极写作助手

【产品介绍】 产品名称 Writesonic 上线时间 成立于2020年 具体介绍 Writesonic是一个强大的人工智能写作助手,它使用自然语言处理(NLP)和机器学习算法来生成内容,这些内容不仅写得好,而且还为SEO和转…

MFC中嵌入显示opencv窗口

在MFC窗体中建立一个Picture Control控件,用于显示opencv窗口 在属性中设置图片控件的资源ID为IDC_PIC1 主要的思路: 使用GetWindowRect可以获取图片控件的区域 使用cv::resizeWindow可以设置opencv窗口的大小,适合图片控件的大小 使用cvGetWindowHandle函数可以获取到ope…

30天入门Python(基础篇)——第3天:【变量】与【输出】与【转义符】(万字解析,建议收藏)

文章目录 专栏导读作者有话说:上一节课补充(Pychaem界面认识)①编写代码区域②运行代码(多种方法,随便选一种,开心就好) 什么是变量(变量的定义)①较标准的回答(引用AI)②大白话解释图文并茂(我…

嵌入式Linux驱动开发(I2C专题)(一)

一、I2C协议 1.1、硬件连接 I2C在硬件上的接法如下所示,主控芯片引出两条线SCL,SDA线,在一条I2C总线上可以接很多I2C设备。 1.2、IIC传输数据的格式 1.2.1、写操作 流程如下: 主芯片要发出一个start信号然后发出一个设备地址(用来确定是…

Java作业-模拟扎金花

要求 实现扑克牌的创建、洗牌、发牌、大小对比,输出赢家牌。 前提条件 首先需要创建三个集合,用于存储牌面值、牌号与比较规则,再创建一个类作为牌。 其次还需要了解到一个工具类,就是Collections类,该类的所有方法…

python,迪卡尔象限中画点

import numpy as np import matplotlib.pyplot as plt circleNum 30 # 同时圆刻度值 pointNum 20 # 点的数量 theta np.linspace(0.0, 2*np.pi, pointNum, endpointFalse) s circleNum * np.random.rand(pointNum) # plt.polar(theta, s, linestyleNone, marker*) # 无连接…

不知道有用没用的Api

encodeURIComponent(https://www.baidu.com/?name啊啊啊) decodeURIComponent(https%3A%2F%2Fwww.baidu.com%2F%3Fname%3D%E5%95%8A%E5%95%8A%E5%95%8A) encodeURI(https://www.baidu.com/?name啊啊啊) decodeURI(https://www.baidu.com/?name%E5%95%8A%E5%95%8A%E5%95%8A) …

​LeetCode解法汇总1222. 可以攻击国王的皇后

目录链接: 力扣编程题-解法汇总_分享记录-CSDN博客 GitHub同步刷题项目: https://github.com/September26/java-algorithms 原题链接:力扣(LeetCode)官网 - 全球极客挚爱的技术成长平台 描述: 在一个 8x…

解决虚拟机重启后ifconfig看不到IP的问题

目录 背景 解决方案 背景 虚拟机,桥接模式,启动后一切正常,但重启后发现终端连不上虚机了,也ping不到,最后检查发现,IP消失了,虚机没有IP了。 解决方案 不论是否重启,只要是看不…

2023-09-14 LeetCode每日一题(可以攻击国王的皇后)

2023-09-14每日一题 一、题目编号 1222. 可以攻击国王的皇后二、题目链接 点击跳转到题目位置 三、题目描述 在一个 8x8 的棋盘上,放置着若干「黑皇后」和一个「白国王」。 给定一个由整数坐标组成的数组 queens ,表示黑皇后的位置;以及…

关于浅克隆和深克隆入门理解

浅克隆:需要类实现Cloneable,并重写clone()方法 一般在重写clone()方法时,将返回值类型强转为自己类,避免每次克隆之后需要强转 public class Test {public static void main(String[] args) throws CloneNotSupportedException {A a1new A();A a2 a1.clone();//克隆之后 a1…

【搭建私人图床】本地PHP搭建简单Imagewheel云图床,在外远程访问

文章目录 1.前言2. Imagewheel网站搭建2.1. Imagewheel下载和安装2.2. Imagewheel网页测试2.3.cpolar的安装和注册 3.本地网页发布3.1.Cpolar临时数据隧道3.2.Cpolar稳定隧道(云端设置)3.3.Cpolar稳定隧道(本地设置) 4.公网访问测…

Linux——(第十章)进程管理

目录 一、概述 二、常用指令 1.ps查看当前系统进程状态 2.kill 终止进程 3.pstree 查看进程树 4.top 实时监控系统进程状态 5.netstat 监控网络状态 一、概述 (1)进程是正在执行的一个程序或命令,每一个进程都是一个运行的实体&#…

Redis 常用命令

目录 全局命令 1)keys 2)exists 3) del(delete) 4)expire 5)type SET命令 GET命令 MSET 和 MGET命令 其他SET命令 计数命令 redis-cli,进入redis 最核心的命令:我们这里只是先介绍 set 和 get 最简单的操作…

IP地址,子网掩码,默认网关,DNS讲解

IP地址:用来标识网络中一个个主机,IP有唯一性,即每台机器的IP在全世界是唯一的。 子网掩码:用来判断任意两台计算机的ip地址是否属于同一子网络的根据。最为简单的理解就是两台计算机各自的ip地址与子网掩码进行and运算后&#x…

常用排序算法

一、插入排序1、直接插入排序2、折半插入排序3、希尔排序 二、交换排序1、冒泡排序2、快速排序 三、选择排序1、简单选择排序2、堆排序(1)调整堆(2)创建堆 四、归并排序五、基数排序六、各种排序方法的比较 将一组杂乱无章的数据按…

Zookeeper应用场景和底层设计

一、什么是zookeeper Zookeeper是一个开源的分布式协调服务框架,它是服务于其它集群式框架的框架。 【简言之】 有一个服务A,以集群的方式提供服务。只需要A专注于它提供的服务就可以,至于它如何以多台服务器协同完成任务的事情&#xff0c…

9.14号作业

仿照vector手动实现自己的myVector&#xff0c;最主要实现二倍扩容功能 有些功能&#xff0c;不会 #include <iostream>using namespace std; //创建vector类 class Vector { private:int *data;int size;int capacity; public://无参构造Vector(){}//拷贝构造Vector(c…

2023年最新 Nonobot2 制作QQ聊天机器人详细教程(每周更新中)

协议端 go-cqhttp 安装 使用 mirai 以及 MiraiGo 开发的 cqhttp golang 原生实现&#xff0c;并在 cqhttp 原版 的基础上做了部分修改和拓展。 测试版下载地址&#xff1a;https://github.com/Mrs4s/go-cqhttp/releases 正式版下载地址&#xff1a;https://github.com/Mrs4s…