22- Pytorch实现天气分类 (Pytorch系列) (项目二十二)

news2025/7/19 10:13:34

项目要点

  • 4种天气数据的分类:   cloudy,  rain,  shine,  sunrise.
  • all_img_path = glob.glob(r'G:\01-project\08-深度学习\day 56 迁移学习\dataset/*.jpg')     # 指定文件夹   # import glob
  • 获取随机数列: index = np.random.permutation(len(all_img_path))
  • 建立数组和索引的关联: idx_to_species = dict((i, c) for i, c in enumerate(species))
  • transform = transforms.Compose([transforms.Resize((96, 96)), transforms.ToTensor()])    # 转换为tensor  # 定义transform
  • 数据由numpy转换为tensor: torch.from_numpy(np.array(label)).long()
  • 判断图片的通道数: if np.array(img).shape[-1] == 3
  • 打开文件夹图片: img = Image.open(all_img_path[0])
  • 数据转换为ndarray: np.asarray(img).shape
  • train_d1 = torch.utils.data.DataLoader(train_ds, batch_size = 16, shuffle = True, collate_fn = MyDataset.collate_fn, drop_last = True)      # 定义dataloader  # 最后一批数据直接不用

定义模型:

  • 添加卷积层: self.conv1 = nn.Conv2d(3, 32, 3)
  • 添加激活层: x = self.pool(F.relu(self.conv1(x)))
  • 添加BN层: self.bn1 = nn.BatchNorm2d(32)     # x = self.bn1(x)
  • 添加Flatten层: x = nn.Flatten()(x)     # 用来将输入“压平”,即把多维的输入一维化,# 常用在从 卷积层到全连接层的过渡。
  • 添加卷积层: self.fc1 = nn.Linear(64 * 10 * 10, 1024)   
  • 添加激活层: x = F.relu(self.fc1(x))
  • 添加dropout: self.dropout = nn.Dropout()    # 防止过拟合
  • 添加输出层: self.fc3 = nn.Linear(256, 4)   
    • x = self.fc3(x)
  • 定义程序运行位置: device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
  • 定义优化器: optimizer = optim.Adam(model.parameters(), lr=0.001)
  • 定义loss: loss_fn = nn.CrossEntropyLoss()
  • 定义梯度下降:
for x, y in train_loader:
        x, y = x.to(device), y.to(device)
        y_pred = model(x)
        loss = loss_fn(y_pred, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        with torch.no_grad():
            y_pred = torch.argmax(y_pred, dim=1)
            correct += (y_pred == y).sum().item()
            total += y.size(0)
            running_loss += loss.item()


一 自定义数据集分类

 4种天气数据的分类: cloudy,  rain,  shine,  sunrise.

1.1 导包

import torch
import numpy as np
from torchvision import transforms
import glob
from PIL import Image
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt

1.2 数据导入

  • 指定文件夹
all_img_path = glob.glob(r'G:\01-project\08-深度学习\day 56 迁移学习\dataset/*.jpg')
  • 打乱顺序
# permutation 排列组合
# 借助ndarray的索引取值的方法, 打乱数据
index = np.random.permutation(len(all_img_path))
index   # array([ 175, 1027,  530, ...,    4,  831,   65])
species = ['cloudy', 'rain', 'shine', 'sunrise']
# 建立类别和索引之间的映射关系
idx_to_species = dict((i, c) for i, c in enumerate(species))
# {0: 'cloudy', 1: 'rain', 2: 'shine', 3: 'sunrise'}
# 生成所有图片的label
all_labels = []

for img in all_img_path:
    for i,c in enumerate(species):
        if c in img:
            all_labels.append(i)
  •  数据格式转换
all_labels = np.array(all_labels, dtype=np.int64)[index]
all_labels   # array([0, 3, 2, ..., 0, 3, 0], dtype=int64)
all_img_path = np.array(all_img_path)[index]
all_img_path

1.3 数据拆分

# 手动划分一下训练数据和测试数据
split = int(len(all_img_path) * 0.8)  # int 只取整数部分

train_imgs = all_img_path[:split]
train_labels = all_labels[:split]

test_imgs = all_img_path[split:]
test_labels = all_labels[split:]
  • 定义 transform
# 定义transform
transform = transforms.Compose([transforms.Resize((96, 96)),
                                transforms.ToTensor()])  # 转换为tensor

1.4 数据处理

class MyDataset(torch.utils.data.Dataset):
    def __init__(self, img_paths, labels, transform):  # 接受初始化数据
        self.imgs = img_paths
        self.labels = labels
        self.transforms = transform
        
    def __getitem__(self, index):   # 取上面的数据
        # 根据index获取item
        img_path = self.imgs[index]
        label = self.labels[index]
        
        # 通过PIL的Image读取图片
        img = Image.open(img_path)
        if np.array(img).shape[-1] == 3:
            data = self.transforms(img)
        
            return data, torch.from_numpy(np.array(label)).long()
        else:
            # 否则为有问题的图片
            print(img_path)
            # print(np.array(img).shape)
            # print(np.array(img))
            return self.__getitem__(index+1)
    
    def __len__(self):  # 调用数据时, 返回长度
        return len(self.imgs)  # 返回个数
    
    # 重写collate_fn
    @staticmethod
    def collate_fn(batch):
        # batch是列表, 长度是batch_size
        # 列表的每个元素是一个元组(x, y)
        # [(x1, y1), (x2, y2).......]
        # collate_fn 的作用, 把所有的x,y分别放到一起, x在一起, y在一起.
        # 把batch中返回值为空的部分过滤掉
        batch = [sample for sample in batch if sample is not None]
        # 简单方法, 直接调用默认的collate方法
        # from torch.utils.data.dataloader import default_collate
        # return default_collate(batch)
        
        # 方式二
        imgs, labels = zip(*batch)
        return torch.stack(imgs, 0), torch.stack(labels, 0)

dataset = MyDataset(all_img_path, all_labels, transform)
len(dataset)   # 1122
train_ds = MyDataset(train_imgs, train_labels, transform)
test_ds = MyDataset(test_imgs, test_labels, transform)
# dataloader
train_d1 = torch.utils.data.DataLoader(train_ds, batch_size = 16,
                                       shuffle = True,
                                       collate_fn=MyDataset.collate_fn,
                                       drop_last = True)  # 最后一批数据直接不用
test_d1 = torch.utils.data.DataLoader(test_ds, batch_size = 16 * 2,
                                      collate_fn=MyDataset.collate_fn, drop_last = True)
for x, y in train_d1:
    print(x.shape,y.shape)

imgs, labels = next(iter(train_d1))
imgs.shape     # torch.Size([16, 3, 96, 96])
labels     # tensor([3, 3, 2, 3, 2, 3, 3, 0, 1, 1, 0, 3, 2, 0, 1, 1])

1.5 定义模型

# 定义模型  # 添加BN层
import torch.nn as nn
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, 3) # 卷积  # (96, 96, 3)  --> (32, 94, 94)
        # 主要做标准化处理
        self.bn1 = nn.BatchNorm2d(32)
        self.pool = nn.MaxPool2d(2, 2) # 池化   # (32, 47, 47)
        self.conv2 = nn.Conv2d(32, 32, 3)    # (32, 45, 45) --> pooling --> (32, 22, 22)
        self.bn2 = nn.BatchNorm2d(32)
        self.conv3 = nn.Conv2d(32, 64, 3)   # (64, 22, 22)  --> pooling --> (64, 10, 10)
        self.bn3 = nn.BatchNorm2d(64)
        self.dropout = nn.Dropout()  # 防止过拟合    # 
        
        # batch, channel, height, width, 64
        # 全连接层
        self.fc1 = nn.Linear(64 * 10 * 10, 1024)
        self.bn_fc1 = nn.BatchNorm1d(1024)
        self.fc2 = nn.Linear(1024, 256)
        self.bn_fc2 = nn.BatchNorm1d(256)
        self.fc3 = nn.Linear(256, 4)
        
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.bn1(x)
        x = self.pool(F.relu(self.conv2(x)))
        x = self.bn2(x)
        x = self.pool(F.relu(self.conv3(x)))
        x = self.bn3(x)
        
        # x.view(-1, 64 * 10 * 10)
        # flatten , Flatten层用来将输入“压平”,即把多维的输入一维化,
        # 常用在从卷积层到全连接层的过渡。
        x = nn.Flatten()(x)
        x = F.relu(self.fc1(x))
        x = self.bn_fc1(x)
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        x = self.bn_fc2(x)
        x = self.dropout(x)
        x = self.fc3(x)
        return x
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device   # device(type='cpu')

# 生成模型
model = Net()
# 把模型拷贝到GPU
if torch.cuda.is_available():
    model.to(device)

1.6 定义训练

optimizer = optim.Adam(model.parameters(), lr=0.001)
loss_fn = nn.CrossEntropyLoss()

# 定义训练过程
def fit(epoch, model, train_loader, test_loader):
    correct = 0
    total = 0
    running_loss = 0
    
    for x, y in train_loader:
        x, y = x.to(device), y.to(device)
        y_pred = model(x)
        loss = loss_fn(y_pred, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        with torch.no_grad():
            y_pred = torch.argmax(y_pred, dim=1)
            correct += (y_pred == y).sum().item()
            total += y.size(0)
            running_loss += loss.item()
            
    epoch_loss = running_loss / len(train_loader.dataset)
    epoch_acc = correct / total
    
    # 测试过程
    test_correct = 0
    test_total = 0
    test_running_loss = 0
    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)
            y_pred = model(x)
            loss = loss_fn(y_pred, y)
            y_pred = torch.argmax(y_pred, dim=1)
            test_correct += (y_pred == y).sum().item()
            test_total += y.size(0)
            test_running_loss += loss.item()
    test_epoch_loss = test_running_loss / len(test_loader.dataset)
    test_epoch_acc = test_correct /test_total
    
    print('epoch', epoch,
          'loss', round(epoch_loss, 3),
          'accuracy', round(epoch_acc, 3),
          'test_loss', round(test_epoch_loss, 3),
          'test_accuracy', round(test_epoch_acc, 3))
    return epoch_loss, epoch_acc, test_epoch_loss, test_epoch_acc
  • 指定训练
# 指定训练次数
epochs = 10
train_loss = []
train_acc = []
test_loss = []
test_acc = []

for epoch in range(epochs):
    epoch_loss, epoch_acc, test_epoch_loss, test_epoch_acc = fit(epoch, model,
                                                                 train_d1, test_d1)
    train_loss.append(epoch_loss)
    train_acc.append(epoch_acc)
    
    test_loss.append(epoch_loss)
    test_acc.append(epoch_acc)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/395540.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Java操作数据库基本原理

- 四年前存稿 Java操作数据库基本原理概述 全称Java Database Connectivity,Java的数据库连接,使用Java语言操作数据库,定义了操作所有关系型数据库规则(接口) 使用步骤 我的mysql是8版本的,使用jar包时必须使用8版本的&#x…

6年Android开发最终被优,事后加入车载开发,开启新起点~

如今传统Android 开发行业的岗位越发紧张了,经过去年一些互联网大厂的部门人员优化、开源截流等操作,加快了内卷的速度,原本坐山观虎斗我,没想到也被卷入其中。 1. Android 开发6年,无情中招 就去年年底,…

RZ/G2UL核心板-40℃低温启动测试

1. 测试对象HD-G2UL-EVM基于HD-G2UL-CORE工业级核心板设计,一路千兆网口、一路CAN-bus、 3路TTL UART、LCD、WiFi、CSI 摄像头接口等,接口丰富,适用于工业现场应用需求,亦方便用户评估核心板及CPU的性能。HD-G2UL-CORE系列工业级核…

铁路与公路

蓝桥杯集训每日一题acwing4074 某国家有 n 个城市(编号 1∼n)和 m 条双向铁路。 每条铁路连接两个不同的城市,没有两条铁路连接同一对城市。 除了铁路以外,该国家还有公路。 对于每对不同的城市 x,y,当且仅当它们之…

Mars3D美丽乡村系统发布

近日,我们基于Mars3D平台做了一个Mars3D美丽乡村应用系统,于2023年3月7日正式发布上线!该系统提供一个PC浏览器端的“样板房”项目模版,包含常用的地图基础功能,可基于该平台开发新项目,更换数据即可快速形…

C++继承派生以及虚基类的内存分布

C中类有3种权限&#xff1a;public、protected、private。&#xff08;本文为《直击招聘》的笔记总结&#xff09;。如果没有指明默认为private&#xff0c;定义class A如下class A {int x; public:void displaya() {cout << "A::x:" << &x <<…

理解进程、通过调用 fork 函数创建进程

文章目录1.理解进程1.1 CPU核的个数与进程数1.2 进程 ID2.通过调用 fork 函数创建进程2.1 fork.c1.理解进程 进程&#xff08;Process&#xff09;&#xff0c;其定义如下&#xff1a;“占用内存空间的正在运行的程序”。 假如各位从网上下载了 LBreakout 游戏并安装到硬盘。…

CS项目实训-Java 银行ATM机

摘 要 本次课程设计主要目的是培养我们面向对象软件开发的思维&#xff0c;初步了解软件开发的一般流程。提高编程的实际动手能力并增强大家对面向对象的了解。这次课程设计的主要内容是开发一个应用程序&#xff0c;我们小组设计的ATM柜员机&#xff0c;它主要是由各个indows窗…

05 | 如何安全、快速地接入OAuth 2.0?

05 | 如何安全、快速地接入OAuth 2.0&#xff1f; 构建第三方软件应用 第一点&#xff0c;注册信息 小兔软件的研发人员提前登录到京东商家开放平台进行手动注册&#xff0c;以便后续使用这些注册的相关信息来请求访问令牌。兔软件需要先拥有自己的 app_id 和 app_serect 等信…

联想笔记本电脑开机后一直转圈无法启动怎么办?

联想笔记本电脑开机后一直转圈无法启动怎么办&#xff1f;在正常开启电脑的过程中&#xff0c;系统进入到加载页面&#xff0c;但是却一直无法正常的启动。进行系统的重新启动依然是无法正常的使用。遇到这个情况需要进行系统的重置。接下来我们来看看详细的解决方法分享吧。 准…

实现用户操作日志记录

Java记录操作日志 java自带的日志框架是java.util.logging&#xff08;JUL&#xff09;&#xff0c;从JDK1.4&#xff08;2002&#xff09;开始捆绑在JDK中。可以使用JUL来记录操作日志。以下是使用JUL记录事务的示例&#xff1a; // java.util.logging java.util.logging.Lo…

网分线缆测试和dc-block

今天的好苹果和坏苹果 好苹果&#xff1a;是校准件和网分都是好的&#xff0c;又给了我一次复盘的机会 网分测试线缆&#xff1a; 1.网分直接复位&#xff0c;如果网分复位是校准状态&#xff0c;且解的是精密转接头&#xff0c;BNC的&#xff0c;可以不校准&#xff0c;结果差…

【高中数学教资】教案设计通用模板

前言 本文针对的是高中数学教师资格证笔试中最后的大题——教案设计&#xff08;含设计意图&#xff0c;文末有2022下半年高中数学教资教案设计大题&#xff09;。并附上高中数学404教资考点大纲&#xff0c;还有在学习中发现的一些可以免费学习网站推荐。 一、高中数学404考…

List系列集合

一. List集合特点、特有API List的实现类的底层原理 ArrayList底层是基于数组实现的&#xff1a;根据索引定位元素快&#xff0c;增删相对慢。LinkedList底层基于双链表实现的&#xff1a;查询元素慢&#xff0c;增删首尾元素是非常快的。public class ListDemo01 {public sta…

SerDes---CDR技术

1、为什么需要CDR 时钟数据恢复主要完成两个工作&#xff0c;一个是时钟恢复&#xff0c;一个是数据重定时&#xff0c;也就是数据的恢复。时钟恢复主要是从接收到的 NRZ&#xff08;非归零码&#xff09;码中将嵌入在数据中的时钟信息提取出来。 2、CDR种类 PLL-Based CDROve…

【信号与系统笔记】第一章 绪论

1.1信号传输系统 信息传输的任务 将带有信息的信号&#xff0c;通过某种系统由发送者传送给接收者。 通信系统的组成 转换器&#xff1a;把消息转换为电信号或者把电信号还原成消息信道&#xff1a;信号传输的通道&#xff0c;广义上来说。发射机和接收机也可以是信道的一部分…

【RabbitMQ】Producer之publisher confirm、transaction - 基于AMQP 0-9-1(二)

上篇文章主要介绍Producer的mandatory参数&#xff0c;备份队列和TTL的内容&#xff0c;这篇文章讲继续介绍Producer端的开发&#xff0c;主要包括发布方确认和事务机制。 发布方确认 消息持久化机制可以保证应服务器出现异常导致消息丢失的问题&#xff0c;但是Producer将消…

线程池ThreadPoolExecutor,从0到0.6

ThreadPoolExecutor是JDK提供的在java.util.concurrent包中的一个用于创建线程池的工具类。 一、ThreadPoolExecutor的7个参数 corePoolSize&#xff1a;核心线程数&#xff0c;线程池中保留的最小的线程数量&#xff0c;即使它们是空闲的也不会被销毁&#xff0c;除非allowCor…

Modbus转profinet网关连接1200PLC在博图组态与驱动器通讯程序案例

本案例给大家介绍由兴达易控modbus转profinet网关连接1200PLC在博图软件无需编程&#xff0c;实现1200Profinet转modbus与驱动器通讯的程序案例 硬件连接&#xff1a;1200PLC一台&#xff1b;英威腾DA180系列驱动器一台&#xff1b;兴达易控modbus转profinet网关一台 下面就是…

【Git】拉取 Pull Requests 测试的两种方法

文章目录前言参考目录方法说明方法一&#xff1a;直接拉取方法二&#xff1a;使用 diff 文件2.1、保存 diff 文件2.2、新建分支并执行文件前言 最近有参与到框架帮忙进行简单的 Pull Requests&#xff08;以下简称 PR&#xff09; 测试&#xff0c;因为也是第一次接触到这种操…