python打卡day42

news2025/6/5 2:37:57

Grad-CAM与Hook函数

知识点回顾

  1. 回调函数
  2. lambda函数
  3. hook函数的模块钩子和张量钩子
  4. Grad-CAM的示例

在深度学习中,我们经常需要查看或修改模型中间层的输出或梯度,但标准的前向传播和反向传播过程通常是一个黑盒,很难直接访问中间层的信息。PyTorch 提供了一种强大的工具——hook 函数,它允许我们在不修改模型结构的情况下,获取或修改中间层的信息。常用场景如下:

  1. 调试与可视化中间层输出
  2. 特征提取:如在图像分类模型中提取高层语义特征用于下游任务
  3. 梯度分析与修改: 在训练过程中,对某些层进行梯度裁剪或缩放,以改变模型训练的动态
  4. 模型压缩:在推理阶段对特定层的输出应用掩码(如剪枝后的模型权重掩码),实现轻量化推理

1、回调函数

Hook本质是回调函数,所以我们先介绍一下回调函数。回调函数是作为参数传递给其他函数的函数,其目的是在某个特定事件发生时被调用执行。这种机制允许代码在运行时动态指定需要执行的逻辑,其中回调函数作为参数传入,所以在定义的时候一般用callback来命名

在 PyTorch 的 Hook API 中,回调参数通常命名为 hook,PyTorch 的 Hook 机制基于其动态计算图系统:

  1. 当你注册一个 Hook 时,PyTorch 会在计算图的特定节点(如模块或张量)上添加一个回调函数
  2. 当计算图执行到该节点时(前向或反向传播),自动触发对应的 Hook 函数
  3. Hook 函数可以访问或修改流经该节点的数据(如输入、输出或梯度)

2、lambda函数

在hook中常常用到lambda函数,它是一种匿名函数(没有正式名称的函数),最大特点是用完即弃,无需提前命名和定义。它的语法形式非常简约,仅需一行即可完成定义,格式:lambda 参数列表: 表达式

  • 参数列表:可以是单个参数、多个参数或无参数
  • 表达式:函数的返回值(无需 return 语句,表达式结果直接返回)

举个例子

# 定义匿名函数:计算平方
square = lambda x: x ** 2

# 调用
print(square(5))  # 输出: 25

3、hook函数

PyTorch 提供了两种主要的 hook:

  • Module Hooks(模块钩子):用于监听整个模块的输入和输出
  • Tensor Hooks:用于监听张量的梯度

(1)模块钩子

允许我们在模块的输入或输出经过时进行监听。PyTorch 提供了两种模块钩子:

  1. register_forward_hook:在模块的前向传播完成后立即被调用,这个函数可以访问模块的输入和输出,但不能修改
  2. register_backward_hook:在反向传播过程中被调用的,可以用来获取或修改梯度信息

前向钩子举个例子

# 创建模型实例
model = SimpleModel()

# 创建一个列表用于存储中间层的输出
conv_outputs = []

# 定义前向钩子函数 - 用于在模型前向传播过程中获取中间层信息
def forward_hook(module, input, output):
    print(f"钩子被调用!模块类型: {type(module)}")
    print(f"输入形状: {input[0].shape}") #  input是一个元组,对应 (image, label)
    print(f"输出形状: {output.shape}")
    
    # 保存卷积层的输出用于后续分析
    # 使用detach()避免追踪梯度,防止内存泄漏
    conv_outputs.append(output.detach())

# 在卷积层注册前向钩子
# register_forward_hook返回一个句柄,用于后续移除钩子
hook_handle = model.conv.register_forward_hook(forward_hook)

# 创建一个随机输入张量 (批次大小=1, 通道=1, 高度=4, 宽度=4)
x = torch.randn(1, 1, 4, 4)

# 执行前向传播 - 此时会自动触发钩子函数
output = model(x)

# 释放钩子 - 重要!防止在后续模型使用中持续调用钩子造成意外行为或内存泄漏
hook_handle.remove()

反向钩子

# 定义一个存储梯度的列表
conv_gradients = []

# 定义反向钩子函数
def backward_hook(module, grad_input, grad_output):
    print(f"反向钩子被调用!模块类型: {type(module)}")
    print(f"输入梯度数量: {len(grad_input)}")
    print(f"输出梯度数量: {len(grad_output)}")
    
    # 保存梯度供后续分析
    conv_gradients.append((grad_input, grad_output))

# 在卷积层注册反向钩子
hook_handle = model.conv.register_backward_hook(backward_hook)

# 创建一个随机输入并进行前向传播
x = torch.randn(1, 1, 4, 4, requires_grad=True)
output = model(x)

# 定义一个简单的损失函数并进行反向传播
loss = output.sum()
loss.backward()

# 释放钩子
hook_handle.remove()

(2)张量钩子

PyTorch 还提供了张量钩子,允许我们直接监听和修改张量的梯度。张量钩子有两种:

  1. register_hook:用于监听张量的梯度
  2. register_full_backward_hook:用于在完整的反向传播过程中监听张量的梯度
# 创建一个需要计算梯度的张量
x = torch.tensor([2.0], requires_grad=True)
y = x ** 2
z = y ** 3

# 定义一个钩子函数,用于修改梯度
def tensor_hook(grad):
    print(f"原始梯度: {grad}")
    # 修改梯度,例如将梯度减半
    return grad / 2

# 在y上注册钩子
hook_handle = y.register_hook(tensor_hook)

# 计算梯度,梯度会从z反向传播经过y到x,此时调用钩子函数
z.backward()

print(f"x的梯度: {x.grad}")

# 释放钩子
hook_handle.remove()

4、Grad-CAM

一个可视化算法,通过梯度信息用热力图显示图片中哪些区域让CNN做出了某个分类决定(比如为什么认为这是“猫”),原理:

  • 梯度计算:看最后几层特征图的梯度,哪个特征图对预测“猫”的贡献大
  • 加权融合:把重要的特征图合并成一张热力图(重要区域更亮)
  • 叠加显示:把热力图盖在原图上,一眼看出猫的脸/耳朵等关键部位被高亮了
# Grad-CAM实现
class GradCAM:
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.gradients = None
        self.activations = None
        
        # 注册钩子,用于获取目标层的前向传播输出和反向传播梯度
        self.register_hooks()
        
    def register_hooks(self):
        # 前向钩子函数,在目标层前向传播后被调用,保存目标层的输出(激活值)
        def forward_hook(module, input, output):
            self.activations = output.detach()
        
        # 反向钩子函数,在目标层反向传播后被调用,保存目标层的梯度
        def backward_hook(module, grad_input, grad_output):
            self.gradients = grad_output[0].detach()
        
        # 在目标层注册前向钩子和反向钩子
        self.target_layer.register_forward_hook(forward_hook)
        self.target_layer.register_backward_hook(backward_hook)
    
    def generate_cam(self, input_image, target_class=None):
        # 前向传播,得到模型输出
        model_output = self.model(input_image)
        
        if target_class is None:
            # 如果未指定目标类别,则取模型预测概率最大的类别作为目标类别
            target_class = torch.argmax(model_output, dim=1).item()
        
        # 清除模型梯度,避免之前的梯度影响
        self.model.zero_grad()
        
        # 反向传播,构造one-hot向量,使得目标类别对应的梯度为1,其余为0,然后进行反向传播计算梯度
        one_hot = torch.zeros_like(model_output)
        one_hot[0, target_class] = 1
        model_output.backward(gradient=one_hot)
        
        # 获取之前保存的目标层的梯度和激活值
        gradients = self.gradients
        activations = self.activations
        
        # 对梯度进行全局平均池化,得到每个通道的权重,用于衡量每个通道的重要性
        weights = torch.mean(gradients, dim=(2, 3), keepdim=True)
        
        # 加权激活映射,将权重与激活值相乘并求和,得到类激活映射的初步结果
        cam = torch.sum(weights * activations, dim=1, keepdim=True)
        
        # ReLU激活,只保留对目标类别有正贡献的区域,去除负贡献的影响
        cam = F.relu(cam)
        
        # 调整大小并归一化,将类激活映射调整为与输入图像相同的尺寸(32x32),并归一化到[0, 1]范围
        cam = F.interpolate(cam, size=(32, 32), mode='bilinear', align_corners=False)
        cam = cam - cam.min()
        cam = cam / cam.max() if cam.max() > 0 else cam
        
        return cam.cpu().squeeze().numpy(), target_class

idx = 102  # 选择测试集中的第101张图片 (索引从0开始)
image, label = testset[idx]
print(f"选择的图像类别: {classes[label]}")

# 转换图像以便可视化
def tensor_to_np(tensor):
    img = tensor.cpu().numpy().transpose(1, 2, 0)
    mean = np.array([0.5, 0.5, 0.5])
    std = np.array([0.5, 0.5, 0.5])
    img = std * img + mean
    img = np.clip(img, 0, 1)
    return img

# 添加批次维度并移动到设备
input_tensor = image.unsqueeze(0).to(device)

# 初始化Grad-CAM(选择最后一个卷积层)
grad_cam = GradCAM(model, model.conv3)

# 生成热力图
heatmap, pred_class = grad_cam.generate_cam(input_tensor)

# 可视化
plt.figure(figsize=(12, 4))

# 原始图像
plt.subplot(1, 3, 1)
plt.imshow(tensor_to_np(image))
plt.title(f"原始图像: {classes[label]}")
plt.axis('off')

# 热力图
plt.subplot(1, 3, 2)
plt.imshow(heatmap, cmap='jet')
plt.title(f"Grad-CAM热力图: {classes[pred_class]}")
plt.axis('off')

# 叠加的图像
plt.subplot(1, 3, 3)
img = tensor_to_np(image)
heatmap_resized = np.uint8(255 * heatmap)
heatmap_colored = plt.cm.jet(heatmap_resized)[:, :, :3]
superimposed_img = heatmap_colored * 0.4 + img * 0.6
plt.imshow(superimposed_img)
plt.title("叠加热力图")
plt.axis('off')

plt.tight_layout()
plt.savefig('grad_cam_result.png')
plt.show()

@浙大疏锦行

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2396638.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

XMOS以全新智能音频及边缘AI技术亮相广州国际专业灯光音响展

全球领先的边缘AI和智能音频解决方案提供商XMOS于5月27-30日亮相第23届广州国际专业灯光、音响展览会(prolight sound Guangzhou,以下简称“广州展”,XMOS展位号:5.2A66)。在本届展会上,XMOS将展出先进的音…

Playwright 测试框架 - Node.js

🚀超全实战:基于 Playwright + Node.js 的自动化测试项目教程【附源码】 📌 本文适合自动化测试入门者 & 前端测试实战者。从零开始手把手教你搭建一个 Playwright + Node.js 项目,涵盖配置、测试用例编写、运行与调试、报告生成以及实用进阶技巧。建议收藏!👍 �…

机器学习有监督学习sklearn实战二:六种算法对鸢尾花(Iris)数据集进行分类和特征可视化

本项目代码在个人github链接:https://github.com/KLWU07/Machine-learning-Project-practice 六种分类算法分别为逻辑回归LR、线性判别分析LDA、K近邻KNN、决策树CART、朴素贝叶斯NB、支持向量机SVM。 一、项目代码描述 1.数据准备和分析可视化 加载鸢尾花数据集&…

vr中风--数据处理模型搭建与训练2

位置http://localhost:8888/notebooks/Untitled1-Copy1.ipynb # -*- coding: utf-8 -*- """ MUSED-I康复评估系统(增强版) 包含:多通道sEMG数据增强、混合模型架构、标准化处理 """ import numpy as np impor…

鸿蒙next系统以后会取代安卓吗?

点击上方关注 “终端研发部” 设为“星标”,和你一起掌握更多数据库知识 官方可没说过取代谁谁,三足鼎立不好吗?三分天下,并立共存。 鸿蒙基于Linux,有人说套壳;ios/macos基于Unix,说它ios开源了…

PolyGen:一个用于 3D 网格的自回归生成模型 论文阅读

[2002.10880] PolyGen:一个用于 3D 网格的自回归生成模型 --- [2002.10880] PolyGen: An Autoregressive Generative Model of 3D Meshes 图 2:PolyGen 首先生成网格顶点(左侧),然后基于这些顶点生成网格面&#xff0…

系统思考:成长与投资不足

最近认识了一位95后年轻创业者,短短2年时间,他的公司从十几个人发展到几百人,规模迅速扩大。随着团队壮大,用户池也在持续扩大,但令人困惑的是,业绩增长却没有明显提升,甚至人效持续下滑。尽管公…

快手可灵视频V1.6模型API如何接入免费AI开源项目工具

全球领先的视频生成大模型:可灵是首个效果对标 Sora 、面向用户开放的视频生成大模型,目前在国内及国际上均处于领先地位。快手视频生成大模型“可灵”(Kling),是全球首个真正用户可用的视频生成大模型,自面…

数学建模期末速成 最短路径

关键词:Dijkstra算法 Floyd算法 例题 已知有6个村庄,各村的小学生人数如表所列,各村庄间的距离如图所示。现在计划建造一所医院和一所小学,问医院应建在哪个村庄才能使最远村庄的人到医院看病所走的路最短?又问小学建…

Java开发经验——阿里巴巴编码规范实践解析7

摘要 本文主要解析了阿里巴巴 Java 开发中的 SQL 编码规范,涉及 SQL 查询优化、索引建立、字符集选择、分页查询处理、外键与存储过程的使用等多个方面,旨在帮助开发者提高代码质量和数据库操作性能,避免常见错误和性能陷阱。 1. 【强制】业…

权威认证与质量保障:第三方检测在科技成果鉴定测试中的核心作用

科技成果鉴定测试是衡量科研成果技术价值与应用潜力的关键环节,其核心目标在于通过科学验证确保成果的可靠性、创新性和市场适配性。第三方检测机构凭借其独立性、专业性和权威性,成为科技成果鉴定测试的核心支撑主体。本文从测试流程、第三方检测的价值…

Perforce P4产品简介:无限扩展+全球协作+安全管控+工具集成(附下载)

本产品简介由Perforce中国授权合作伙伴——龙智编辑整理,旨在带您快速了解Perforce P4版本控制系统的强大之处。 世界级无限可扩展的版本控制系统 Perforce P4(原Helix Core)是业界领先的版本控制平台,备受19家全球Top20 AAA级游…

网络协议入门:TCP/IP五层模型如何实现全球数据传输?

🔍 开发者资源导航 🔍🏷️ 博客主页: 个人主页📚 专栏订阅: JavaEE全栈专栏 内容: 网络初识什么是网络?关键概念认识协议五元组 协议分层OSI七层模型TCP/IP五层(四层&…

Docker安装Redis集群(3主3从+动态扩容、缩容)保姆级教程含踩坑及安装中遇到的问题解决

前言 部署集群前,我们需要先掌握Redis分布式存储的核心算法。了解这些算法能帮助我们在实际工作中做出合理选择,同时清晰认识各方案的优缺点。 一、分布式存储算法 我们通过一道大厂面试题来进行阐述。 如下:1-2亿条数据需要缓存&#xff…

企业级 AI 开发新范式:Spring AI 深度解析与实践

一、Spring AI 的核心架构与设计哲学 1.1 技术定位与价值主张 Spring AI 作为 Spring 生态系统的重要组成部分,其核心使命是将人工智能能力无缝注入企业级 Java 应用。它通过标准化的 API 抽象和 Spring Boot 的自动装配机制,让开发者能够以熟悉的 Spr…

如何用docker部署ELK?

环境: ELK 8.8.0 Ubuntu20.04 问题描述: 如何用docker部署ELK? 解决方案: 一、环境准备 (一)主机设置 安装 Docker Engine :版本需为 18.06.0 或更新。可通过命令 docker --version 检查…

Redis最佳实践——安全与稳定性保障之高可用架构详解

全面详解 Java 中 Redis 在电商应用的高可用架构设计 一、高可用架构核心模型 1. 多层级高可用体系 #mermaid-svg-Ffzq72Onkv7wgNKQ {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-Ffzq72Onkv7wgNKQ .error-icon{f…

【Python 算法零基础 4.排序 ⑥ 快速排序】

既有锦绣前程可奔赴,亦有往日岁月可回首 —— 25.5.25 选择排序回顾 ① 遍历数组:从索引 0 到 n-1(n 为数组长度)。 ② 每轮确定最小值:假设当前索引 i 为最小值索引 min_index。从 i1 到 n-1 遍历,若找到…

Go 即时通讯系统:日志模块重构,并从main函数开始

重构logger 上次写的logger.go过于繁琐,有很多没用到的功能;重构后只提供了简洁的日志接口,支持日志轮转、多级别日志记录等功能,并采用单例模式确保全局只有一个日志实例 全局变量 var (once sync.Once // 用于实现…

MYSQL MGR高可用

1,MYSQL MGR高可用是什么 简单来说,MySQL MGR 的核心目标就是:确保数据库服务在部分节点(服务器)发生故障时,整个数据库集群依然能够继续提供读写服务,最大限度地减少停机时间。 2. 核心优势 v…