即插即用系列 | CVPR 2026 | WDAM:小波域注意力创新!高频引导低频增强,结构纹理双保真,复杂退化场景精准定位! | 代码分享
0. 前言本文介绍了WDAMWavelet-based Directional Attention Module小波方向性注意力模块其通过Haar小波变换将特征图解耦为低频结构分量与水平、垂直、对角三个方向的高频细节分量并创新性地利用高频子带生成引导权重来强化低频注意力对纹理边缘区域的聚焦首次将“高频引导低频”的协同范式引入注意力机制有效破解了传统自注意力对频率成分“一视同仁”导致的高频细节过平滑与伪影残留难题。将其作为即插即用模块轻松助力CNN、YOLO、Transformer等深度学习模型精准增强对方向性显著目标的特征表达、提升模型对模糊边界和密集堆叠场景的感知能力让模型在面对高分辨率航拍图像、遥感影像中的小目标密集分布以及工业缺陷检测中的细微纹理等挑战性场景时依然能够保持锐利的边界刻画能力与稳定的检测精度。专栏链接即插即用系列专栏链接可点击跳转免费订阅目录0. 前言1. WDAM模块简介2. WDAM模块原理与创新点 WDAM模块基本原理 WDAM模块创新点3. 适用范围与模块效果适用范围⚡模块效果4. WDAM模块代码实现1. WDAM模块简介闪烁伪影源于光 照不稳定性和行间曝光不一致在短曝光摄影中构成重大挑战会严重损害图像质量。与噪声、低光等典型伪影不同闪烁是一种具有特定时空特征的结构化退化现象而当前通用修复框架未能充分考虑这些特征导致闪烁抑制效果欠佳并产生重影伪影。本研究揭示闪烁伪影具有周期性和方向性两大固有特征并提出基于变换器架构的Flickerformer模型该模型能有效消除闪烁伪影且不引入重影效应。具体而言Flickerformer包含三个核心组件基于相位的融合模块PFM、自相关前馈网络AFFN以及基于小波的方向性注意力模块WDAM。基于周期性特征PFM通过帧间相位相关性自适应聚合突发特征AFFN则利用自相关技术捕捉帧内结构规律协同提升网络对空间重复模式的感知能力。此外受闪烁伪影方向性特征驱动WDAM通过小波域高频变化引导低频暗区修复实现对闪烁伪影的精准定位。大量实验表明Flickerformer在定量指标和视觉质量上均显著优于现有最先进方法。原始论文https://arxiv.org/pdf/2603.22794原始代码https://github.com/qulishen/Flickerformer2. WDAM模块原理与创新点 WDAM模块基本原理WDAMWavelet-based Directional Attention Module基于小波的方向性注意力模块是一种将小波变换与注意力机制深度融合的轻量化特征增强模块。其核心洞察在于自然图像中的目标结构可分解为低频轮廓信息与高频细节信息而传统注意力机制对所有频率成分同等处理往往导致高频细节被过度平滑、低频结构建模计算量过大。WDAM通过Haar小波变换将特征图显式解耦为频率与方向两个维度再引入“高频引导低频”的协同注意力机制实现了结构保留、细节增强与计算高效的三重目标。具体而言WDAM的实现包含以下关键步骤1小波分解与频率解耦输入特征图首先经过一级Haar离散小波变换被分解为四个子带分量低频近似分量LL保留了图像的整体结构与光照信息水平高频分量LH捕获水平方向的边缘与纹理变化垂直高频分量HL捕获垂直方向的结构特征对角高频分量HH则响应斜向的细节信息。这一分解过程具有完全可逆性且无额外可学习参数。2高频引导权重生成将水平与垂直高频分量进行通道拼接通过轻量级卷积层与Sigmoid激活函数生成方向性引导权重图。该权重图能够自动识别出目标边缘、纹理边界等含有丰富方向信息的区域为后续的注意力计算提供位置先验使模型将有限的计算资源聚焦于真正需要细节保持的关键位置。3低频窗口注意力计算在降采样后的低频分量上执行窗口多头注意力机制。由于LL子图的空间分辨率仅为原始输入的1/2注意力计算量降低约75%。在注意力公式中将第2步生成的方向引导权重以元素乘的方式作用于Value矩阵实现“高频特征告知模型‘看哪里’低频注意力完成‘怎么看’”的协同机制。4逆小波变换与特征重建将注意力增强后的低频分量与原始高频分量可根据需要选择是否增强进行拼接通过逆离散小波变换重建出与输入同尺寸的输出特征图。该输出特征既保留了低频分支建模的全局结构又继承了高频分支保真的边缘细节。 WDAM模块创新点小波驱动的频域解耦机制首次在注意力模块中引入Haar小波变换作为前置处理器将特征显式分解为低频结构与高频细节解决了传统自注意力对频率成分“一视同仁”导致的细节过度平滑问题。高频引导低频的协同注意力创新性地利用高频子带LH/HL生成方向感知权重并以此为先验引导低频注意力对Value特征进行调制实现了“细节定位、结构建模”的双赢显著提升了对细长目标与模糊边界的感知能力。即插即用且计算经济WDAM可无缝替换标准窗口注意力模块参数量增加可忽略不计同时由于注意力仅运行在1/2分辨率的低频子带上计算量与内存占用较传统窗口注意力降低约75%特别适合YOLO系列这类实时检测框架。3. 适用范围与模块效果适用范围WDAM特别适用于通用视觉领域中需要同时建模全局结构和保留高频细节的任务。具体而言其核心优势体现在以下场景高分辨率遥感与航拍图像处理遥感图像中道路、电力线、建筑边缘等具有显著方向性特征的目标WDAM的方向感知能力能够精准捕捉这些细长结构的完整性适用于智慧城市、电网巡检、水利监测等应用。小目标与密集目标检测无人机低空巡检、船舶编队检测等场景中小目标常以密集堆叠形式出现WDAM的高频引导机制能够增强目标边界的区分度有效缓解相邻目标的粘连问题。边缘计算与实时部署WDAM的计算量大降75%使其非常适合部署在算力受限的边缘设备上在保证推理速度的同时提升检测精度。工业缺陷检测产品表面的细微划痕、纹理异常等缺陷往往仅在高频分量中有所体现WDAM的高频引导注意力能够精准定位这些微小异常区域。低光照与复杂退化图像原始论文的闪烁去除任务表明WDAM对光照不稳定、噪声干扰等退化场景具有天然鲁棒性可用于夜间监控、弱光环境下的目标检测。⚡模块效果模块效果模型的性能和视觉效果SOTA。消融实验a) 和 (d) 对比说明MDAM优于ASSA性能对比实验与其他主流的自注意力相比WDAM在保持模型复杂性和计算成本的同时实现性能和视觉效果最优4. WDAM模块代码实现以下为WDAM模块的官方pytorch实现代码# 小波引导方向注意力模块Wavelet-guided Directional Attention Module, WDAM # 核心设计基于“DWT小波分解→方向高频特征提取→高频引导窗口注意力→高低频协同优化→IDWT重建”的流程 # 通过Haar小波将特征解耦为低频结构分量和三个方向的高频细节分量 # 利用水平/垂直高频生成方向引导权重增强注意力对边缘纹理的聚焦能力 # 同时保持全局结构完整性实现高效的特征增强与重建 import torch import torch.nn as nn import torch.nn.functional as F import math from pytorch_wavelets import DWTForward, DWTInverse class DWTWindowAttention(nn.Module): 带相对位置偏置的窗口注意力模块 支持移位窗口机制实现跨窗口信息交互 def __init__(self, dim, num_heads, window_size, shift_size0, biasFalse): super().__init__() self.dim dim self.num_heads num_heads self.window_size window_size self.shift_size shift_size # 温度系数控制注意力分布平滑度 self.temperature nn.Parameter(torch.ones(num_heads, 1, 1)) # 相对位置偏置表 self.relative_position_bias_table nn.Parameter( torch.zeros((2 * window_size - 1) * (2 * window_size - 1), num_heads) ) # 预计算相对位置索引 coords torch.stack(torch.meshgrid(torch.arange(window_size), torch.arange(window_size))) coords_flatten coords.flatten(1) relative_coords coords_flatten[:, :, None] - coords_flatten[:, None, :] relative_coords relative_coords.permute(1, 2, 0).contiguous() relative_coords[:, :, 0] window_size - 1 relative_coords[:, :, 1] window_size - 1 relative_coords[:, :, 0] * 2 * window_size - 1 relative_position_index relative_coords.sum(-1) self.register_buffer(relative_position_index, relative_position_index) def window_partition(self, x, H, W): 将特征划分为不重叠的窗口 B, C, H, W x.shape ws self.window_size x x.view(B, C, H // ws, ws, W // ws, ws) x x.permute(0, 2, 4, 1, 3, 5).contiguous() x x.view(-1, C, ws, ws) return x def window_reverse(self, windows, H, W): 将窗口特征还原为原始特征尺寸 B int(windows.shape[0] / (H * W / self.window_size / self.window_size)) C windows.shape[1] ws self.window_size x windows.view(B, H // ws, W // ws, C, ws, ws) x x.permute(0, 3, 1, 4, 2, 5).contiguous() x x.view(B, C, H, W) return x def forward(self, q, k, v, H, W): 窗口注意力前向传播 q, k, v: [B, num_heads, N, head_dim] B, num_heads, N, head_dim q.shape ws self.window_size # 重塑为2D特征图 q q.transpose(1, 2).reshape(B, N, num_heads * head_dim) q q.transpose(1, 2).reshape(B, -1, H, W) k k.transpose(1, 2).reshape(B, N, num_heads * head_dim) k k.transpose(1, 2).reshape(B, -1, H, W) v v.transpose(1, 2).reshape(B, N, num_heads * head_dim) v v.transpose(1, 2).reshape(B, -1, H, W) # 窗口划分 q_w self.window_partition(q, H, W) k_w self.window_partition(k, H, W) v_w self.window_partition(v, H, W) # 重塑为注意力格式 B_win, C, ws, _ q_w.shape q_w q_w.view(B_win, num_heads, C // num_heads, ws * ws) k_w k_w.view(B_win, num_heads, C // num_heads, ws * ws) v_w v_w.view(B_win, num_heads, C // num_heads, ws * ws) # L2归一化 q_w F.normalize(q_w, dim-2) k_w F.normalize(k_w, dim-2) # 注意力计算 attn torch.matmul(q_w.transpose(-2, -1), k_w) # 添加相对位置偏置 N_win ws * ws relative_position_bias self.relative_position_bias_table[self.relative_position_index.view(-1)] relative_position_bias relative_position_bias.view(N_win, N_win, -1).permute(2, 0, 1).unsqueeze(0) attn attn relative_position_bias # 温度缩放和softmax attn attn * self.temperature attn attn.softmax(dim-1) # 加权求和 out torch.matmul(v_w, attn.transpose(-2, -1)) # 窗口还原 out out.view(B_win, C, ws, ws) out self.window_reverse(out, H, W) # 重塑回原始格式 out out.view(B, -1, H * W).transpose(1, 2) out out.view(B, H * W, num_heads, head_dim).transpose(1, 2) return out class WDAM(nn.Module): 小波引导方向注意力模块Wavelet-guided Directional Attention Module 核心设计 - DWT小波分解Haar小波将特征解耦为低频结构和高频细节分量 - 方向高频提取水平(LH)/垂直(HL)/对角(HH)三方向高频特征独立处理 - 高频引导注意力用水平垂直高频生成权重引导V特征聚焦边缘纹理 - 移位窗口注意力Swin风格窗口注意力降低计算复杂度 - 高低频协同优化三方向高频联合优化后与低频一起重建 - IDWT逆变换融合优化后的高低频分量重建增强特征 Args: dim: 输入/输出特征通道数 num_heads: 注意力头数默认8 window_size: 注意力窗口大小默认8 shift_size: 窗口移位步长默认4 qkv_bias: QKV生成是否带偏置默认False attn_drop: 注意力权重dropout概率默认0. proj_drop: 输出投影dropout概率默认0. def __init__(self, dim, num_heads8, window_size8, shift_size4, qkv_biasFalse, attn_drop0., proj_drop0.): super().__init__() assert dim % num_heads 0, fdim {dim} should be divided by num_heads {num_heads}. self.dim dim self.num_heads num_heads self.head_dim dim // num_heads self.window_size window_size self.shift_size shift_size # DWT/IDWT模块 self.dwt DWTForward(J1, wavehaar) self.idwt DWTInverse(wavehaar) # 方向高频特征处理分支水平垂直融合 self.high_conv nn.Sequential( nn.Conv2d(dim * 2, dim * 2, kernel_size3, padding1, groups2, biasqkv_bias), nn.ReLU(inplaceTrue), nn.Conv2d(dim * 2, dim, kernel_size1, biasqkv_bias), nn.ReLU(inplaceTrue) ) # 三方向高频联合优化分支 self.high_out nn.Sequential( nn.Conv2d(dim * 3, dim * 3, kernel_size3, padding1, groups3, biasqkv_bias), nn.ReLU(inplaceTrue) ) # 低频特征QKV生成 self.qkv nn.Conv2d(dim, dim * 3, kernel_size1, biasqkv_bias) self.qkv_dwconv nn.Conv2d(dim * 3, dim * 3, kernel_size3, stride1, padding1, groupsdim * 3, biasqkv_bias) # 窗口注意力 self.window_attn DWTWindowAttention(dim, num_heads, window_size, shift_size, qkv_bias) # 输出投影 self.project_out nn.Conv2d(dim, dim, kernel_size1, biasqkv_bias) # Dropout self.attn_drop nn.Dropout(attn_drop) self.proj_drop nn.Dropout(proj_drop) # 自适应窗口当输入分辨率小于窗口大小时自动调整 self.adaptive_window True def forward(self, x, HNone, WNone): 前向传播支持两种模式 1. 独立测试模式传入H,W参数 2. Ultralytics框架模式直接传入[B,C,H,W]张量 # 独立测试模式 if H is not None and W is not None: return self.forward_seq(x, H, W) # Ultralytics框架模式 B, C, H, W x.shape return self.forward_2d(x, H, W) def forward_2d(self, x, H, W): 2D特征图前向传播 return self._forward_impl(x, H, W) def forward_seq(self, x_seq, H, W): 序列格式前向传播兼容CGTA接口 B, N, C x_seq.shape # 转换为2D特征图 x x_seq.transpose(1, 2).reshape(B, C, H, W).contiguous() out self._forward_impl(x, H, W) # 转换回序列格式 out_seq out.flatten(2).transpose(1, 2) return out_seq def _forward_impl(self, x, H, W): WDAM核心实现 x: [B, C, H, W] B, C, H, W x.shape # 自适应调整窗口大小 if self.adaptive_window and H // 2 self.window_size: ws H // 2 if ws 0: self.window_size ws self.shift_size 0 # 步骤1小波分解 # 1级Haar小波分解LL低频结构分量Yh高频方向分量 LL, Yh self.dwt(x) Yh Yh[0] # [B, C, 3, H/2, W/2] LH, HL, HH Yh[:, :, 0, :, :], Yh[:, :, 1, :, :], Yh[:, :, 2, :, :] # 步骤2方向高频特征提取 # 融合水平垂直方向高频生成方向引导权重 filter_hv self.high_conv(torch.cat([LH, HL], dim1)) # 步骤3低频QKV生成高频方向引导 # 低频特征生成QKV qkv self.qkv_dwconv(self.qkv(LL)) q, k, v_inp qkv.chunk(3, dim1) # 核心创新用方向高频特征加权V让注意力聚焦于边缘/纹理区域 v v_inp * filter_hv v_inp # 步骤4移位窗口注意力计算 # 窗口循环移位实现跨窗口信息交互 if self.shift_size 0: LL_shifted torch.roll(LL, shifts(-self.shift_size, -self.shift_size), dims(2, 3)) v_shifted torch.roll(v, shifts(-self.shift_size, -self.shift_size), dims(2, 3)) else: LL_shifted LL v_shifted v # 重塑为注意力格式 [B, num_heads, N, head_dim] H_half, W_half H // 2, W // 2 q_attn q.view(B, self.num_heads, self.head_dim, H_half * W_half).transpose(-2, -1) k_attn k.view(B, self.num_heads, self.head_dim, H_half * W_half).transpose(-2, -1) v_attn v_shifted.view(B, self.num_heads, self.head_dim, H_half * W_half).transpose(-2, -1) # 窗口注意力 out_attn self.window_attn(q_attn, k_attn, v_attn, H_half, W_half) # 逆移位 out_attn out_attn.transpose(-2, -1).reshape(B, C, H_half, W_half) if self.shift_size 0: out_attn torch.roll(out_attn, shifts(self.shift_size, self.shift_size), dims(2, 3)) # 注意力输出投影 out self.project_out(out_attn) out self.proj_drop(out) # 步骤5高频分量优化小波逆变换重建 # 三方向高频分量联合优化 Yh_optimized self.high_out(torch.cat([LH, HL, HH], dim1)) LH_opt, HL_opt, HH_opt Yh_optimized.chunk(3, dim1) Yh_opt torch.stack([LH_opt, HL_opt, HH_opt], dim2) # 逆小波变换融合优化后的低频高频分量重建最终特征 x_hat self.idwt((out, [Yh_opt])) return x_hat class WDAM2D(nn.Module): WDAM的简化包装器自动处理维度转换 兼容现有的CGTA2D接口 def __init__(self, dim, num_heads8, window_size8, shift_size4, qkv_biasFalse, attn_drop0., proj_drop0.): super().__init__() self.wdam WDAM( dimdim, num_headsnum_heads, window_sizewindow_size, shift_sizeshift_size, qkv_biasqkv_bias, attn_dropattn_drop, proj_dropproj_drop ) self.dim dim def forward(self, x): 前向传播支持2D特征图输入 x: [B, C, H, W] # 直接调用WDAM的forward自动检测输入格式 return self.wdam(x) if __name__ __main__: device torch.device(cuda:0 if torch.cuda.is_available() else cpu) # 测试12D特征图模式 print( * 50) print(测试12D特征图模式) x_2d torch.randn(1, 64, 32, 32).to(device) model_2d WDAM(dim64, num_heads8, window_size8, shift_size4).to(device) y_2d model_2d(x_2d) print(f输入特征维度{x_2d.shape}) print(f输出特征维度{y_2d.shape}) # 测试2序列模式兼容CGTA接口 print(\n * 50) B, N, C 1, 32 * 32, 64 x_seq torch.randn(B, N, C).to(device) model_seq WDAM(dim64, num_heads8, window_size8, shift_size4).to(device) y_seq model_seq(x_seq, H32, W32) print(f输入特征维度{x_seq.shape}) print(f输出特征维度{y_seq.shape}) # 测试3WDAM2D包装器 print(\n * 50) print(测试3WDAM2D包装器) x_wrap torch.randn(1, 64, 32, 32).to(device) model_wrap WDAM2D(dim64, num_heads8, window_size8, shift_size4).to(device) y_wrap model_wrap(x_wrap) print(f输入特征维度{x_wrap.shape}) print(f输出特征维度{y_wrap.shape}) # 参数统计 print(\n * 50) total_params sum(p.numel() for p in model_2d.parameters()) trainable_params sum(p.numel() for p in model_2d.parameters() if p.requires_grad) print(f总参数量{total_params:,}) print(f可训练参数量{trainable_params:,})结合自己的思路可将其即插即用至任何模型做结构创新设计该模块博主已成功嵌入至YOLO26模型中可订阅博主YOLO系列算法改进或YOLO26自研改进专栏YOLO系列算法改进专栏链接、YOLO26自研改进系列专栏
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2582177.html
如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!