day51 python CBAM注意力

news2025/6/10 8:52:27

目录

一、CBAM 模块简介

二、CBAM 模块的实现

(一)通道注意力模块

(二)空间注意力模块

(三)CBAM 模块的组合

三、CBAM 模块的特性

四、CBAM 模块在 CNN 中的应用


一、CBAM 模块简介

在之前的探索中,我们已经了解了 SE(Squeeze-and-Excitation)通道注意力模块,它通过关注特征图的通道维度,提升了模型对重要特征的感知能力。然而,SE 模块仅关注“哪些通道重要”,而忽略了“重要信息在空间中的位置”。这就好比我们只看到了一幅画的整体色调,却没注意到画中最吸引人的细节所在。

而 CBAM 模块正是为了解决这一局限而生。它是一种可以无缝集成到任何卷积神经网络架构中的注意力模块,核心目标是通过学习的方式,自动获取特征图在通道和空间维度上的重要性,并对特征图进行自适应调整。简单来说,它就像是给模型装上了一副“智能眼镜”,让模型能够更精准地聚焦于图像中关键的部分,从而提升模型的特征表达能力和性能。

CBAM 由两个主要部分组成:通道注意力模块(Channel Attention Module)和空间注意力模块(Spatial Attention Module)。这两个模块顺序连接,共同作用于输入的特征图。通道注意力模块负责分析“哪些通道的特征更关键”,例如图像中的颜色、纹理通道等;空间注意力模块则定位“关键特征在图像中的具体位置”,比如物体所在区域。二者的结合,让模型同时学会“关注什么”和“关注哪里”,极大地提升了特征表达能力。

二、CBAM 模块的实现

(一)通道注意力模块

通道注意力模块的实现思路是通过全局池化操作(包括全局平均池化和全局最大池化)分别提取特征图的全局平均信息和全局最大信息,然后将这两种信息分别送入共享的全连接层网络中进行学习,最后将两个分支的结果相加并通过 Sigmoid 函数得到通道权重,再将权重与原始特征图相乘,实现对重要通道的增强和不重要通道的抑制。

以下是通道注意力模块的代码实现:

class ChannelAttention(nn.Module):
    def __init__(self, in_channels, ratio=16):
        """
        通道注意力机制初始化
        参数:
            in_channels: 输入特征图的通道数
            ratio: 降维比例,用于减少参数量,默认为16
        """
        super().__init__()
        # 全局平均池化,将每个通道的特征图压缩为1x1,保留通道间的平均值信息
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        # 全局最大池化,将每个通道的特征图压缩为1x1,保留通道间的最显著特征
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        # 共享全连接层,用于学习通道间的关系
        # 先降维(除以ratio),再通过ReLU激活,最后升维回原始通道数
        self.fc = nn.Sequential(
            nn.Linear(in_channels, in_channels // ratio, bias=False),  # 降维层
            nn.ReLU(),  # 非线性激活函数
            nn.Linear(in_channels // ratio, in_channels, bias=False)   # 升维层
        )
        # Sigmoid函数将输出映射到0-1之间,作为各通道的权重
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        """
        前向传播函数
        参数:
            x: 输入特征图,形状为 [batch_size, channels, height, width]
        返回:
            调整后的特征图,通道权重已应用
        """
        # 获取输入特征图的维度信息,这是一种元组的解包写法
        b, c, h, w = x.shape
        # 对平均池化结果进行处理:展平后通过全连接网络
        avg_out = self.fc(self.avg_pool(x).view(b, c))
        # 对最大池化结果进行处理:展平后通过全连接网络
        max_out = self.fc(self.max_pool(x).view(b, c))
        # 将平均池化和最大池化的结果相加并通过sigmoid函数得到通道权重
        attention = self.sigmoid(avg_out + max_out).view(b, c, 1, 1)
        # 将注意力权重与原始特征相乘,增强重要通道,抑制不重要通道
        return x * attention  # 这个运算是pytorch的广播机制

(二)空间注意力模块

空间注意力模块的核心思想是先对特征图进行通道维度的池化操作(包括平均池化和最大池化),得到两个描述特征图空间分布的特征图,然后将这两个特征图拼接在一起,通过一个卷积层进行特征提取,最后将得到的空间注意力权重与原始特征图相乘,实现对空间位置的加权。

以下是空间注意力模块的代码实现:

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super().__init__()
        self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # 通道维度池化
        avg_out = torch.mean(x, dim=1, keepdim=True)  # 平均池化:(B,1,H,W)
        max_out, _ = torch.max(x, dim=1, keepdim=True)  # 最大池化:(B,1,H,W)
        pool_out = torch.cat([avg_out, max_out], dim=1)  # 拼接:(B,2,H,W)
        attention = self.conv(pool_out)  # 卷积提取空间特征
        return x * self.sigmoid(attention)  # 特征与空间权重相乘

(三)CBAM 模块的组合

CBAM 模块就是将通道注意力模块和空间注意力模块按照顺序连接起来,先进行通道注意力的调整,再进行空间注意力的调整。这种串行的结构使得特征图先在通道维度上被优化,然后再在空间维度上被优化,从而实现对特征图的全面增强。

以下是 CBAM 模块的代码实现:

class CBAM(nn.Module):
    def __init__(self, in_channels, ratio=16, kernel_size=7):
        super().__init__()
        self.channel_attn = ChannelAttention(in_channels, ratio)
        self.spatial_attn = SpatialAttention(kernel_size)

    def forward(self, x):
        x = self.channel_attn(x)
        x = self.spatial_attn(x)
        return x

三、CBAM 模块的特性

  1. 轻量级设计:CBAM 模块仅增加了少量的计算量,主要是全局池化和简单的卷积操作,因此它非常适合嵌入到各种 CNN 架构中,例如 ResNet、YOLO 等,不会对模型的性能产生过大的负担。

  2. 即插即用:CBAM 模块的设计非常灵活,无需修改原有模型的主体结构,可以直接作为模块插入到卷积层之间,方便我们对现有的模型进行升级和优化。

  3. 双重优化:CBAM 模块同时对通道和空间维度的特征进行优化,能够提升模型在复杂场景下的性能,例如小目标检测、语义分割等任务。通过增强重要特征和抑制不重要特征,模型能够更好地捕捉到图像中的关键信息,从而提高特征表达能力。

四、CBAM 模块在 CNN 中的应用

为了验证 CBAM 模块的效果,我将其应用到了一个简单的 CNN 模型中,并使用 CIFAR-10 数据集进行了训练和测试。以下是带有 CBAM 模块的 CNN 模型的代码实现:

class CBAM_CNN(nn.Module):
    def __init__(self):
        super(CBAM_CNN, self).__init__()
        
        # ---------------------- 第一个卷积块(带CBAM) ----------------------
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32) # 批归一化
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(kernel_size=2)
        self.cbam1 = CBAM(in_channels=32)  # 在第一个卷积块后添加CBAM
        
        # ---------------------- 第二个卷积块(带CBAM) ----------------------
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(kernel_size=2)
        self.cbam2 = CBAM(in_channels=64)  # 在第二个卷积块后添加CBAM
        
        # ---------------------- 第三个卷积块(带CBAM) ----------------------
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.relu3 = nn.ReLU()
        self.pool3 = nn.MaxPool2d(kernel_size=2)
        self.cbam3 = CBAM(in_channels=128)  # 在第三个卷积块后添加CBAM
        
        # ---------------------- 全连接层 ----------------------
        self.fc1 = nn.Linear(128 * 4 * 4, 512)
        self.dropout = nn.Dropout(p=0.5)
        self.fc2 = nn.Linear(512, 10)

    def forward(self, x):
        # 第一个卷积块
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)
        x = self.pool1(x)
        x = self.cbam1(x)  # 应用CBAM
        
        # 第二个卷积块
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)
        x = self.pool2(x)
        x = self.cbam2(x)  # 应用CBAM
        
        # 第三个卷积块
        x = self.conv3(x)
        x = self.bn3(x)
        x = self.relu3(x)
        x = self.pool3(x)
        x = self.cbam3(x)  # 应用CBAM
        
        # 全连接层
        x = x.view(-1, 128 * 4 * 4)
        x = self.fc1(x)
        x = self.relu3(x)
        x = self.dropout(x)
        x = self.fc2(x)
        
        return x

在训练过程中,我设置了 50 个 epoch,并使用了 Adam 优化器和学习率调度器。以下是训练过程中的部分输出信息:

Epoch: 1/50 | Batch: 100/782 | 单Batch损失: 1.8068 | 累计平均损失: 1.9504
Epoch: 1/50 | Batch: 200/782 | 单Batch损失: 1.6703 | 累计平均损失: 1.8310
……
Epoch 1/50 完成 | 训练准确率: 42.49% | 测试准确率: 58.60%
Epoch 2/50 完成 | 训练准确率: 56.67% | 测试准确率: 67.22%
……

最终,经过 50 个 epoch 的训练,模型的测试准确率达到了 85.98%。通过观察训练过程中的损失值和准确率变化,可以看到模型在训练过程中逐渐收敛,准确率稳步提升。

@浙大疏锦行

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2406419.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

使用VMware克隆功能快速搭建集群

自己搭建的虚拟机,后续不管是学习java还是大数据,都需要集群,java需要分布式的微服务,大数据Hadoop的计算集群,如果从头开始搭建虚拟机会比较费时费力,这里分享一下如何使用克隆功能快速搭建一个集群 先把…

篇章一 论坛系统——前置知识

目录 1.软件开发 1.1 软件的生命周期 1.2 面向对象 1.3 CS、BS架构 1.CS架构​编辑 2.BS架构 1.4 软件需求 1.需求分类 2.需求获取 1.5 需求分析 1. 工作内容 1.6 面向对象分析 1.OOA的任务 2.统一建模语言UML 3. 用例模型 3.1 用例图的元素 3.2 建立用例模型 …

Qt/C++学习系列之列表使用记录

Qt/C学习系列之列表使用记录 前言列表的初始化界面初始化设置名称获取简单设置 单元格存储总结 前言 列表的使用主要基于QTableWidget控件,同步使用QTableWidgetItem进行单元格的设置,最后可以使用QAxObject进行单元格的数据读出将数据进行存储。接下来…

基于django+vue的健身房管理系统-vue

开发语言:Python框架:djangoPython版本:python3.8数据库:mysql 5.7数据库工具:Navicat12开发软件:PyCharm 系统展示 会员信息管理 员工信息管理 会员卡类型管理 健身项目管理 会员卡管理 摘要 健身房管理…

Yolo11改进策略:Block改进|FCM,特征互补映射模块|AAAI 2025|即插即用

1 论文信息 FBRT-YOLO(Faster and Better for Real-Time Aerial Image Detection)是由北京理工大学团队提出的专用于航拍图像实时目标检测的创新框架,发表于AAAI 2025。论文针对航拍场景中小目标检测的核心难题展开研究,重点解决…

简单聊下阿里云DNS劫持事件

阿里云域名被DNS劫持事件 事件总结 根据ICANN规则,域名注册商(Verisign)认定aliyuncs.com域名下的部分网站被用于非法活动(如传播恶意软件);顶级域名DNS服务器将aliyuncs.com域名的DNS记录统一解析到shado…

循环语句之while

While语句包括一个循环条件和一段代码块&#xff0c;只要条件为真&#xff0c;就不断 循环执行代码块。 1 2 3 while (条件) { 语句 ; } var i 0; while (i < 100) {console.log(i 当前为&#xff1a; i); i i 1; } 下面的例子是一个无限循环&#xff0c;因…

机器学习复习3--模型评估

误差与过拟合 我们将学习器对样本的实际预测结果与样本的真实值之间的差异称为&#xff1a;误差&#xff08;error&#xff09;。 误差定义&#xff1a; ①在训练集上的误差称为训练误差&#xff08;training error&#xff09;或经验误差&#xff08;empirical error&#x…

联邦学习带宽资源分配

带宽资源分配是指在网络中如何合理分配有限的带宽资源&#xff0c;以满足各个通信任务和用户的需求&#xff0c;尤其是在多用户共享带宽的情况下&#xff0c;如何确保各个设备或用户的通信需求得到高效且公平的满足。带宽是网络中的一个重要资源&#xff0c;通常指的是单位时间…

今日行情明日机会——20250609

上证指数放量上涨&#xff0c;接近3400点&#xff0c;个股涨多跌少。 深证放量上涨&#xff0c;但有个小上影线&#xff0c;相对上证走势更弱。 2025年6月9日涨停股主要行业方向分析&#xff08;基于最新图片数据&#xff09; 1. 医药&#xff08;11家涨停&#xff09; 代表标…

GC1808:高性能音频ADC的卓越之选

在音频处理领域&#xff0c;高质量的音频模数转换器&#xff08;ADC&#xff09;是实现精准音频数字化的关键。GC1808&#xff0c;一款96kHz、24bit立体声音频ADC&#xff0c;以其卓越的性能和高性价比脱颖而出&#xff0c;成为众多音频设备制造商的理想选择。 GC1808集成了64倍…

生产管理系统开发:专业软件开发公司的实践与思考

生产管理系统开发的关键点 在当前制造业智能化升级的转型背景下&#xff0c;生产管理系统开发正逐步成为企业优化生产流程的重要技术手段。不同行业、不同规模的企业在推进生产管理数字化转型过程中&#xff0c;面临的挑战存在显著差异。本文结合具体实践案例&#xff0c;分析…

VASP软件在第一性原理计算中的应用-测试GO

VASP软件在第一性原理计算中的应用 VASP是由维也纳大学Hafner小组开发的一款功能强大的第一性原理计算软件&#xff0c;广泛应用于材料科学、凝聚态物理、化学和纳米技术等领域。 VASP的核心功能与应用 1. 电子结构计算 VASP最突出的功能是进行高精度的电子结构计算&#xff…

Centos 7 服务器部署多网站

一、准备工作 安装 Apache bash sudo yum install httpd -y sudo systemctl start httpd sudo systemctl enable httpd创建网站目录 假设部署 2 个网站&#xff0c;目录结构如下&#xff1a; bash sudo mkdir -p /var/www/site1/html sudo mkdir -p /var/www/site2/html添加测试…

从数据报表到决策大脑:AI重构电商决策链条

在传统电商运营中&#xff0c;决策链条往往止步于“数据报表层”&#xff1a;BI工具整合历史数据&#xff0c;生成滞后一周甚至更久的销售分析&#xff0c;运营团队凭经验预判需求。当爆款突然断货、促销库存积压时&#xff0c;企业才惊觉标准化BI的决策时差正成为增长瓶颈。 一…

(12)-Fiddler抓包-Fiddler设置IOS手机抓包

1.简介 Fiddler不但能截获各种浏览器发出的 HTTP 请求&#xff0c;也可以截获各种智能手机发出的HTTP/ HTTPS 请求。 Fiddler 能捕获Android 和 Windows Phone 等设备发出的 HTTP/HTTPS 请求。同理也可以截获iOS设备发出的请求&#xff0c;比如 iPhone、iPad 和 MacBook 等苹…

第2课 SiC MOSFET与 Si IGBT 静态特性对比

2.1 输出特性对比 2.2 转移特性对比 2.1 输出特性对比 器件的输出特性描述了当温度和栅源电压(栅射电压)为某一具体数值时,漏极电流(集电极电流

MCP和Function Calling

MCP MCP&#xff08;Model Context Protocol&#xff0c;模型上下文协议&#xff09; &#xff0c;2024年11月底&#xff0c;由 Anthropic 推出的一种开放标准&#xff0c;旨在统一大模型与外部数据源和工具之间的通信协议。MCP 的主要目的在于解决当前 AI 模型因数据孤岛限制而…

解密鸿蒙系统的隐私护城河:从权限动态管控到生物数据加密的全链路防护

摘要 本文以健康管理应用为例&#xff0c;展示鸿蒙系统如何通过细粒度权限控制、动态权限授予、数据隔离和加密存储四大核心机制&#xff0c;实现复杂场景下的用户隐私保护。我们将通过完整的权限请求流程和敏感数据处理代码&#xff0c;演示鸿蒙系统如何平衡功能需求与隐私安…

SFTrack:面向警务无人机的自适应多目标跟踪算法——突破小尺度高速运动目标的追踪瓶颈

【导读】 本文针对无人机&#xff08;UAV&#xff09;视频中目标尺寸小、运动快导致的多目标跟踪难题&#xff0c;提出一种更简单高效的方法。核心创新在于从低置信度检测启动跟踪&#xff08;贴合无人机场景特性&#xff09;&#xff0c;并改进传统外观匹配算法以关联此类检测…