作业:对比不同卷积层热图可视化的结果
核心差异总结
-
浅层卷积层(如第 1-3 层)
- 关注细节:聚焦输入图像的边缘、纹理、颜色块等基础特征(例:猫脸的胡须边缘、树叶的脉络)。
- 热图特点:区域小而分散,高激活区域多为局部细节,分辨率接近原图。
-
中层卷积层(如第 4-6 层)
- 关注局部组合:提取形状、部件组合等中级特征(例:猫的耳朵轮廓、椅子的椅腿结构)。
- 热图特点:区域稍大,激活区域开始整合局部信息,分辨率略降低。
-
深层卷积层(如第 7 层及以上)
- 关注全局语义:聚焦目标整体、类别核心特征(例:整只猫的轮廓、“椅子” 的整体结构)。
- 热图特点:区域更大且集中,高激活区域覆盖目标主体,分辨率较低但语义更明确。
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
# 加载预训练VGG16(仅取卷积层,去掉全连接)
model = models.vgg16(pretrained=True).features.eval() # .features包含44层卷积+池化
# 图像预处理(适配VGG输入要求)
transform = transforms.Compose([
transforms.Resize((224, 224)), # 调整尺寸
transforms.ToTensor(), # 转Tensor
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 标准化
])
def generate_heatmap(image_path, layer_indices):
# 读取图像并预处理
img = Image.open(image_path).convert('RGB')
input_tensor = transform(img).unsqueeze(0) # 增加batch维度 [1, 3, 224, 224]
# 注册各层钩子,获取特征图
features = {}
def hook(module, input, output):
features[module.__class__.__name__ + str(layer_idx)] = output.detach() # 保存特征图
heatmaps = []
for layer_idx in layer_indices:
# 注册当前层钩子
handle = model[layer_idx].register_forward_hook(hook)
# 前向传播
_ = model(input_tensor)
handle.remove() # 移除钩子,避免重复注册
# 提取特征图并生成热图
feat_map = features[model[layer_idx].__class__.__name__ + str(layer_idx)]
feat_map = feat_map.squeeze(0).cpu().numpy() # 维度:[C, H, W]
# 对通道维度求平均(简单可视化,也可取最大激活通道)
heatmap = np.mean(feat_map, axis=0)
# 归一化到0-1
heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min() + 1e-8)
heatmaps.append(heatmap)
return heatmaps
# 图层索引(VGG16的卷积层索引:0, 2, 5, 7, 10, 12, 14, 17, 19, 21, 24, 26, 28, 31, 33, 35)
layer_indices = [0, 10, 20] # 浅层(第1卷积层)、中层(第10层)、深层(第20层)
image_path = "test_image.jpg" # 替换为你的图片路径
# 生成热图
heatmaps = generate_heatmap(image_path, layer_indices)
# 可视化对比
plt.figure(figsize=(12, 8))
for i, heatmap in enumerate(heatmaps):
plt.subplot(1, 3, i+1)
plt.imshow(heatmap, cmap='jet')
plt.title(f"Layer {layer_indices[i]}")
plt.axis('off')
plt.show()
@浙大疏锦行