即插即用系列 | TGRS 2026 | CGTA:曲率引导标记注意力!线性复杂度全局建模,几何结构保真与长程关联双突破 | 代码分享
0. 前言本文介绍了CGTA曲率引导标记注意力模块其通过曲率感知的标记选择策略与全局稀疏注意力机制首次在遥感图像超分辨率领域实现对细长曲线结构与重复纹理的高保真重建有效破解了传统注意力机制在处理曲线拓扑时容易产生锯齿边缘与结构断裂的难题。将其作为即插即用模块轻松助力CNN、YOLO、Transformer等深度学习模型在保持近线性计算复杂度的同时增强对关键几何区域的全局关联能力让模型在面对蜿蜒河流、盘山公路、电力线、海岸线等复杂拓扑目标时依然能够保持清晰的几何连续性与精准的结构感知。专栏链接即插即用系列专栏链接可点击跳转免费订阅目录0. 前言1. CGTA模块简介2. CGTA模块基本原理与创新点 CGTA模块的基本原理 CGTA模块主要创新点3. 适用范围与模块效果适用范围⚡模块效果4. CGTA模块代码实现1. CGTA模块简介遥感图像超分辨率支持制图、监测和识别等下游任务然而编码场景拓扑的曲线结构和重复纹理在下采样时会产生混叠效应。缺乏几何先验的通用注意力机制往往导致锯齿状边缘和细节流失。为了解决这一问题本文提出了曲率引导注意力CGA它在标准Transformer框架内将曲率线索注入局部窗口注意力和基于标记的全局聚合中。局部曲率引导注意力LCGA在窗口内强化边缘和线条的连续性而曲率引导标记注意力CGTA则以近线性复杂度选择全局信息丰富的标记避免了重量级全局模块和额外分支。所设计的结构能够在不损害自然区域的情况下保持结构布局和纹理规律性并与传统残差组RGs集成。原始论文https://ieeexplore.ieee.org/stamp/stamp.jsp?tparnumber11373098原始代码https://github.com/mwaleedaslam/CGA2. CGTA模块基本原理与创新点 CGTA模块的基本原理CGTACurvature-Guided Token Attention是一种曲率引导的全局标记注意力机制其核心设计理念是通过曲率感知的标记选择策略将全局注意力从全图密集计算转化为稀疏的、面向关键几何区域的关联计算。该模块的总体原理是首先利用可学习的曲率代理从特征图中提取几何显著性信息然后基于该显著性选择最具代表性的top-k个标记最后仅对这些精选标记执行交叉注意力计算从而实现线性复杂度的全局几何建模。具体而言CGTA的实现包含以下几个关键步骤1曲率感知的标记选择模块首先复用LCGA中构建的可学习曲率代理对全图特征生成曲率显著性图该图能够突出道路边界、屋顶转角、河流弯曲等几何结构丰富的区域。同时引入一个保留门控retention gate生成可靠性图用于抑制噪声激活。两者结合形成综合评分确保选出的标记既具有几何重要性又具有特征可靠性。2自适应标记预算分配CGTA采用一种称为KTockets的策略根据输入图像尺寸自适应计算路由密度ρ和标记预算k。具体地通过输入图像的高宽计算时间层级time_level然后ρ 4^time_level最终k (H/ρ)×(W/ρ)。这种设计使得k随图像尺寸呈亚线性增长确保在大尺寸遥感图像上仍保持高效计算。3精选标记的压缩与变换对被选中的k个标记CGTA首先通过1×1卷积压缩通道维度再经线性投影得到键K和值V。查询Q则针对所有N个位置计算但投影维度减半以提升效率。同时值会依据选择置信度进行门控调制使得弱选标记对聚合贡献较小。4曲率调制的混合注意力在计算注意力时CGTA同时生成标准注意力权重和曲率调制注意力权重后者通过对原始logits施加曲率引导的乘法调制实现。最终通过可学习的融合参数β将两种注意力进行凸组合形成混合注意力权重既保留标准注意力的稳定性又增强曲率区域的结构关联。 CGTA模块主要创新点曲率引导的稀疏全局建模首次将曲率几何先验引入全局标记选择过程使注意力计算聚焦于结构关键区域避免在平坦无纹理区域浪费计算资源。线性复杂度的全局关联通过top-k标记选择策略将全局注意力的O(N²)复杂度降低为O(Nk)在k≪N时实现近线性扩展使大尺寸遥感图像的高效全局建模成为可能。曲率-标准混合注意力机制采用可学习融合参数动态平衡曲率调制注意力与标准注意力在强化几何结构的同时保持非结构区域的稳定性避免注意力过度集中导致的特征漂移。自适应的标记预算分配KTockets策略根据输入图像尺寸动态调整标记数量使模型在不同分辨率的输入上都能保持计算效率与表征能力的平衡。3. 适用范围与模块效果适用范围CGTA模块适用于各类视觉任务特别是需要高效全局建模且对几何结构保真度要求较高的场景。具体包括但不限于遥感图像超分辨率作为CGTA的原始应用场景在处理道路、河流、海岸线等细长曲线结构时表现出色。目标检测与实例分割尤其适合检测细长、弯曲目标如蜿蜒道路、盘山公路、电力弧垂、管道裂缝等能有效提升对这类目标的召回率与定位精度。语义分割与边缘检测需要保持几何连续性的任务CGTA可增强模型对物体边界的感知能力。高分辨率图像处理由于具有近线性复杂度CGTA特别适合处理大尺寸遥感图像、航拍图像和卫星图像等场景。⚡模块效果CGTA的Token数256最优消融研究说明CGTA的有效性第二行和第四行对比蓝色轮廓表示HR脊带。左比较Std和曲线注意力的顶部优势轮廓。右同一波段内的混合注意力热图。曲线经常在道路分支和路口边界周围占据主导地位而混合则将注意力集中在这些山脊附近并保持稳定。4. CGTA模块代码实现以下为CGTA模块的官方pytorch实现代码# 曲率引导令牌注意力模块Curvature-guided Token Attention, CGTA # 核心设计基于“曲率特征提取→动态Top-K令牌筛选→跨注意力计算→特征增强”的流程 # 通过特征曲率捕捉结构变化结合保留门控筛选关键令牌在降低计算成本的同时 # 让注意力聚焦结构关键区域强化特征的结构依赖性与表达精准性 import torch import torch.nn as nn from torch.nn import functional as F import math class RetentionGate(nn.Module): def __init__(self, dim, hidden_dim64): super().__init__() self.ret_gate nn.Sequential( nn.Linear(dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 1), nn.Sigmoid() ) def forward(self, x): return self.ret_gate(x).squeeze(-1) def get_dynamic_topk_tokens(H, W, training): h_val max(H // 16, 1) w_val max(W // 16, 1) if training: time_level max(int(math.log(H // 4, 4)), int(math.log(W // 4, 4))) else: time_level max(2, max(int(math.log(h_val, 4)), int(math.log(w_val, 4)))) # 防止 scale 过大 time_level min(time_level, 3) scale 4 ** time_level k_tokens max(1, (H // scale) * (W // scale)) return k_tokens, scale class CGTA(nn.Module): def __init__(self, dim, num_heads8, qkv_biasFalse, qk_scaleNone, attn_drop0., proj_drop0., c_ratio0.5): super().__init__() assert dim % num_heads 0, fdim {dim} should be divided by num_heads {num_heads}. self.dim dim self.num_heads num_heads head_dim dim // num_heads self.cr int(dim * c_ratio) self.cr (self.cr // num_heads) * num_heads if self.cr 0: self.cr num_heads self.scale qk_scale or (head_dim * c_ratio) ** -0.5 self.curvature_conv nn.Conv2d(dim, dim, 3, 1, 1, groupsdim, biasFalse) self.gate RetentionGate(dim, hidden_dimdim // 2) self.q nn.Linear(dim, self.cr, biasqkv_bias) self.kv_reduce nn.Linear(dim, self.cr, biasqkv_bias) self.k nn.Linear(self.cr, self.cr, biasqkv_bias) self.v nn.Linear(self.cr, dim, biasqkv_bias) self.norm_act nn.Sequential( nn.LayerNorm(self.cr), nn.GELU() ) self.cpe nn.Conv2d(dim, dim, kernel_size3, stride1, padding1, groupsdim) self.proj nn.Linear(dim, dim) self.attn_drop nn.Dropout(attn_drop) self.proj_drop nn.Dropout(proj_drop) self.alpha nn.Parameter(torch.zeros(1)) self.beta nn.Parameter(torch.tensor(0.5)) def forward(self, x, HNone, WNone): # 独立测试模式 if H is not None and W is not None: return self.forward_seq(x, H, W) # Ultralytics 框架模式 B, C, H, W x.shape x_seq x.flatten(2).transpose(1, 2) out_seq self.forward_seq(x_seq, H, W) out out_seq.transpose(1, 2).view(B, C, H, W).contiguous() return out def forward_seq(self, x, H, W): B, N, C x.shape _x x.permute(0, 2, 1).reshape(B, C, H, W).contiguous() curvature self.curvature_conv(_x).mean(dim1, keepdimTrue).view(B, -1) curvature F.layer_norm(curvature, curvature.shape[1:]) gate_score self.gate(x) score (curvature.abs() gate_score) / 2 k_tokens, _scale get_dynamic_topk_tokens(H, W, self.training) k_tokens min(k_tokens, N) topk_scores, topk_indices score.topk(k_tokens, dim-1, largestTrue, sortedFalse) x_topk torch.gather(x, dim1, indextopk_indices.unsqueeze(-1).expand(-1, -1, C)) score_topk torch.gather(score, 1, topk_indices) q self.q(x).reshape(B, N, self.num_heads, self.cr // self.num_heads).permute(0, 2, 1, 3) kv_compressed self.kv_reduce(x_topk) kv_compressed self.norm_act(kv_compressed) k self.k(kv_compressed).reshape(B, k_tokens, self.num_heads, self.cr // self.num_heads).permute(0, 2, 1, 3) v self.v(kv_compressed).reshape(B, k_tokens, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) v v * score_topk.unsqueeze(1).unsqueeze(-1) curvature_topk torch.gather(curvature, 1, topk_indices).unsqueeze(1) attn_logits (q k.transpose(-2, -1)) * self.scale attn_mod attn_logits * (1 self.alpha * curvature_topk.unsqueeze(1)) attn_std F.softmax(attn_logits, dim-1) attn_cga F.softmax(attn_mod, dim-1) attn self.beta * attn_std (1 - self.beta) * attn_cga attn self.attn_drop(attn) try: v_reshape v.transpose(1, 2).reshape(B, k_tokens, C) kh, kw min(int(math.sqrt(k_tokens)), H), min(int(math.sqrt(k_tokens)), W) v_reshape v_reshape.transpose(1, 2).view(B, C, kh, kw) cpe_out self.cpe(v_reshape) v v cpe_out.view(B, C, -1).transpose(1, 2).reshape(B, k_tokens, self.num_heads, C//self.num_heads).transpose(1,2) except: pass x_out (attn v).transpose(1, 2).reshape(B, N, C) x_out self.proj(x_out) x_out self.proj_drop(x_out) return x_out if __name__ __main__: device torch.device(cuda:0 if torch.cuda.is_available() else cpu) # B, N, C x torch.randn(1, 32*32, 64).to(device) model CGTA(64, 8).to(device) y model(x, 32, 32) print(输入特征维度, x.shape) print(输出特征维度, y.shape)结合自己的思路可将其即插即用至任何模型做结构创新设计该模块博主已成功嵌入至YOLO26模型中可订阅博主YOLO系列算法改进或YOLO26自研改进专栏YOLO系列算法改进专栏链接、YOLO26自研改进系列专栏
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2470091.html
如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!