【Python训练营打卡】day43 @浙大疏锦行

news2025/6/7 4:23:55

DAY 43 复习日

作业:

kaggle找到一个图像数据集,用cnn网络进行训练并且用grad-cam做可视化

进阶:并拆分成多个文件

我选择的是music_instruments   链接:Musical Instruments (kaggle.com)

#导包
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import os
import warnings
warnings.filterwarnings("ignore")

# 设置随机种子确保结果可复现
torch.manual_seed(42)
np.random.seed(42)
# 数据预处理
transform = transforms.Compose([
    transforms.Resize((64, 64)),  # 调整图像大小
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# 加载自定义数据集
dataset_path = r"F:\Program Files\MyPythonProjects\day43\music_instruments"
dataset = ImageFolder(root=dataset_path, transform=transform)

# 划分训练集和测试集
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

# 定义类别名称
classes = ('accordion', 'banjo', 'drum', 'flute', 'guitar', 
           'harmonica', 'saxophone', 'sitar', 'tabla', 'violin')
# 定义一个简单的CNN模型(调整输出类别数为10)
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(128 * 8 * 8, 512)  
        self.fc2 = nn.Linear(512, len(classes))  
        
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))  # 32x32
        x = self.pool(F.relu(self.conv2(x)))  # 16x16
        x = self.pool(F.relu(self.conv3(x)))  # 8x8
        x = x.view(-1, 128 * 8 * 8)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 初始化模型
model = SimpleCNN()
print("模型已创建")

# 使用GPU或CPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# 训练模型
def train_model(model, epochs=10):
    trainloader = torch.utils.data.DataLoader(
        train_dataset, 
        batch_size=32,
        shuffle=True, 
        num_workers=2
    )
    
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    
    for epoch in range(epochs):
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            if i % 10 == 9:
                print(f'[{epoch + 1}, {i + 1}] 损失: {running_loss / 10:.3f}')
                running_loss = 0.0
    
    print("训练完成")

# 训练或加载模型
try:
    model.load_state_dict(torch.load('music_instruments_cnn.pth'))
    print("已加载预训练模型")
except:
    print("无法加载预训练模型,使用未训练模型或训练新模型")
    train_model(model, epochs=10)
    torch.save(model.state_dict(), 'music_instruments_cnn.pth')

# 设置模型为评估模式
model.eval()

# Grad-CAM实现
class GradCAM:
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.gradients = None
        self.activations = None
        self.register_hooks()
        
    def register_hooks(self):
        def forward_hook(module, input, output):
            self.activations = output.detach()
        
        def backward_hook(module, grad_input, grad_output):
            self.gradients = grad_output[0].detach()
        
        self.target_layer.register_forward_hook(forward_hook)
        self.target_layer.register_backward_hook(backward_hook)
    
    def generate_cam(self, input_image, target_class=None):
        model_output = self.model(input_image)
        
        if target_class is None:
            target_class = torch.argmax(model_output, dim=1).item()
        
        self.model.zero_grad()
        one_hot = torch.zeros_like(model_output)
        one_hot[0, target_class] = 1
        model_output.backward(gradient=one_hot)
        
        gradients = self.gradients
        activations = self.activations
        
        weights = torch.mean(gradients, dim=(2, 3), keepdim=True)
        cam = torch.sum(weights * activations, dim=1, keepdim=True)
        cam = F.relu(cam)
        cam = F.interpolate(cam, size=(64, 64), mode='bilinear', align_corners=False)
        cam = cam - cam.min()
        cam = cam / cam.max() if cam.max() > 0 else cam
        return cam.cpu().squeeze().numpy(), target_class

# 可视化函数
def tensor_to_np(tensor):
    img = tensor.cpu().numpy().transpose(1, 2, 0)
    mean = np.array([0.5, 0.5, 0.5])
    std = np.array([0.5, 0.5, 0.5])
    img = std * img + mean
    img = np.clip(img, 0, 1)
    return img

# 设置中文字体支持
plt.rcParams["font.family"] = ["SimHei"]
plt.rcParams['axes.unicode_minus'] = False

# 选择一个测试图像
idx = 10  
image, label = test_dataset[idx]
print(f"选择的图像类别: {classes[label]}")

# 添加批次维度并移动到设备
input_tensor = image.unsqueeze(0).to(device)

# 初始化Grad-CAM(选择最后一个卷积层)
grad_cam = GradCAM(model, model.conv3)

# 生成热力图
heatmap, pred_class = grad_cam.generate_cam(input_tensor)

# 可视化
plt.figure(figsize=(12, 4))

# 原始图像
plt.subplot(1, 3, 1)
plt.imshow(tensor_to_np(image))
plt.title(f"原始图像: {classes[label]}")
plt.axis('off')

# 热力图
plt.subplot(1, 3, 2)
plt.imshow(heatmap, cmap='jet')
plt.title(f"Grad-CAM热力图: {classes[pred_class]}")
plt.axis('off')

# 叠加的图像
plt.subplot(1, 3, 3)
img = tensor_to_np(image)
heatmap_resized = np.uint8(255 * heatmap)
heatmap_colored = plt.cm.jet(heatmap_resized)[:, :, :3]
superimposed_img = heatmap_colored * 0.4 + img * 0.6
plt.imshow(superimposed_img)
plt.title("叠加热力图")
plt.axis('off')

plt.tight_layout()
plt.savefig('grad_cam_result.png')
plt.show()
print("Grad-CAM可视化完成。已保存为grad_cam_result.png")

@浙大疏锦行

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2400732.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

1-【源码剖析】kafka核心概念

从今天开始开始在csdn上记录学习的笔记,主要包括以下几个方面: kafkaflinkdoris 本系列笔记主要记录Kafka学习相关的内容。在进行kafka源码学习之前,先介绍一下Kafka的核心概念。 消息 消息是kafka中最基本的数据单元,由key和…

思科设备网络实验

一、 总体拓扑图 图 1 总体拓扑图 二、 IP地址规划 表格 1 接口地址规划 设备名称 接口/VLAN IP 功能 PC0 VLAN580 10.80.1.1 访问外网 PC1 VLAN581 10.80.2.1 访问外网 PC2 Fa0 20.80.1.100 端口镜像监控流量 PC3 VLAN585 10.80.6.1 远程登陆多层交换机0…

AWS之数据分析

目录 数据分析产品对比 1. Amazon Athena 3. AWS Lake Formation 4. AWS Glue 5. Amazon OpenSearch Service 6. Amazon Kinesis Data Analytics 7. Amazon Redshift 8.Amazon Redshift Spectrum 搜索服务对比 核心功能与定位对比 适用场景 关键差异总结 注意事项 …

C# Onnx 动漫人物头部检测

目录 效果 模型信息 项目 代码 下载 参考 效果 模型信息 Model Properties ------------------------- date:2024-10-19T12:32:20.920471 description:Ultralytics best model trained on /root/datasets/yolo/anime_head_detection/data.yaml au…

【Ragflow】24.Ragflow-plus开发日志:增加分词逻辑,修复关键词检索失效问题

概述 在RagflowPlus v0.3.0 版本推出之后,反馈比较多的问题是:检索时,召回块显著变少了。 如上图所示,进行检索测试时,关键词相似度得分为0,导致混合相似度(加权相加得到)也被大幅拉低,低于设定…

Zookeeper 集群部署与故障转移

Zookeeper 介绍 Zookeeper 是一个开源的分布式协调服务,由Apache基金会维护,专为分布式应用提供高可用、强一致性的核心基础能力。它通过简单的树形命名空间(称为ZNode树)存储数据节点(ZNode),…

Redis最佳实践——电商应用的性能监控与告警体系设计详解

Redis 在电商应用的性能监控与告警体系设计 一、原子级监控指标深度拆解 1. 内存维度监控 核心指标: # 实时内存组成分析(单位字节) used_memory: 物理内存总量 used_memory_dataset: 数据集占用量 used_memory_overhead: 管理开销内存 us…

区域徘徊检测算法AI智能分析网关V4助力公共场所/工厂等多场景安全升级

一、项目背景 随着数字化安全管理需求激增,重点场所急需强化人员异常行为监测。区域徘徊作为潜在安全威胁的早期征兆,例如校园围墙外的陌生逗留者,都可能引发安全隐患。传统人工监控模式效率低、易疏漏,AI智能分析网关V4的区域徘…

修复与升级suse linux

suse linux enterprise desktop 10提示:xxx service failed when loaded shared lib . error ibgobject.so.2.0:no such file or directory. suse linux enterprise server 12.iso 通过第一启动项引导,按照如下方式直接升级解决。

电力高空作业安全检测(2)数据集构建

数据集构建的重要性 在电力高空作业安全检测领域,利用 计算机视觉技术 进行安全监测需要大量的图像数据,这些数据需要准确标注不同的安全设备与作业人员行为。只有构建出包含真实场景的高质量数据集,才能通过深度学习等算法对高空作业中的潜…

嵌入式开发之STM32学习笔记day18

STM32F103C8T6 SPI通信读写W25Q64 1 W25Q64简介 W25Qxx系列是一种低成本、小型化且易于使用的非易失性存储器(NOR Flash),它广泛应用于需要持久化存储数据的各种场景,如数据存储、字库存储以及固件程序存储等。该系列存储器采用…

[论文阅读]PPT: Backdoor Attacks on Pre-trained Models via Poisoned Prompt Tuning

PPT: Backdoor Attacks on Pre-trained Models via Poisoned Prompt Tuning PPT: Backdoor Attacks on Pre-trained Models via Poisoned Prompt Tuning | IJCAI IJCAI-22 发表于2022年的论文,当时大家还都在做小模型NLP的相关工作(BERT,Ro…

汽车安全:功能安全FuSa、预期功能安全SOTIF与网络安全Cybersecurity 解析

汽车安全的三重防线:深入解析FuSa、SOTIF与网络安全技术 现代汽车已成为装有数千个传感器的移动计算机,安全挑战比传统车辆复杂百倍。 随着汽车智能化、网联化飞速发展,汽车电子电气架构已从简单的分布式控制系统演变为复杂的移动计算平台。现…

【C++高级主题】虚继承

目录 一、菱形继承:虚继承的 “导火索” 1.1 菱形继承的结构与问题 1.2 菱形继承的核心矛盾:多份基类实例 1.3 菱形继承的具体问题:二义性与数据冗余 二、虚继承的语法与核心目标 2.1 虚继承的声明方式 2.2 虚继承的核心目标 三、虚继…

基于 ZYNQ 的实时运动目标检测系统设计

摘 要: 传统视频监控系统在实时运动目标检测时,存在目标检测不完整和目标检测错误的局限 性 。 本研究基于体积小 、 实时性高的需求,提出了一种将动态三帧差分法与 Sobel 边缘检测算法结 合的实时目标检测方法,并基于 ZYNQ 构建了视频…

[华为eNSP] 在eNSP上实现IPv4地址以及IPv4静态路由的配置

设备名称配置 重命名设备以及关闭信息提示 此处以R1演示&#xff0c;R2R3以此类推 <Huawei>system-view [Huawei]sysname R1#关闭提示 undo info-center enable 配置路由接口IP地址 R1 [R1]interface GigabitEthernet 0/0/1[R1-GigabitEthernet0/0/1]ip address 10.0.…

2024年第十五届蓝桥杯青少组c++国赛真题——快速分解质因数

2024年第十五届蓝桥杯青少组c国赛真题——快速分解质因数 题目可点下方去处&#xff0c;支持在线编程&#xff0c;在线测评&#xff5e; 快速分解质因数_C_少儿编程题库学习中心-嗨信奥 题库收集了历届各白名单赛事真题和权威机构考级真题&#xff0c;覆盖初赛—省赛—国赛&am…

【动手学MCP从0到1】2.1 SDK介绍和第一个MCP创建的步骤详解

SDK介绍和第一个MCP 1. 安装SDK2. MCP通信协议3. 基于stdio通信3.1 服务段脚本代码3.2 客户端执行代码3.2.1 客户端的初始化设置3.2.2 创建执行进行的函数3.2.3 代码优化 4. 基于SSE协议通信 1. 安装SDK 开发mcp项目&#xff0c;既可以使用Anthropic官方提供的SDK&#xff0c;…

测试面试题总结一

目录 列表、元组、字典的区别 nvicat连接出现问题如何排查 mysql性能调优 python连接mysql数据库方法 参数化 pytest.mark.parametrize 装饰器 list1 [1,7,4,5,5,6] for i in range(len(list1): assert list1[i] < list1[i1] 这段程序有问题嘛&#xff1f; pytest.i…

【深度学习】14. DL在CV中的应用章:目标检测: R-CNN, Fast R-CNN, Faster R-CNN, MASK R-CNN

深度学习在计算机视觉中的应用介绍 深度卷积神经网络&#xff08;Deep convolutional neural network&#xff0c; DCNN&#xff09;是将深度学习引入计算机视觉发展的关键概念。通过模仿生物神经系统&#xff0c;深度神经网络可以提供前所未有的能力来解释复杂的数据模式&…