6.04打卡

news2025/7/22 22:01:17
@浙大疏锦行
DAY 43 复习日

作业:

kaggle找到一个图像数据集,用cnn网络进行训练并且用grad-cam做可视化

进阶:并拆分成多个文件

损失: 0.502 | 准确率: 75.53% 训练完成

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import torch.nn.functional as F
# 设置随机种子确保结果可复现
torch.manual_seed(42)
np.random.seed(42)
# 设置中文字体支持
plt.rcParams["font.family"] = ["SimHei"]
plt.rcParams['axes.unicode_minus'] = False  # 解决负号显示问题
# 检查GPU是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")
# 1. 数据预处理
# 训练集:使用多种数据增强方法提高模型泛化能力
train_transform = transforms.Compose([
    # 新增:调整图像大小为统一尺寸
    transforms.Resize((32, 32)),  # 确保所有图像都是32x32像素
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
# 测试集:仅进行必要的标准化,保持数据原始特性
test_transform = transforms.Compose([
    # 新增:调整图像大小为统一尺寸
    transforms.Resize((32, 32)),  # 确保所有图像都是32x32像素
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
# 定义数据集根目录
root = r'day43'
train_dataset = datasets.ImageFolder(
    root=root + '/training_set',  # 指向 train 子文件夹
    transform=train_transform
)
test_dataset = datasets.ImageFolder(
    root=root + '/test_set',  # 指向 test 子文件夹
    transform=test_transform
)
# 打印类别信息,确认数据加载正确
print(f"训练集类别: {train_dataset.classes}")
print(f"测试集类别: {test_dataset.classes}")
# 3. 创建数据加载器
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
# 定义一个简单的CNN模型
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        # 第一个卷积层,输入通道为3(彩色图像),输出通道为32,卷积核大小为3x3,填充为1以保持图像尺寸不变
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        # 第二个卷积层,输入通道为32,输出通道为64,卷积核大小为3x3,填充为1
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        # 第三个卷积层,输入通道为64,输出通道为128,卷积核大小为3x3,填充为1
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        # 最大池化层,池化核大小为2x2,步长为2,用于下采样,减少数据量并提取主要特征
        self.pool = nn.MaxPool2d(2, 2)
        # 第一个全连接层,输入特征数为128 * 4 * 4(经过前面卷积和池化后的特征维度),输出为512
        self.fc1 = nn.Linear(128 * 4 * 4, 512)
        # 第二个全连接层,输入为512,输出为2(对应猫和非猫两个类别)
        self.fc2 = nn.Linear(512, 2)
    def forward(self, x):
        # 第一个卷积层后接ReLU激活函数和最大池化操作,经过池化后图像尺寸变为原来的一半,这里输出尺寸变为16x16
        x = self.pool(F.relu(self.conv1(x)))
        # 第二个卷积层后接ReLU激活函数和最大池化操作,输出尺寸变为8x8
        x = self.pool(F.relu(self.conv2(x)))
        # 第三个卷积层后接ReLU激活函数和最大池化操作,输出尺寸变为4x4
        x = self.pool(F.relu(self.conv3(x)))
        # 将特征图展平为一维向量,以便输入到全连接层
        x = x.view(-1, 128 * 4 * 4)
        # 第一个全连接层后接ReLU激活函数
        x = F.relu(self.fc1(x))
        # 第二个全连接层输出分类结果
        x = self.fc2(x)
        return x
# 初始化模型
model = SimpleCNN()
print("模型已创建")
# 如果有GPU则使用GPU,将模型转移到对应的设备上
model = model.to(device)
# 训练模型
def train_model(model, train_loader, test_loader, epochs=20):
    # 定义损失函数为交叉熵损失,用于分类任务
    criterion = nn.CrossEntropyLoss()
    # 定义优化器为Adam,用于更新模型参数,学习率设置为0.001
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    for epoch in range(epochs):
        # 训练阶段
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for i, data in enumerate(train_loader, 0):
            # 从数据加载器中获取图像和标签
            inputs, labels = data
            # 将图像和标签转移到对应的设备(GPU或CPU)上
            inputs, labels = inputs.to(device), labels.to(device)
            # 清空梯度,避免梯度累加
            optimizer.zero_grad()
            # 模型前向传播得到输出
            outputs = model(inputs)
            # 计算损失
            loss = criterion(outputs, labels)
            # 反向传播计算梯度
            loss.backward()
            # 更新模型参数
            optimizer.step()

            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

        # 测试阶段
        model.eval()
        test_loss = 0
        correct = 0
        total = 0
        with torch.no_grad():
            for data in test_loader:
                images, labels = data
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                test_loss += criterion(outputs, labels).item()
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()
        print(f'测试集 [{epoch + 1}] 损失: {test_loss/len(test_loader):.3f} | 准确率: {100.*correct/total:.2f}%')
    print("训练完成")
    return model
# 训练模型
try:
    # 尝试加载预训练模型(如果存在)
    model.load_state_dict(torch.load('cat_classifier.pth'))
    print("已加载预训练模型")
except:
    print("无法加载预训练模型,训练新模型")
    model = train_model(model, train_loader, test_loader, epochs=10)
    # 保存训练后的模型参数
    torch.save(model.state_dict(), 'cat_classifier.pth')
# 设置模型为评估模式
model.eval()
# Grad-CAM实现
class GradCAM:
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.gradients = None
        self.activations = None
        # 注册钩子,用于获取目标层的前向传播输出和反向传播梯度
        self.register_hooks()
    def register_hooks(self):
        # 前向钩子函数,在目标层前向传播后被调用,保存目标层的输出(激活值)
        def forward_hook(module, input, output):
            self.activations = output.detach()
        # 反向钩子函数,在目标层反向传播后被调用,保存目标层的梯度
        def backward_hook(module, grad_input, grad_output):
            self.gradients = grad_output[0].detach()
        # 在目标层注册前向钩子和反向钩子
        self.target_layer.register_forward_hook(forward_hook)
        self.target_layer.register_backward_hook(backward_hook)
    def generate_cam(self, input_image, target_class=None):
        # 前向传播,得到模型输出
        model_output = self.model(input_image)
        if target_class is None:
            # 如果未指定目标类别,则取模型预测概率最大的类别作为目标类别
            target_class = torch.argmax(model_output, dim=1).item()
        # 清除模型梯度,避免之前的梯度影响
        self.model.zero_grad()
        # 反向传播,构造one-hot向量,使得目标类别对应的梯度为1,其余为0,然后进行反向传播计算梯度
        one_hot = torch.zeros_like(model_output)
        one_hot[0, target_class] = 1
        model_output.backward(gradient=one_hot)
        # 获取之前保存的目标层的梯度和激活值
        gradients = self.gradients
        activations = self.activations
        # 对梯度进行全局平均池化,得到每个通道的权重,用于衡量每个通道的重要性
        weights = torch.mean(gradients, dim=(2, 3), keepdim=True)
        # 加权激活映射,将权重与激活值相乘并求和,得到类激活映射的初步结果
        cam = torch.sum(weights * activations, dim=1, keepdim=True)
        # ReLU激活,只保留对目标类别有正贡献的区域,去除负贡献的影响
        cam = F.relu(cam)
        # 调整大小并归一化,将类激活映射调整为与输入图像相同的尺寸(32x32),并归一化到[0, 1]范围
        cam = F.interpolate(cam, size=(32, 32), mode='bilinear', align_corners=False)
        cam = cam - cam.min()
        cam = cam / cam.max() if cam.max() > 0 else cam
        return cam.cpu().squeeze().numpy(), target_class
# 可视化Grad-CAM结果的函数
import warnings
warnings.filterwarnings("ignore")
import matplotlib.pyplot as plt
# 设置中文字体支持
plt.rcParams["font.family"] = ["SimHei"]
plt.rcParams['axes.unicode_minus'] = False  # 解决负号显示问题
# 选择一个随机图像
# idx = np.random.randint(len(test_dataset))
idx = 50  # 选择测试集中的第49张图片 (索引从0开始)
image, label = test_dataset[idx]
print(f"选择的图像类别: {test_dataset.classes[label]}")
# 转换图像以便可视化
def tensor_to_np(tensor):
    img = tensor.cpu().numpy().transpose(1, 2, 0)
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    img = std * img + mean
    img = np.clip(img, 0, 1)
    return img
# 添加批次维度并移动到设备
input_tensor = image.unsqueeze(0).to(device)
# 初始化Grad-CAM(选择最后一个卷积层)
grad_cam = GradCAM(model, model.conv3)
# 生成热力图
heatmap, pred_class = grad_cam.generate_cam(input_tensor)
# 可视化
plt.figure(figsize=(12, 4))
# 原始图像
plt.subplot(1, 3, 1)
plt.imshow(tensor_to_np(image))
plt.title(f"原始图像: {test_dataset.classes[label]}")
plt.axis('off')
# 热力图
plt.subplot(1, 3, 2)
plt.imshow(heatmap, cmap='jet')
plt.title(f"Grad-CAM热力图: {test_dataset.classes[pred_class]}")
plt.axis('off')
# 叠加的图像
plt.subplot(1, 3, 3)
img = tensor_to_np(image)
heatmap_resized = np.uint8(255 * heatmap)
heatmap_colored = plt.cm.jet(heatmap_resized)[:, :, :3]
superimposed_img = heatmap_colored * 0.4 + img * 0.6
plt.imshow(superimposed_img)
plt.title("叠加热力图")
plt.axis('off')
plt.tight_layout()
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2401065.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【基于SpringBoot的图书购买系统】操作Jedis对图书图书的增-删-改:从设计到实战的全栈开发指南

引言 在当今互联网应用开发中,缓存技术已成为提升系统性能和用户体验的关键组件。Redis作为一款高性能的键值存储数据库,以其丰富的数据结构、快速的读写能力和灵活的扩展性,被广泛应用于各类系统的缓存层设计。本文将围绕一个基于Redis的图…

Spring Boot微服务架构(十):Docker与K8S部署的区别

Spring Boot微服务在Docker与Kubernetes(K8S)中的部署存在显著差异,主要体现在技术定位、管理能力、扩展性及适用场景等方面。以下是两者的核心区别及实践对比: 一、技术定位与核心功能 Docker 功能:专注于单节点容器化…

vue3:Table组件动态的字段(列)权限、显示隐藏和左侧固定

效果展示 根据后端接口返回&#xff0c;当前登录用户详情中的页面中el-table组件的显示隐藏等功能。根据菜单id查询该菜单下能后显示的列。 后端返回的数据类型: 接收到后端返回的数据后处理数据结构. Table组件文件 <!-- 自己封装的Table组件文件 --> onMounted(()>…

pikachu靶场通关笔记13 XSS关卡09-XSS之href输出

目录 一、href 1、常见取值类型 2、使用示例 3、安全风险 二、源码分析 1、进入靶场 2、代码审计 3、渗透思路 三、渗透实战 1、注入payload1 2、注入payload2 3、注入payload3 本系列为通过《pikachu靶场通关笔记》的XSS关卡(共10关&#xff09;渗透集合&#xff…

MCP客户端Client开发流程

1. uv工具入门使用指南 1.1 uv入门介绍 MCP开发要求借助uv进行虚拟环境创建和依赖管理。 uv 是一个Python 依赖管理工具&#xff0c;类似于pip 和 conda &#xff0c;但它更快、更高效&#xff0c;并且可以更好地管理 Python 虚拟环境和依赖项。它的核心目标是 替代 pip 、…

学习日记-day21-6.3

完成目标&#xff1a; 目录 知识点&#xff1a; 1.集合_哈希表存储过程说明 2.集合_哈希表源码查看 3.集合_哈希表无索引&哈希表有序无序详解 4.集合_TreeSet和TreeMap 5.集合_Hashtable和Vector&Vector源码分析 6.集合_Properties属性集 7.集合_集合嵌套 8.…

C语言探索之旅:深入理解结构体的奥秘

目录 引言 一、什么是结构体&#xff1f; 二、结构体类型的声明和初始化 1、结构体的声明 2、结构体的初始化 3、结构体的特殊声明 4、结构体的自引用 5、结构体的重命名 三、结构体的内存对齐 1、对齐规则 2、为什么存在内存对齐&#xff1f; 3、修改默认对齐数 三…

经典算法回顾之最小生成树

最小生成树&#xff08;Minimum Spanning Tree&#xff0c;简称MST&#xff09;是图论中的一个重要概念&#xff0c;主要用于解决加权无向图中连接所有顶点且总权重最小的树结构问题。本文对两种经典的算法即Prim算法和Kruskal算法进行回顾&#xff0c;并对后者的正确性给出简单…

Ubuntu下实现nginx反向代理

1. 多个ngx实例安装 脚本已经在deepseek的指导下完成啦&#xff01; deepseek写的脚本支持ubuntu/centos两种系统。 ins_prefix"/usr/local/" makefile_gen() {ngx$1 ngx_log_dir"/var/log/"$ngx"/"ngx_temp_path"/var/temp/"${ngx}…

c++ QicsTable使用实例

效果图&#xff1a; #include <QicsTable.h> #include <QicsDataModelDefault.h> #include <QVBoxLayout> Demo1::Demo1(QWidget *parent) : QWidget(parent) { ui.setupUi(this); const int numRows 10; const int numCols 5; // create th…

在WordPress上添加隐私政策页面

在如今的互联网时代&#xff0c;保护用户隐私已经成为每个网站管理员的责任。隐私政策不仅是法律要求&#xff0c;还能提高用户对网站的信任。本文将介绍两种常用方法&#xff0c;帮助你在WordPress上轻松创建并发布隐私政策页面。这些方法简单易行&#xff0c;符合中国用户的阅…

阿里云ACP云计算备考笔记 (3)——云服务器ECS

目录 第一章 整体概览 第二章 ECS简介 1、产品概念 2、ECS对比本地IDC 3、BGP机房优势 第三章 ECS实例 1、实例规格族 2、实例系列 3、应用场景推荐选型 4、实例状态 5、创建实例 ① 完成基础配置 ② 完成网络和安全组配置 ③ 完成管理配置和高级选项 ④ 确认下单…

从零开始:用Tkinter打造你的第一个Python桌面应用

目录 一、界面搭建&#xff1a;像搭积木一样组合控件 二、菜单系统&#xff1a;给应用装上“控制中枢” 三、事件驱动&#xff1a;让界面“活”起来 四、进阶技巧&#xff1a;打造专业级体验 五、部署发布&#xff1a;让作品触手可及 六、学习路径建议 在Python生态中&am…

Web开发主流前后端框架总结

&#x1f5a5; 一、前端主流框架 前端框架的核心是提升用户界面开发效率&#xff0c;实现高交互性应用。当前三大主流框架各有侧重&#xff1a; React (Meta/Facebook) 核心特点&#xff1a;采用组件化架构与虚拟DOM技术&#xff08;减少真实DOM操作&#xff0c;优化渲染性能&…

GlobalSign、DigiCert、Sectigo三种SSL安全证书有什么区别?

‌GlobalSign、DigiCert和Sectigo是三家知名的SSL证书颁发机构&#xff0c;其产品在安全性、功能、价格和适用场景上存在一定差异。选择SSL证书就像为你的网站挑选最合身的“安全盔甲”&#xff0c;核心是匹配你的实际需求&#xff0c;避免过度配置或防护不足。 一、核心特点对…

力扣面试150题--二叉搜索树中第k小的元素

Day 58 题目描述 思路 直接采取中序遍历&#xff0c;不过我们将k参与到中序遍历中&#xff0c;遍历到第k个元素就结束 /*** Definition for a binary tree node.* public class TreeNode {* int val;* TreeNode left;* TreeNode right;* TreeNode() {}* …

SQL Server Agent 不可用怎么办?

在 SQL Server Management Studio (SSMS) 中&#xff0c;SQL Server Agent 通常位于对象资源管理器&#xff08;Object Explorer&#xff09;的树形结构中&#xff0c;作为 SQL Server 实例的子节点。以下是详细说明和可能的原因&#xff1a; 1. SQL Server Agent 的位置 默认路…

css-塞贝尔曲线

文章目录 1、定义2、使用和解释 1、定义 cubic-bezier() 函数定义了一个贝塞尔曲线(Cubic Bezier)语法&#xff1a;cubic-bezier(x1,y1,x2,y2) 2、使用和解释 x1,y1,x2,y2&#xff0c;表示两个点的坐标P1(x1,y1),P2(x2,y2)将以一条直线放在范围只有 1 的坐标轴中&#xff0c;并…

docker使用proxy拉取镜像

前提条件&#xff0c;宿主机可以访问docker hub 虚拟机上telnet 宿主机7890能正常访问 下面的才是关键&#xff0c;上面部分自己想办法~ 3. 编辑 /etc/docker/daemon.json {"proxies": {"http-proxy": "http://192.168.100.1:7890","ht…

服务端定时器的学习(一)

一、定时器 1、定时器是什么&#xff1f; 定时器不仅存在于硬件领域&#xff0c;在软件层面&#xff08;客户端、网页和服务端&#xff09;也普遍应用&#xff0c;核心功能都是高效管理大量延时任务。不同应用场景下&#xff0c;其实现方式和使用方法有所差异。 2、定时器解…