day43 python Grad-CAM

news2025/6/3 23:36:28

目录

一、为什么需要 Grad-CAM?

二、Grad-CAM 的原理

三、Grad-CAM 的实现

1. 模块钩子(Module Hooks)

2. Grad-CAM 的实现代码

四、学习总结


在深度学习领域,神经网络模型常常被视为“黑盒”,因为其复杂的内部结构和难以理解的决策过程。然而,随着模型可解释性研究的不断深入,Grad-CAM(Gradient-weighted Class Activation Mapping)作为一种强大的可视化工具,为我们打开了一扇窥探模型决策机制的窗口。

一、为什么需要 Grad-CAM?

在实际的深度学习项目中,我们常常面临这样的问题:模型的预测结果虽然准确,但其背后的决策依据却难以捉摸。例如,在图像分类任务中,模型是如何从一张复杂的图片中识别出特定的类别?它关注了图片的哪些区域?这些问题的答案对于理解模型的行为、优化模型性能以及发现潜在的偏差至关重要。Grad-CAM 正是为了解决这些问题而诞生的。它通过可视化模型对输入图像的关注区域,帮助我们直观地理解模型的决策过程。这种可视化的热力图不仅能够增强我们对模型的信任,还能在模型出现偏差时,提供线索以便我们进行调整和优化。

二、Grad-CAM 的原理

Grad-CAM 的核心思想是利用卷积神经网络(CNN)中卷积层的特征图(Feature Map)和对应的梯度信息,生成类激活映射(Class Activation Mapping)。具体来说,它通过以下步骤实现:

  1. 选择目标层:通常选择最后一个卷积层作为目标层,因为这一层的特征图包含了丰富的语义信息。

  2. 前向传播:将输入图像通过模型进行前向传播,获取目标层的特征图。

  3. 反向传播:对目标类别进行反向传播,计算目标层的梯度。

  4. 生成热力图:将梯度信息与特征图结合,生成热力图。热力图中的高亮区域表示模型在预测目标类别时关注的区域。

Grad-CAM 的关键在于,它利用梯度信息来衡量每个特征图通道对目标类别的贡献程度,并通过对特征图进行加权求和,生成最终的热力图。

三、Grad-CAM 的实现

为了实现 Grad-CAM,我们需要借助 PyTorch 的 hook 机制。hook 是一种强大的工具,允许我们在不修改模型结构的情况下,动态地获取或修改中间层的输出或梯度。

1. 模块钩子(Module Hooks)

模块钩子分为前向钩子(register_forward_hook)和反向钩子(register_backward_hook)。前向钩子用于获取模块的输入和输出,而反向钩子用于获取模块的梯度信息。

以下是一个简单的示例,展示如何使用模块钩子获取卷积层的输出和梯度:

import torch
import torch.nn as nn

class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.conv = nn.Conv2d(1, 2, kernel_size=3, padding=1)
        self.relu = nn.ReLU()
        self.fc = nn.Linear(2 * 4 * 4, 10)

    def forward(self, x):
        x = self.conv(x)
        x = self.relu(x)
        x = x.view(-1, 2 * 4 * 4)
        x = self.fc(x)
        return x

model = SimpleModel()

# 定义前向钩子
def forward_hook(module, input, output):
    print("前向钩子被调用!")
    print(f"输入形状: {input[0].shape}")
    print(f"输出形状: {output.shape}")

# 注册前向钩子
hook_handle = model.conv.register_forward_hook(forward_hook)

# 创建输入并执行前向传播
x = torch.randn(1, 1, 4, 4)
output = model(x)

# 移除钩子
hook_handle.remove()

通过上述代码,我们可以在卷积层的前向传播过程中获取其输入和输出。类似地,我们可以通过反向钩子获取梯度信息。

2. Grad-CAM 的实现代码

接下来,我们将实现 Grad-CAM 的完整代码。我们将使用 CIFAR-10 数据集,并基于一个简单的 CNN 模型进行实验。

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

# 定义一个简单的 CNN 模型
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(128 * 4 * 4, 512)
        self.fc2 = nn.Linear(512, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = x.view(-1, 128 * 4 * 4)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 初始化模型并加载预训练权重
model = SimpleCNN()
model.load_state_dict(torch.load('cifar10_cnn.pth'))
model.eval()

# Grad-CAM 类
class GradCAM:
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.gradients = None
        self.activations = None
        self.register_hooks()

    def register_hooks(self):
        def forward_hook(module, input, output):
            self.activations = output.detach()

        def backward_hook(module, grad_input, grad_output):
            self.gradients = grad_output[0].detach()

        self.target_layer.register_forward_hook(forward_hook)
        self.target_layer.register_backward_hook(backward_hook)

    def generate_cam(self, input_image, target_class=None):
        model_output = self.model(input_image)
        if target_class is None:
            target_class = torch.argmax(model_output, dim=1).item()
        self.model.zero_grad()
        one_hot = torch.zeros_like(model_output)
        one_hot[0, target_class] = 1
        model_output.backward(gradient=one_hot)

        gradients = self.gradients
        activations = self.activations
        weights = torch.mean(gradients, dim=(2, 3), keepdim=True)
        cam = torch.sum(weights * activations, dim=1, keepdim=True)
        cam = F.relu(cam)
        cam = F.interpolate(cam, size=(32, 32), mode='bilinear', align_corners=False)
        cam = cam - cam.min()
        cam = cam / cam.max() if cam.max() > 0 else cam
        return cam.cpu().squeeze().numpy(), target_class

# 选择一张测试图像并生成 Grad-CAM 热力图
image, label = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms.ToTensor())[102]
input_tensor = image.unsqueeze(0)

grad_cam = GradCAM(model, model.conv3)
heatmap, pred_class = grad_cam.generate_cam(input_tensor)

# 可视化结果
plt.figure(figsize=(12, 4))
plt.subplot(1, 3, 1)
plt.imshow(image.permute(1, 2, 0).numpy())
plt.title(f"原始图像: {label}")
plt.axis('off')

plt.subplot(1, 3, 2)
plt.imshow(heatmap, cmap='jet')
plt.title(f"Grad-CAM 热力图: {pred_class}")
plt.axis('off')

plt.subplot(1, 3, 3)
img = image.permute(1, 2, 0).numpy()
heatmap_resized = np.uint8(255 * heatmap)
heatmap_colored = plt.cm.jet(heatmap_resized)[:, :, :3]
superimposed_img = heatmap_colored * 0.4 + img * 0.6
plt.imshow(superimposed_img)
plt.title("叠加热力图")
plt.axis('off')

plt.tight_layout()
plt.show()

四、学习总结

通过本次实验,我对 Grad-CAM 的原理和实现有了更深入的理解。Grad-CAM 不仅能够帮助我们可视化模型的决策过程,还能在模型出现偏差时提供线索。例如,在实验中,我们发现模型在识别“青蛙”类别时,主要关注了图像的腿部和头部区域。这表明模型确实能够捕捉到关键的语义特征,但也提醒我们在数据标注和模型训练过程中需要注意潜在的偏差。

@浙大疏锦行

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2395508.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

SQL的查询优化

1. 查询优化器 1.1. SQL语句执行需要经历的环节 解析阶段:语法分析和语义检查,确保语句正确;优化阶段:通过优化器生成查询计划;执行阶段:由执行器根据查询计划实际执行操作。 1.2. 查询优化器 查询优化器…

MCU如何从向量表到中断服务

目录 1、中断向量表 2、编写中断服务例程 中断处理的核心是中断向量表(IVT),它是一个存储中断服务例程(ISR)地址的内存结构。当中断发生时,MCU通过IVT找到对应的ISR地址并跳转执行。本文将深入探讨MCU&am…

Linux线程同步实战:多线程程序的同步与调度

个人主页:chian-ocean 文章专栏-Linux Linux线程同步实战:多线程程序的同步与调度 个人主页:chian-ocean文章专栏-Linux 前言:为什么要实现线程同步线程饥饿(Thread Starvation)示例:抢票问题 …

【MySQL】事务及隔离性

目录 一、什么是事务 (一)概念 (二)事务的四大属性 (三)事务的作用 (四)事务的提交方式 二、事务的启动、回滚与提交 (一)事务的启动、回滚与提交 &am…

yolo目标检测助手:具有模型预测、图像标注功能

在人工智能浪潮席卷各行各业的今天,计算机视觉模型(如 YOLO)已成为目标检测领域的标杆。然而,模型的强大能力需要直观的界面和便捷的工具才能充分发挥其演示、验证与迭代优化的价值。为此,我开发了一款基于 WPF 的桌面…

2022 RoboCom 世界机器人开发者大赛(睿抗 caip) -高职组(国赛)解题报告 | 科学家

前言 题解 2022 RoboCom 世界机器人开发者大赛(睿抗 caip) -高职组&#xff08;国赛&#xff09;。 最后一题还考验能力&#xff0c;需要找到合适的剪枝。 RC-v1 智能管家 分值: 20分 签到题&#xff0c;map的简单实用 #include <bits/stdc.h>using namespace std;int…

基于物联网(IoT)的电动汽车(EVs)智能诊断

我是穿拖鞋的汉子&#xff0c;魔都中坚持长期主义的汽车电子工程师。 老规矩&#xff0c;分享一段喜欢的文字&#xff0c;避免自己成为高知识低文化的工程师&#xff1a; 做到欲望极简&#xff0c;了解自己的真实欲望&#xff0c;不受外在潮流的影响&#xff0c;不盲从&#x…

JDBC+HTML+AJAX实现登陆和单表的CRUD

JDBCHTMLAJAX实现登陆和单表的CRUD 导入maven依赖 <?xml version"1.0" encoding"UTF-8"?><project xmlns"http://maven.apache.org/POM/4.0.0" xmlns:xsi"http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocatio…

【C++】位图详解(一文彻底搞懂位图的使用方法与底层原理)

&#x1f308; 个人主页&#xff1a;谁在夜里看海. &#x1f525; 个人专栏&#xff1a;《C系列》《Linux系列》 ⛰️ 天高地阔&#xff0c;欲往观之。 目录 1.位图的概念 2.位图的使用方法 定义与创建 设置和清除 位访问和检查 转换为其他格式 3.位图的使用场景 1.快速…

【笔记】开源通用人工智能代理 Suna 部署全流程准备清单(Windows 系统)

#工作记录 一、基础工具与环境 开发工具 Git 或 GitHub Desktop&#xff08;代码管理&#xff09;Docker Desktop&#xff08;需启用 WSL2&#xff0c;容器化部署&#xff09;Python 3.11&#xff08;推荐版本&#xff0c;需添加到系统环境变量&#xff09;Node.js LTS&#xf…

海康工业相机SDK二次开发(VS+QT+海康SDK+C++)

前言 工业相机在现代制造和工业自动化中扮演了至关重要的角色&#xff0c;尤其是在高精度、高速度检测中。海康威视工业相机以其性能稳定、图像质量高、兼容性强而受到广泛青睐。特别是搞机器视觉的小伙伴们跟海康打交道肯定不在少数&#xff0c;笔者在平常项目中跟海康相关人…

深度学习|pytorch基本运算-乘除法和幂运算

【1】引言 前序学习进程中&#xff0c;已经对pytorch张量数据的生成和广播做了详细探究&#xff0c;文章链接为&#xff1a; 深度学习|pytorch基本运算-CSDN博客 深度学习|pytorch基本运算-广播失效-CSDN博客 上述探索的内容还止步于张量的加减法&#xff0c;在此基础上&am…

4.2.4 Spark SQL 数据写入模式

在本节实战中&#xff0c;我们详细探讨了Spark SQL中数据写入的四种模式&#xff1a;ErrorIfExists、Append、Overwrite和Ignore。通过具体案例&#xff0c;我们演示了如何使用mode()方法结合SaveMode枚举类来控制数据写入行为。我们首先读取了一个JSON文件生成DataFrame&#…

论文笔记: Urban Region Embedding via Multi-View Contrastive Prediction

AAAI 2024 1 INTRO 之前基于多视图的region embedding工作大多遵循相同的模式 单独的单视图表示多视图融合 但这种方法存在明显的局限性&#xff1a;忽略了不同视图之间的信息一致性 一个区域的多个视图所携带的信息是高度相关的&#xff0c;因此它们的表示应该是一致的如果能…

初学者如何微调大模型?从0到1详解

本文将手把手带你从0到1&#xff0c;详细解析初学者如何微调大模型&#xff0c;让你也能驾驭这些强大的AI工具。 1. 什么是大模型微调&#xff1f; 想象一下&#xff0c;预训练大模型就像一位博览群书但缺乏专业知识的通才。它掌握了海量的通用知识&#xff0c;但可能无法完美…

西瓜书第十一章——降维与度量学习

文章目录 降维与度量学习k近邻学习原理头歌实战-numpy实现KNNsklearn实现KNN 降维——多维缩放&#xff08;Multidimensional Scaling, MDS&#xff0c;MDS&#xff09;提出背景与原理重述1.**提出背景**2.**数学建模与原理推导**3.**关键推导步骤** Principal Component Analy…

Portainer安装指南:多节点监控的docker管理面板-家庭云计算专家

背景 Portainer 是一个轻量级且功能强大的容器管理面板&#xff0c;专为 Docker 和 Kubernetes 环境设计。它通过直观的 Web 界面简化了容器的部署、管理和监控&#xff0c;即使是非技术用户也能轻松上手。Portainer 支持多节点管理&#xff0c;允许用户从一个中央控制台管理多…

vscode实用配置

前端开发安装插件&#xff1a; 1.可以更好看的显示文件图标 2.用户快速打开文件 使用步骤&#xff1a;在html文件下右键点击 open with live server 即可 刷力扣&#xff1a; 安装这个插件 还需要安装node.js即可

React 项目中封装 Excel 导入导出组件:技术分享与实践

文章目录 前言一、为什么需要封装 Excel 组件&#xff1f;二、技术选型三、核心实现1. 安装依赖2. 封装Excel导出3. 封装导入组件 &#xff08;UploadExcel&#xff09; 总结 前言 在 React 项目中&#xff0c;处理 Excel 文件的导入和导出是常见的业务需求。无论是导出报表数…

【2025CCF中国开源大会】RISC-V 开源生态的挑战与机遇分论坛重磅来袭!共探开源芯片未来

点击蓝字 关注我们 CCF Opensource Development Committee 开源浪潮正从软件席卷硬件领域&#xff0c;RISC-V作为全球瞩目的开源芯片架构&#xff0c;正在重塑计算生态的版图&#xff01;相较于成熟的x86与ARM&#xff0c;RISC-V生态虽处爆发初期&#xff0c;却蕴藏着无限可能。…