Python打卡训练营学习记录Day43

news2025/6/7 15:56:13

作业:

kaggle找到一个图像数据集,用cnn网络进行训练并且用grad-cam做可视化

进阶:并拆分成多个文件

从谷歌图片中拍摄的 10 种不同类别的动物图片

数据预处理

import os
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split

def load_data(data_dir, batch_size):
    # 数据预处理
    data_transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    # 加载数据集
    image_dataset = datasets.ImageFolder(data_dir, data_transform)

    # 划分训练集和验证集
    train_size = int(0.8 * len(image_dataset))
    val_size = len(image_dataset) - train_size
    train_dataset, val_dataset = random_split(image_dataset, [train_size, val_size])

    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

    dataloaders = {'train': train_dataloader, 'val': val_dataloader}
    dataset_sizes = {'train': train_size, 'val': val_size}
    class_names = image_dataset.classes

    return dataloaders, dataset_sizes, class_names

构建并训练 CNN 模型

import torch.nn as nn

class SimpleCNN(nn.Module):
    def __init__(self, num_classes):
        super(SimpleCNN, self).__init__()
        # 定义特征提取层
        self.features = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        # 定义分类层
        self.classifier = nn.Sequential(
            nn.Linear(64 * 28 * 28, 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        # 前向传播,先通过特征提取层,再通过分类层
        x = self.features(x)
        x = x.view(-1, 64 * 28 * 28)
        x = self.classifier(x)
        return x

模型训练模块

import torch
import torch.nn as nn
import torch.optim as optim

def train_model(model, dataloaders, dataset_sizes, criterion, optimizer, num_epochs=25):
    # 判断是否有可用的 GPU,若有则使用 GPU 进行训练
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)

    for epoch in range(num_epochs):
        print(f'第 {epoch} 个 epoch,共 {num_epochs - 1} 个 epochs')
        print('-' * 10)

        # 每个 epoch 都有一个训练和验证阶段
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # 训练模式
            else:
                model.eval()   # 评估模式

            running_loss = 0.0
            running_corrects = 0

            # 迭代数据
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # 零参数梯度
                optimizer.zero_grad()

                # 前向传播
                # 只有在训练时才跟踪历史
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # 只有在训练阶段才进行反向传播和优化
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # 统计
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print(f'{phase} 阶段:损失值: {epoch_loss:.4f} 准确率: {epoch_acc:.4f}')

    return model

Grad-CAM可视化模块 

import torch
import torch.nn.functional as F
import numpy as np
import cv2

class GradCAM:
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.gradients = None
        self.activations = None

        # 反向传播钩子函数,用于捕获梯度
        def backward_hook(module, grad_input, grad_output):
            self.gradients = grad_output[0]

        # 前向传播钩子函数,用于捕获激活值
        def forward_hook(module, input, output):
            self.activations = output

        target_layer.register_forward_hook(forward_hook)
        target_layer.register_backward_hook(backward_hook)

    def forward(self, input_tensor):
        # 将模型设置为评估模式并进行前向传播
        self.model.eval()
        output = self.model(input_tensor)
        return output

    def generate_cam(self, input_tensor, target_class=None):
        # 进行前向传播
        output = self.forward(input_tensor)

        # 如果未指定目标类别,则选择输出概率最大的类别
        if target_class is None:
            target_class = torch.argmax(output, dim=1).item()

        one_hot = torch.zeros_like(output)
        one_hot[:, target_class] = 1
        one_hot.requires_grad_(True)

        # 清零模型参数的梯度
        self.model.zero_grad()
        # 计算损失并进行反向传播
        (one_hot * output).sum().backward(retain_graph=True)

        gradients = self.gradients[0]
        activations = self.activations[0]

        # 对梯度进行全局平均池化
        pooled_gradients = torch.mean(gradients, dim=[1, 2])

        for i in range(activations.shape[0]):
            activations[i, :, :] *= pooled_gradients[i]

        # 对激活值求和生成 CAM 图
        cam = torch.sum(activations, dim=0).detach().cpu().numpy()
        # 取 CAM 图的正值部分
        cam = np.maximum(cam, 0)
        # 调整 CAM 图的大小以匹配输入图像
        cam = cv2.resize(cam, (input_tensor.shape[3], input_tensor.shape[2]))
        # 归一化 CAM 图
        cam = cam - np.min(cam)
        cam = cam / np.max(cam)
        return cam

 主程序

from data_loader import load_data
from model import SimpleCNN
from train import train_model
from grad_cam import GradCAM
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from PIL import Image
import numpy as np
import cv2
import sys
import os
# 将当前目录添加到 Python 模块搜索路径中
sys.path.append(os.path.dirname(os.path.abspath(__file__)))

if __name__ == '__main__':
    # 加载数据,设置批次大小,你可以根据需要调整该值
    batch_size = 32
    # 修改解包操作以处理所有返回值
    dataloaders, dataset_sizes, class_names = load_data('raw-img', batch_size)

    # 获取类别数量
    num_classes = len(class_names)

    # 使用类别数量初始化模型
    model = SimpleCNN(num_classes)

    # 定义损失函数和优化器
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

    # 训练模型
    trained_model = train_model(model, dataloaders, dataset_sizes, criterion, optimizer, num_epochs=5)

    # 生成 Grad - CAM 可视化结果
    # 修改此处,选择实际存在的卷积层
    # grad_cam = GradCAM(model, target_layer=model.conv2)
    grad_cam = GradCAM(model, target_layer=model.features[0])
    img_path = 'path/to/your/image.jpg'
    img = Image.open(img_path).convert('RGB')
    cam = grad_cam(img)
    plt.imshow(img)
    plt.imshow(cam, alpha=0.5, cmap='jet')
    plt.axis('off')
    plt.savefig('grad_cam_result.jpg')
    plt.show()

 @浙大疏锦行

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2403081.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Android基础回顾】二:handler消息机制

Android 的 Handler 机制 是 Android 应用中实现线程间通信、任务调度、消息分发的核心机制之一,它基于 消息队列(MessageQueue) 消息循环(Looper) 消息处理器(Handler) 组成。 1 handler的使用…

每日Prompt:每天上班的状态

提示词 一个穿着清朝官服的僵尸脸上贴着符纸,在电脑面前办公,房间阴暗,电脑桌面很乱,烟灰缸里面满是烟头

C++11 右值引用:从入门到精通

文章目录 一、引言二、左值和右值(一)概念(二)区别和判断方法 三、左值引用和右值引用(一)左值引用(二)右值引用 四、移动语义(一)概念和必要性(二…

.net 使用MQTT订阅消息

在nuGet下载M2Mqtt V4.3.0版本。(支持.net framework) 订阅主题 public void LoadMQQCData() {string enpoint "xxx.xxx.x.x";//ip地址int port 1883;//端口string user "usrname";//用户名string pwd "pwd";//密码…

【递归、搜索与回溯】综合练习(四)

📝前言说明: 本专栏主要记录本人递归,搜索与回溯算法的学习以及LeetCode刷题记录,按专题划分每题主要记录:(1)本人解法 本人屎山代码;(2)优质解法 优质代码…

强化学习入门:Gym实现CartPole随机智能体

前言 最近想开一个关于强化学习专栏,因为DeepSeek-R1很火,但本人对于LLM连门都没入。因此,只是记录一些类似的读书笔记,内容不深,大多数只是一些概念的东西,数学公式也不会太多,还望读者多多指教…

STM32:CAN总线精髓:特性、电路、帧格式与波形分析详解

声明:此博客是我的学习笔记,所看课程是江协科技的CAN总线课程,知识点都大同小异,我仅进行总结并加上了我自己的理解,所引案例也都是课程中的案例,希望对你的理解有所帮助! 知识点1【CAN总线的概…

贝叶斯深度学习!华科大《Nat. Commun.》发表BNN重大突破!

华科大提出基于贝叶斯深度学习的超分辨率成像,成功被Nat. Commun.收录。可以说,这是贝叶斯神经网络BNN近期最值得关注的成果之一了。另外还有AAAI 2025上的Bella新框架,计算成本降低了99.7%,也非常值得研读。 显然鉴于BNN“不确定…

【大模型LLM学习】Flash-Attention的学习记录

【大模型LLM学习】Flash-Attention的学习记录 0. 前言1. flash-attention原理简述2. 从softmax到online softmax2.1 safe-softmax2.2 3-pass safe softmax2.3 Online softmax2.4 Flash-attention2.5 Flash-attention tiling 0. 前言 Flash Attention可以节约模型训练和推理时间…

物联网数据归档之数据存储方案选择分析

在上一篇文章中《物联网数据归档方案选择分析》中凯哥分析了归档设计的两种方案,并对两种方案进行了对比。这篇文章咱们就来分析分析,归档后数据应该存储在哪里?及存储方案对比。 这里就选择常用的mysql及taos数据库来存储归档后的数据吧。 你在处理设备归档表存储方案时对…

【C语言】C语言经典小游戏:贪吃蛇(上)

文章目录 一、游戏背景及其功能二、Win32 API介绍1、Win32 API2、控制台程序3、定位坐标(COORD)4、获得句柄(GetStdHandle)5、获得光标属性(GetConsoleCursorInfo)1)描述光标属性(CO…

vue2中使用jspdf插件实现页面自定义块pdf下载

pdf下载 实现pdf下载的环境安装jspdf插件在项目中使用 实现pdf下载的环境 项目需求案例背景,点击【pdf下载】按钮,弹出pdf下载弹窗,显示需要下载四个模块的下载进度,下载完成后,关闭弹窗即可! 项目使用的是…

如何防止服务器被用于僵尸网络(Botnet)攻击 ?

防止服务器被用于僵尸网络(Botnet)攻击是关键的网络安全措施之一。僵尸网络是黑客利用大量被感染的计算机、服务器或物联网设备来发起攻击的网络。以下是关于如何防止服务器被用于僵尸网络攻击的技术文章: 防止服务器被用于僵尸网络&#xff…

基于cornerstone3D的dicom影像浏览器 第二十九章 自定义菜单组件

文章目录 前言一、程序结构1. 菜单数据结构2. XMenu.vue3. XSubMenu.vue4. XSubMenuSlot.vue5. XMenuItem.vue 二、调用流程总结 前言 菜单用于组织程序功能,为用户提供导航。是用户与程序交互非常重要的接口。 开源组件库像Element Plus和Ant Design中都提供了功能…

【Block总结】DBlock,结合膨胀空间注意模块(Di-SpAM)和频域模块Gated-FFN|即插即用|CVPR2025

论文信息 标题: DarkIR: Robust Low-Light Image Restoration 作者: Daniel Feijoo, Juan C. Benito, Alvaro Garcia, Marcos Conde 论文链接:https://arxiv.org/pdf/2412.13443 GitHub链接:https://github.com/cidautai/DarkIR 创新点 DarkIR提出了…

口罩佩戴检测算法AI智能分析网关V4工厂/工业等多场景守护公共卫生安全

一、引言​ 在公共卫生安全日益受到重视的当下,口罩佩戴成为预防病毒传播、保障人员健康的重要措施。为了高效、精准地实现对人员口罩佩戴情况的监测,AI智能分析网关V4口罩检测方案应运而生。该方案依托先进的人工智能技术与强大的硬件性能,…

Double/Debiased Machine Learning

独立同步分布的观测数据 { W i ( Y i , D i , X i ) ∣ i ∈ { 1 , . . . , n } } \{W_i(Y_i,D_i,X_i)| i\in \{1,...,n\}\} {Wi​(Yi​,Di​,Xi​)∣i∈{1,...,n}},其中 Y i Y_i Yi​表示结果变量, D i D_i Di​表示因变量, X i X_i Xi​表…

HarmonyOS Next 弹窗系列教程(4)

HarmonyOS Next 弹窗系列教程(4) 介绍 本章主要介绍和用户点击关联更加密切的菜单控制(Menu) 和 气泡提示(Popup) 它们出现显示弹窗出现的位置都是在用户点击屏幕的位置相关 菜单控制(Menu&…

【C】-递归

1、递归概念 递归(Recursion)是编程中一种重要的解决问题的方法,其核心思想是函数通过调用自身来解决规模更小的子问题,直到达到最小的、可以直接解决的基准情形(Base Case)。 核心:自己调用…

飞马LiDAR500雷达数据预处理

0 引言 在使用飞马D2000无人机搭载LiDAR500进行作业完成后,需要对数据进行预处理,方便给内业人员开展点云分类等工作。在开始操作前,先了解一下使用的软硬件及整体流程。 0.1 外业测量设备 无人机:飞马D2000S激光模块&#xff…