知识点回顾:
- 通道注意力模块复习
- 空间注意力模块
- CBAM的定义
作业:尝试对今天的模型检查参数数目,并用tensorboard查看训练过程
一、通道注意力模块复习 & CBAM实现
import torch
import torch.nn as nn
class CBAM(nn.Module):
def __init__(self, channels, reduction=16):
super().__init__()
# 通道注意力
self.channel_attention = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(channels, channels//reduction, 1),
nn.ReLU(),
nn.Conv2d(channels//reduction, channels, 1),
nn.Sigmoid()
)
# 空间注意力
self.spatial_attention = nn.Sequential(
nn.Conv2d(2, 1, 7, padding=3),
nn.Sigmoid()
)
def forward(self, x):
# 通道注意力
ca = self.channel_attention(x)
x = x * ca
# 空间注意力
sa = torch.cat([torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)], dim=1)
sa = self.spatial_attention(sa)
return x * sa
# 在ResNet中插入CBAM
model = resnet18(pretrained=True)
model.layer1[0].add_module("cbam", CBAM(64))
二、参数统计方法
from torchsummary import summary
# 检查模型参数
summary(model.to(Config.DEVICE), (3, 224, 224))
三、TensorBoard监控增强
# 在训练循环中添加
writer.add_scalar('Loss/train', running_loss/100, epoch*len(trainloader)+i)
writer.add_scalar('Accuracy/test', accuracy, epoch)
# 启动TensorBoard
# 在命令行中运行:tensorboard --logdir=runs
关键点说明:
1. CBAM模块包含通道和空间注意力分支
2. 使用summary函数可显示参数量
3. TensorBoard记录需保持writer实例的持续使用