使用钩子函数的方式提取视觉特征

news2024/10/13 20:59:00

通过注册钩子函数,可以在模型的计算过程中插入需要执行的任意代码片段。在视觉特征提取过程中可以根据模型的结构,将正向钩子函数注册到指定的层中,然后通过读取该层的输入或输出数据,将视觉特征提取出来。

找到目标层,可以通过模型的源码找到指定的目标层,也可以通过print函数将模型对象输出并从中选取要注册钩子函数的目标层。

import torch
import torchvision
from PIL import Image
import torchvision.transforms as T

transforms = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

img = Image.open('./test/dog/dog.55.jpg')
img_input = transforms(img)
img_input = img_input.unsqueeze(0)
img_input = img_input.to(device)

model = torchvision.models.resnet18()               # 没加载预训练模型
in_feat_num = model.fc.in_features
model.fc = torch.nn.Linear(in_features=in_feat_num, out_features=2)   # 在resnet18-f37072fd.pth的预训练模型基础上finetune的dog/cat模型
model.load_state_dict(torch.load('resnet18-f37072fd_finetune_allLayer.pth'))   # 加载finetune之后的权重
model.to(device)
# model.eval()   # 设置为eval模式

in_list = []     # 存放输入目标层的特征
out_list = []    # 存放输出目标层的特征
def hook(module_placeholder, input, output):
    print('in', len(input))   # 输入项是传入该层的参数,元组类型,输出是1,说明该层只有一个输入
    for val in input:         # 遍历每个输入项
        print(f"input val:{val.size()}")  # 输出的形状是torch.Size([1, 512, 7, 7])
    for i in range(input[0].size()[0]):  # 遍历多batch的每个图片的特征
        in_list.append(input[0][i].cpu().numpy())         # 保存单张图片的特征
        print(f'in, {input[0][i].cpu().numpy().shape}')   # 输出特征形状(512, 7, 7)
    print('out', len(output))   # 输出项是特征张量,值等于batchsize = 1
    for i in range(output.size(0)):
        out_list.append(output[i].cpu().numpy())    # 保存单张图片的特征
        print(f'out, {output[i].cpu().numpy().shape}')  # 输出特征的形状(512,1,1)

model.avgpool.register_forward_hook(hook)
with torch.no_grad():
    y_pred = model(img_input)

print('Done.')

输出:

in 1
input val:torch.Size([1, 512, 7, 7])
in, (512, 7, 7)
out 1
out, (512, 1, 1)
Done.

需要注意的是,钩子函数的输入项核输出项内容定义并不一致。输入项是一个元组,元组中的元素个数与该层的输入参数个数一致,每个元素才是真正的特征数据,而输出项直接就是该层处理后的特征数据。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2185227.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

3 个简单的微分段项目

与许多大型网络安全项目一样,微分段似乎很复杂、耗时且成本高昂。 它涉及管理有关设备间服务连接的复杂细节。 一台 Web 服务器应连接到特定数据库,但不连接到其他数据库,或者负载平衡器应连接到某些 Web 服务器,同时限制与其他…

图解大模型计算加速系列:vLLM源码解析1,整体架构

整个vLLM代码读下来,给我最深的感觉就是:代码呈现上非常干净历练,但是逻辑比较复杂,环环嵌套,毕竟它是一个耦合了工程调度和模型架构改进的巨大工程。 所以在源码解读的第一篇,我想先写一下对整个代码架构…

Golang | Leetcode Golang题解之第449题序列化和反序列化二叉搜索树

题目: 题解: type Codec struct{}func Constructor() (_ Codec) { return }func (Codec) serialize(root *TreeNode) string {arr : []string{}var postOrder func(*TreeNode)postOrder func(node *TreeNode) {if node nil {return}postOrder(node.Le…

java基础 day1

学习视频链接 人机交互的小故事 微软和乔布斯借鉴了施乐实现了如今的图形化界面 图形化界面对于用户来说,操作更加容易上手,但是也存在一些问题。使用图形化界面需要加载许多图片,所以消耗内存;此外运行的速度没有命令行快 Wi…

针对考研的C语言学习(2019链表大题)

题目解析: 【考】双指针算法,逆置法,归并法。 解析:因为题目要求空间复杂度为O(1),即不能再开辟一条链表,因此我们只能用变量来整体挪动原链表。 第一步先找出中间节点 typedef NODE* Node; Node find_m…

latex有哪些颜色中文叫什么,Python绘制出来

latex有哪些颜色中文叫什么,Python绘制出来 为了展示xcolor包预定义的颜色及其对应的中文名称,并使用Python打印出来,我们可以先列出常见的预定义颜色名称,然后将它们翻译成中文,并最后用Python打印出来。 步骤 列出…

家庭记账本的设计与实现+ssm(lw+演示+源码+运行)

摘 要 随着我国经济迅速发展,人们对手机的需求越来越大,各种手机软件也都在被广泛应用,但是对于手机进行数据信息管理,对于手机的各种软件也是备受用户的喜爱,家庭记账本小程序被用户普遍使用,为方便用户能…

MySQL高阶2066-账户余额

目录 题目 准备数据 分析数据 总结 题目 请写出能够返回用户每次交易完成后的账户余额. 我们约定所有用户在进行交易前的账户余额都为0, 并且保证所有交易行为后的余额不为负数。 返回的结果请依次按照 账户(account_id), 日期( day ) 进行升序排序…

leetcode_238:除自身以外数组的乘积

给你一个整数数组 nums,返回 数组 answer ,其中 answer[i] 等于 nums 中除 nums[i] 之外其余各元素的乘积 。 题目数据 保证 数组 nums之中任意元素的全部前缀元素和后缀的乘积都在 32 位 整数范围内。 请 不要使用除法,且在 O(n) 时间复杂…

Conditional Generative Adversarial Nets

条件生成对抗网络 1.生成对抗网络 生成对网络由两个“对抗性”模型组成:一个生成模型 G,用于捕获数据分布,另一个判别模型 D,用于估计样本来自训练数据而不是 G 的概率。G 和 D 都可以是非线性映射函数。 为了学习数据 x 上的生…

设计模式-生成器模式/建造者模式Builder

构建起模式:将一个复杂类的表示与其构造分离,使得相同的构建过程能够得出不同的表示。(建造者其实和工厂模式差不多) 详细的UML类图 图文说明:距离相同的构建过程 得出不同的展示。此时就用两个类(文本生成…

探索未来:hbmqtt,Python中的AI驱动MQTT

文章目录 **探索未来:hbmqtt,Python中的AI驱动MQTT**1. 背景介绍2. hbmqtt是什么?3. 安装hbmqtt4. 简单的库函数使用方法4.1 连接到MQTT服务器4.2 发布消息4.3 订阅主题4.4 接收消息4.5 断开连接 5. 应用场景示例5.1 智能家居控制5.2 环境监测…

WebGIS之Cesium三维软件开发

目录 第 1 章 三维 WebGIS 概述 1.1 Google Earth 1 1.2 SkylineGlobe 2 1.3 LocaSpace Viewe 2 1.4 Cesium 3 1.5 Cesium API 概要 4 第 2 章 Cesium 快速入门 2.1 Cesium 环境搭建 7 2.1.1 安装 Node.js 环境 7 2.1.2 配置 Cesium 依赖 8 2.2 搭建第一个 Cesi…

【2006.07】UMLS工具——MetaMap原理深度解析

文献:《MetaMap: Mapping Text to the UMLS Metathesaurus》2006 年 7 月 14 日 https://lhncbc.nlm.nih.gov/ii/information/Papers/metamap06.pdf MetaMap:将文本映射到 UMLS 元数据库 总结 解决的问题 自动概念映射问题:解决如何将文本…

Vue3丨进一步了解这 20 个响应式 API,写码如有神

前面说的话 在 Vue2 中,个人觉得对于数据的操作比较 “黑盒” 。而 Vue3 把响应式系统更显式地暴露出来,使得我们对数据的操作有了更多的灵活性。所以,对于 Vue3 的几个响应式的 API ,我们需要更加的理解掌握,才能在实…

【MySQL】子查询、合并查询、表的连接

目录 一、子查询 1、单行子查询 显示SMITH同一部门的员工信息 2、多行子查询 in关键字 查询和10号部门的工作岗位相同的雇员的名字、岗位、工资、部门号,但是筛选出的雇员的部门不能有10号部门 all关键字 查询工资比30号部门中所有雇员工资高的雇员的姓名、…

TS(type,属性修饰符,抽象类,interface)一次性全部总结

目录 1.type 1.基本用法 2.联合类型 3.交叉类型 2.属性修饰符 1.public 属性修饰符 属性的简写形式 2.proteced 属性修饰符 3.private 属性修饰符 4.readonly 属性修饰符 3.抽象类 4.interface 1.定义类结构 2.定义对象结构 3.定义函数结构 4.接口之间的继…

postgresql|数据库|postgis编译完成后的插件迁移应该如何做(postgis插件最终章)

一、 本文的写作理由 postgis插件一般是编译安装,编译安装的原因是可以选择自己喜欢的版本,但编译的难度也是比较高的,因为有各种依赖,依赖之间还有依赖,非常容易形成依赖循环,因此,失败率是比…

【Python】CSVKit:强大的命令行CSV工具套件

CSVKit 是一个基于命令行的工具集,用于简化 CSV 文件的处理和管理。它提供了从数据转换、筛选、格式化到分析的全方位支持,特别适合需要处理复杂表格数据的用户。相比传统的 Excel 操作,CSVKit 更高效且功能更强大,非常适合数据分…

VSOMEIP代码阅读整理(1) - 网卡状态监听

一. 概述 在routing进程所使用的配置文件中,存在如下配置项目:{"unicast" : "192.168.56.101",..."service-discovery" :{"enable" : "true","multicast" : "224.244.224.245",…