DAY43打卡

news2025/6/7 18:29:52

@浙大疏锦行

kaggle找到一个图像数据cnn网络进行训练并且grad-cam可视化

进阶并拆分成多个文件

fruit_cnn_project/
├─ data/                # 存放数据集(需手动创建,后续放入图片)
│  ├─ train/            # 训练集图像
│  └─ val/              # 验证集图像
├─ models/              # 模型定义
│  └─ cnn_model.py      # CNN网络结构
├─ utils/               # 工具函数
│  ├─ dataset_utils.py  # 数据加载与预处理
│  ├─ grad_cam.py       # Grad-CAM可视化
│  └─ train_utils.py    # 训练与评估
├─ main.py              # 主程序
└─ requirements.txt     # 依赖列表(可选)
# 第一部分:导入库
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

# 第二部分:数据加载与预处理
def load_data():
    data_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    train_dataset = datasets.ImageFolder(root='data/train', transform=data_transform)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
    test_dataset = datasets.ImageFolder(root='data/test', transform=data_transform)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)
    return train_loader, test_loader

# 第三部分:模型定义
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(32 * 56 * 56, 128)
        self.fc2 = nn.Linear(128, 2)

    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = x.view(-1, 32 * 56 * 56)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 第四部分:模型训练
train_loader, _ = load_data()
model = SimpleCNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
num_epochs = 10

for epoch in range(num_epochs):
    running_loss = 0.0
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f'Epoch {epoch + 1}, Loss: {running_loss / len(train_loader)}')

torch.save(model.state_dict(), 'trained_model.pth')

# 第五部分:模型测试
_, test_loader = load_data()
model = SimpleCNN()
model.load_state_dict(torch.load('trained_model.pth'))
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for data in test_loader:
        images, labels = data
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy of the network on the test images: {100 * correct / total}%')

# 第六部分:Grad-CAM可视化(修复版)
def get_activation():
    activation = {}
    def hook(model, input, output):
        activation['target_layer'] = output.detach()
    return hook, activation

def grad_cam(model, image, target_class_index):
    hook, activation = get_activation()
    target_layer = model.conv2
    target_layer.register_forward_hook(hook)
    
    model.eval()
    image = image.unsqueeze(0)
    image.requires_grad_(True)
    
    output = model(image)
    one_hot = torch.zeros(1, output.size()[-1]).to(image.device)
    one_hot[0][target_class_index] = 1
    
    output.backward(gradient=one_hot, retain_graph=True)
    gradients = image.grad[0].cpu().numpy()
    
    # 从activation字典中获取激活图
    activation_map = activation['target_layer'].cpu().numpy()[0]
    
    weights = np.mean(gradients, axis=(1, 2))
    cam = np.zeros(activation_map.shape[1:], dtype=np.float32)
    
    for i, w in enumerate(weights):
        cam += w * activation_map[i]
    
    cam = np.maximum(cam, 0)
    cam = F.interpolate(
        torch.from_numpy(cam).unsqueeze(0).unsqueeze(0), 
        size=(224, 224), 
        mode='bilinear', 
        align_corners=False
    )[0][0].numpy()
    
    cam = (cam - cam.min()) / (cam.max() - cam.min())
    return cam

# 可视化前几张测试图片
dataiter = iter(test_loader)
images, labels = dataiter.next()

for i in range(5):  # 可视化前5张图片
    image = images[i]
    label = labels[i].item()
    cam = grad_cam(model, image, label)
    
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.imshow(image.permute(1, 2, 0).numpy())
    plt.title(f'Original Image (Class: {label})')
    plt.axis('off')
    
    plt.subplot(1, 2, 2)
    plt.imshow(image.permute(1, 2, 0).numpy())
    plt.imshow(cam, cmap='jet', alpha=0.5)
    plt.title('Grad-CAM Visualization')
    plt.axis('off')
    
    plt.tight_layout()
    plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2403221.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Leetcode 1892. 页面推荐Ⅱ

1.题目基本信息 1.1.题目描述 表: Friendship ---------------------- | Column Name | Type | ---------------------- | user1_id | int | | user2_id | int | ---------------------- (user1_id,user2_id) 是 Friendship 表的主键(具有唯一值的列的组合…

进程——环境变量及程序地址空间

目录 环境变量 概念 补充:命令行参数 引入 其它环境变量 理解 程序地址空间 引入 理解 虚拟地址存在意义 环境变量 概念 环境变量一般是指在操作系统中用来指定操作系统运行环境的一些参数。打个比方,就像你布置房间,这些参数就类…

VR视频制作有哪些流程?

VR视频制作流程知识 VR视频制作,作为融合了创意与技术的复杂制作过程,涵盖从初步策划到最终呈现的多个环节。在这个过程中,我们可以结合众趣科技的产品,解析每一环节的实现与优化,揭示背后的奥秘。 VR视频制作有哪些…

Ubuntu 系统部署 MySQL 入门篇

一、安装 MySQL 1.1 更新软件包 在终端中执行以下命令,更新系统软件包列表,确保安装的是最新版本的软件: sudo apt update 1.2 安装 MySQL 执行以下命令安装 MySQL 服务端: sudo apt install mysql-server 在安装过程中&…

【MATLAB代码】制导——平行接近法,三维,目标是运动的,订阅专栏后可直接查看MATLAB源代码

文章目录 运行结果简介代码功能概述运行结果核心模块解析代码特性与优势MATLAB例程代码调整说明相关公式视线角速率约束相对运动学方程导引律加速度指令运动学更新方程拦截条件判定运行结果 运行演示视频: 三维平行接近法导引运行演示 简介 代码功能概述 本代码实现了三维空…

黑马Java面试笔记之 微服务篇(SpringCloud)

一. SpringCloud 5大组件 SpringCloud 5大组件有哪些? 总结 五大件分别有: Eureka:注册中心Ribbon:负载均衡Feign:远程调用Hystrix:服务熔断Zuul/Gateway:网关 如果项目用到了阿里巴巴&#xff…

CLIP多模态大模型的优势及其在边缘计算中的应用

CLIP多模态大模型的优势及其在边缘计算中的应用 CLIP(Contrastive Language-Image Pre-training)模型,是OpenAI开发的一种多模态大模型。该模型通过对比学习的方式,在大规模图像-文本对上进行预训练,成功实现了图像和文…

基于STM32语音识别柔光台灯

基于STM32语音识别柔光台灯 (程序+原理图+PCB+设计报告) 功能介绍 具体功能: 基于语音识别的智能LED柔光台灯设计,主要包括语音识别模块应用,PWM波控制LED柔光灯的亮度&#xff0c…

基于PSO粒子群优化的VMD-GRU时间序列预测算法matlab仿真

目录 1.前言 2.算法运行效果图预览 3.算法运行软件版本 4.部分核心程序 5.算法仿真参数 6.算法理论概述 6.1变分模态分解(VMD) 6.2 门控循环单元(GRU) 6.3 粒子群优化(PSO) 7.参考文献 8.算法完…

探索未知惊喜,盲盒抽卡机小程序系统开发新启航

在消费市场不断追求新鲜感与惊喜体验的当下,盲盒抽卡机以其独特的魅力,迅速成为众多消费者热衷的娱乐与消费方式。我们紧跟这一潮流趋势,专注于盲盒抽卡机小程序系统的开发,致力于为商家和用户打造一个充满趣味与惊喜的数字化平台…

基于开源AI大模型与AI智能名片的S2B2C商城小程序源码优化:企业成本管理与获客留存的新范式

摘要:本文以企业成本管理的两大核心——外部成本与内部成本为切入点,结合开源AI大模型、AI智能名片及S2B2C商城小程序源码技术,构建了企业数字化转型的“技术-成本-运营”三维模型。研究结果表明,通过AI智能名片实现获客留存效率提…

Python----目标检测(YOLO简介)

一、 YOLO简介 [YOLO](You Only Look Once)是一种流行的物体检测和图像分割模型, 由华盛顿大学的约瑟夫-雷德蒙(Joseph Redmon)和阿里-法哈迪(Ali Farhadi)开发,YOLO 于 2015 年推出&#xff0c…

Profinet 协议 IO-Link 主站网关(三格电子)

一、产品概述 1.1 产品用途 SG-PN-IOL-8A-001 网关是 Profinet 从转 IO-Link 主的网关设备 ,可以将 IO-Link 从站设备接入 Profinet 系统,通过该网关可实现传感器及驱动器与控制 器之间的信息交互。网关有两个百兆网口和 8 个 IO-Link 端口,两…

Ubuntu22.04 安装 Miniconda3

Conda 是一个开源的包管理系统和环境管理系统,可用于 Python 环境管理。 Miniconda 是一个轻量级的 Conda 发行版。Miniconda 包含了 Conda、Python和一些基本包,是 Anaconda 的精简版本。 1.下载安装脚本 在 conda官网 找到需要的安装版本&#xff0…

Hubstudio浏览器如何使用Loongproxy?

1. 使用软件 1.1 Loongproxy 1. 顶级ISP资源:Loongproxy是神龙云旗下品牌,依托与全球领先ISP运营商的深度合作,Loongproxy 精选全球优质静态住宅IP资源。 2. IP池庞大:覆盖 100 国家/地区,构建庞大的 70 万 静态IP池…

硬件工程师笔记——555定时器应用Multisim电路仿真实验汇总

目录 一 555定时器基础知识 二、引脚功能 三、工作模式 1. 单稳态模式: 2. 双稳态模式(需要外部电路辅助): 3. 无稳态模式(多谐振荡器): 4. 可控脉冲宽度调制(PWM)模式: 四、典型应用 五、优点 二 555无稳态触发器 三 555单稳态触发器 四 555双稳态触发器…

ComfyUI 对图片进行放大的不同方法

本篇里 ComfyUI Wiki将讲解 ComfyUI 中几种基础的放大图片的办法,我们时常会因为设备性能问题,不能一次性生成大尺寸的图片,通常会先生成小尺寸的图像然后再进行放大。 不同的放大图片方法有不同的特点,以下是本篇教程将会涉及的方法: 像素重新采样SD 二次采样放大使用放…

Elasticsearch最新入门教程

文章目录 Elasticsearch最新入门教程1.Elasticsearch安装2.Kibana安装3.Elasticsearch关键概念4.SpringBoot整合Elasticsearch4.1 导入Elasticsearch数据4.2 创建SpringBoot项目4.3 修改pom.xml文件4.4 创建es实体类4.5 创建es的查询接口 5.DSL语句5.1 无条件查询5.2 指定返回的…

【Linux网络篇】:从HTTP到HTTPS协议---加密原理升级与安全机制的全面解析

✨感谢您阅读本篇文章,文章内容是个人学习笔记的整理,如果哪里有误的话还请您指正噢✨ ✨ 个人主页:余辉zmh–CSDN博客 ✨ 文章所属专栏:Linux篇–CSDN博客 文章目录 HTTPS协议原理一.预备知识1.什么是“加密”2.为什么要“加密”…

字符串 金额转换

package heima.Test09;import java.util.Scanner;public class Money {public static void main(String[] args) {//1。键盘录入一个金额Scanner sc new Scanner(System.in);//请输入一个数据String result "";int money;while (true) {System.out.println("请…