即插即用系列 | CVPR 2026 | GSRA:自注意力创新!几何校正空间一致性,语义强化高层关联,特征更精准! | 代码分享
0. 前言本文介绍了GSRAGeometric-Semantic Rectification Attention几何-语义校正注意力其通过跨模态差分注意力机制首次在图像阴影去除领域实现对几何特征与语义特征的精准对齐有效破解了传统方法因物理先验错位导致的边缘模糊与色彩失真难题。将其作为即插即用模块轻松助力CNN、Transformer等深度学习模型精准抑制模态冲突、增强特征一致性让模型在面对复杂环境光照、多光源叠加或间接照明等挑战性场景时依然能够保持清晰的边界感知与稳定的恢复精度。专栏链接即插即用系列专栏链接可点击跳转免费订阅1. GSRA注意力简介Transformer倾向于过度关注不相关的上下文内容。本文提出的差分Transformer通过放大对相关上下文的注意力同时消除噪声来解决这一问题。具体而言差分注意力机制通过计算两个独立的Softmax注意力图之间的差值来生成注意力分数。这种相减操作能够抵消噪声促进稀疏注意力模式的形成。语言建模的实验结果表明Diff Transformer在模型规模扩展和训练token数量等不同设置下均优于Transformer。更引人注目的是它在长上下文建模、关键信息检索、幻觉缓解、上下文学习以及激活异常值减少等实际应用场景中展现出显著优势。通过减少对无关上下文的干扰Diff Transformer能够缓解问答和文本摘要中的幻觉问题。在上下文学习方面Diff Transformer不仅提升了准确性还对顺序排列表现出更强的鲁棒性——这此前一直被认为是Transformer的固有鲁棒性问题。原始论文https://arxiv.org/pdf/2601.17470原始代码https://github.com/ming053l/PhaSR2. GSRA注意力原理与创新点 GSRA注意力基本原理GSRA几何-语义校正注意力的核心设计思想源于差分Transformer的降噪原理但将其从单一模态拓展至跨模态对齐场景。与传统的注意力机制不同GSRA不再简单地将几何信息和语义信息拼接融合而是借鉴了差分放大电路的设计思路——通过两个信号的差值来消除共模噪声。在视觉任务中几何特征如深度、法线对局部光影变化高度敏感能够精确定位阴影边界而语义特征如DINO-v2提取的对象类别信息则在不同光照条件下保持稳定。然而这两种模态常常“各说各话”几何特征在均匀光照区域会引入不必要的噪声语义特征在复杂边界处又容易过度平滑。GSRA通过让两者相互“校正”实现了优势互补。具体而言GSRA的实现包含以下几个关键步骤1多模态先验注入首先从输入图像中提取两种物理先验——通过DepthAnything-V2提取深度和法线图作为几何先验通过DINO-v2提取多尺度语义特征作为语义先验。随后通过可学习的权重参数将这两种先验分别与输入特征融合形成几何增强特征和语义增强特征为后续的跨模态交互做准备。2差异化键值生成将几何增强特征和语义增强特征分别通过独立的线性投影层生成各自专属的键Key和值Value对。这种设计保证了每种模态的特征表达能够保留其特有的统计特性避免在早期阶段就发生特征混淆。3差分注意力计算这是GSRA的核心创新所在。给定共享的查询特征同时计算两个注意力图——几何注意力图和语义注意力图。然后执行减法操作将语义注意力图减去经过可学习系数λ加权的几何注意力图得到校正后的注意力图。这一操作在数学上等价于用几何信息去“过滤”语义注意力中的噪声在几何特征可信的区域如阴影边界几何注意力图具有高响应减法操作会抑制语义注意力的过度平滑在几何特征嘈杂的区域如均匀光照表面几何注意力图响应较弱语义注意力得以保留。4多模态特征聚合将校正后的注意力图分别与几何分支和语义分支的值进行加权求和然后将两个分支的输出拼接起来形成融合特征。这一特征既保留了几何信息对边界的精确刻画能力又继承了语义信息对材质恒常性的稳定判断。 GSRA注意力处理流程GSRA 模块的特征处理流程分为 “模态适配→KV 增强→注意力校正→输出” 四步核心逻辑如下模态适配与投影原始特征图转换为序列格式几何特征3D与语义特征1024D经线性投影至统一维度模态投影特征与基础特征加权融合生成几何增强特征与语义增强特征。双模态 KV 生成与融合增强后的双模态特征经专属投影层生成各自 KV 对拼接版额外生成原始 KV多模态 KV 按可学习权重加权融合平衡不同模态的贡献。几何 - 语义注意力校正原始特征生成的 Q 与融合 KV 计算注意力分数经 Softmax 归一化后加权 V 特征注意力输出经线性投影完成 “结构 语义” 双重校正得到增强序列特征。特征还原与残差输出序列特征还原为特征图格式与原始特征残差融合输出最终校正特征。3. 适用范围与模块效果适用范围GSRA适用于通用视觉任务中需要融合几何与语义信息的场景特别是当两种模态存在响应冲突或噪声干扰时。其核心价值在于通过差分操作实现跨模态的噪声抵消与信号增强。具体而言以下任务场景特别适合应用GSRA1复杂光照下的图像复原如阴影去除、光照归一化、低光照增强等任务。这些场景中几何信息表面走向、深度和语义信息物体类别、材质对光照变化的响应截然不同GSRA能够有效协调两者的矛盾。2多模态融合的视觉任务如RGB-D感知、语义分割、3D场景理解等。当模型中同时包含深度/法线几何分支和语义理解分支时GSRA可作为融合模块提升跨模态协同的鲁棒性。3边界敏感的重建任务如图像抠图、边缘检测、图像修复等。几何特征擅长捕捉边界信息但易受纹理干扰语义特征稳定但边界粗糙GSRA的差分机制能够提取两者的优势生成锐利且语义一致的边界。⚡模块效果根据PhaSR原始论文中的实验结果涉及GSRA模块的关键验证包括1GSRA的消融实验论文Table 6在ISTD数据集上移除GSRA改用标准交叉注意力后PSNR从34.48 dB降至32.56 dBSSIM从0.960降至0.934在WSRD数据集上PSNR从28.44 dB降至26.92 dBSSIM从0.942降至0.920。这一实验验证了GSRA相比于传统跨模态融合方法的显著优势。2差分校正机制的有效性验证论文Table 6将GSRA中的差分校正操作λ设为0即仅使用语义注意力时ISTD数据集上PSNR从34.48 dB降至32.89 dBSSIM从0.960降至0.951。这证明了差分校正机制在模态对齐中的核心作用。3几何先验与语义先验的贡献度分析论文Table 6单独移除几何先验时ISTD数据集PSNR从34.48 dB降至33.52 dB单独移除语义先验时降至33.38 dB。两者均造成性能下降表明几何和语义先验在GSRA框架中相辅相成、缺一不可。4中间特征可视化论文Figure 2对比OmniSR、DenseSR与PhaSR的瓶颈层特征图GSRA能够在复杂环境光照下精准高亮阴影区域而对比方法在瓶颈层已丢失物理先验信息。结论上述实验共同验证了GSRA通过跨模态差分对齐机制有效解决了几何与语义先验的冲突问题在复杂光照条件下实现了更精准的阴影定位和更清晰的边界恢复。4. GSRA注意力代码实现以下为GSRA注意力机制的官方pytorch实现代码# 几何-语义校正注意力模块Geometric-Semantic Rectification Attention, GSRA # 核心设计针对双模态融合场景几何特征语义特征通过模态专属差分投影双KV增强注意力校正的架构 # 将3D几何特征如点云、深度与高维语义特征如DINO预训练特征投影至统一维度 # 生成几何增强KV与语义增强KV结合原始Q实现双模态引导的注意力计算 # 校正特征的几何结构与语义一致性强化双模态互补信息提升特征表达的精准度 import torch import torch.nn as nn from einops import rearrange, repeat import math def lambda_init_fn(depth): return 0.8 - 0.6 * math.exp(-0.3 * depth) class DifferentialLinearProjection(nn.Module): Modal-specific Differential Linear Projection def __init__(self, dim, heads8, dim_head64, dropout0., biasTrue): super().__init__() self.head_dim dim_head inner_dim dim_head * heads self.heads heads # Q projection remains unchanged self.to_q nn.Linear(dim, inner_dim, biasbias) # Create KV projections for two modalities respectively self.to_kv_geometric nn.Linear(dim, inner_dim * 2, biasbias) # KV for geometric branch self.to_kv_semantic nn.Linear(dim, inner_dim * 2, biasbias) # KV for semantic branch # Modal feature projection layer - project different dimension features to unified dimension self.geo_proj nn.Linear(3, dim, biasbias) # Assuming geometric feature is 3D self.dino_proj nn.Linear(1024, dim, biasbias) # Assuming DINO feature is 1024D # Learnable fusion weights self.geo_weight nn.Parameter(torch.tensor(0.1)) self.sem_weight nn.Parameter(torch.tensor(0.1)) self.dim dim self.inner_dim inner_dim print(Modal-specific differential transformer initialized!) def forward(self, x, geo_feat, dino_feat, attn_kvNone): B_, N, C x.shape attn_kv x if attn_kv is None else attn_kv # Q remains as is q self.to_q(x).reshape(B_, N, 1, self.heads, self.head_dim).permute(2, 0, 3, 1, 4) q q[0] # [B_, heads, N, head_dim] # Project modal features to unified dimension geo_feat_proj self.geo_proj(geo_feat) # [B_, N, dim] dino_feat_proj self.dino_proj(dino_feat) # [B_, N, dim] # Simple feature fusion - weighted sum geo_enhanced attn_kv self.geo_weight * geo_feat_proj semantic_enhanced attn_kv self.sem_weight * dino_feat_proj # Calculate two sets of KV respectively kv_geo self.to_kv_geometric(geo_enhanced).reshape(B_, N, 2, self.heads, self.head_dim).permute(2, 0, 3, 1, 4) kv_sem self.to_kv_semantic(semantic_enhanced).reshape(B_, N, 2, self.heads, self.head_dim).permute(2, 0, 3, 1, 4) # Combine into final KV # kv[0] Geometric enhanced KV, kv[1] Semantic enhanced KV k torch.stack([kv_geo[0], kv_sem[0]], dim0) # [2, B_, heads, N, head_dim] v torch.stack([kv_geo[1], kv_sem[1]], dim0) # [2, B_, heads, N, head_dim] return q, k, v class DifferentialLinearProjection_Concat_kv(nn.Module): Concat version of Modal-specific Differential Linear Projection def __init__(self, dim, heads8, dim_head64, dropout0., biasTrue): super().__init__() self.head_dim dim_head inner_dim dim_head * heads self.heads heads # Basic QKV projection self.to_qkv nn.Linear(dim, inner_dim * 3, biasbias) # Additional KV projections for two modalities self.to_kv_geometric nn.Linear(dim, inner_dim * 2, biasbias) self.to_kv_semantic nn.Linear(dim, inner_dim * 2, biasbias) # Modal feature projection layer self.geo_proj nn.Linear(3, dim, biasbias) self.dino_proj nn.Linear(1024, dim, biasbias) # Learnable fusion weights self.geo_weight nn.Parameter(torch.tensor(0.1)) self.sem_weight nn.Parameter(torch.tensor(0.1)) self.dim dim self.inner_dim inner_dim def forward(self, x, geo_feat, dino_feat, attn_kvNone): B_, N, C x.shape attn_kv x if attn_kv is None else attn_kv # Basic QKV qkv_dec self.to_qkv(x).reshape(B_, N, 3, self.heads, self.head_dim).permute(2, 0, 3, 1, 4) q, k_d, v_d qkv_dec[0], qkv_dec[1], qkv_dec[2] # Project modal features to unified dimension geo_feat_proj self.geo_proj(geo_feat) dino_feat_proj self.dino_proj(dino_feat) # Simple feature fusion geo_enhanced attn_kv self.geo_weight * geo_feat_proj semantic_enhanced attn_kv self.sem_weight * dino_feat_proj # Calculate two sets of additional KV respectively kv_geo self.to_kv_geometric(geo_enhanced).reshape(B_, N, 2, self.heads, self.head_dim).permute(2, 0, 3, 1, 4) kv_sem self.to_kv_semantic(semantic_enhanced).reshape(B_, N, 2, self.heads, self.head_dim).permute(2, 0, 3, 1, 4) k_geo, v_geo kv_geo[0], kv_geo[1] k_sem, v_sem kv_sem[0], kv_sem[1] # Concat: [Basic KV, Geometric KV, Semantic KV] k torch.cat((k_d, k_geo, k_sem), dim2) v torch.cat((v_d, v_geo, v_sem), dim2) return q, k, v class DifferentialWindowAttention(nn.Module): 几何-语义校正注意力模块Geometric-Semantic Rectification Attention, GSRA 功能融合双模态增强KV与原始Q通过注意力计算实现几何-语义校正 核心设计 - 双投影适配支持基础版/拼接版差分投影灵活选择融合模式 - 注意力校正通过双模态KV引导Q的注意力分配校正几何结构与语义一致性 - 可学习融合权重动态平衡不同模态KV的贡献 - 残差连接保留原始特征信息避免过度校正 Args: dim: 输入/输出通道数 heads: 注意力头数默认8 dim_head: 每个注意力头的通道数默认64 dropout: dropout概率默认0. bias: 线性层是否带偏置默认True proj_type: 投影模块类型base基础版concat拼接版默认base depth: 网络层深度用于lambda系数初始化默认1 def __init__(self, dim, win_size, num_heads, depth1, token_projectionlinear, qkv_biasTrue, qk_scaleNone, attn_drop0., proj_drop0., se_layerFalse, geo_dim3, dino_dim1024): super().__init__() self.dim dim self.win_size win_size if isinstance(win_size, tuple) else (win_size, win_size) self.num_heads num_heads self.head_dim dim // num_heads self.scale qk_scale or self.head_dim ** -0.5 # Pre-define feature projection layers self.geo_dim geo_dim self.dino_dim dino_dim self.geo_adaptive_proj nn.Linear(geo_dim, 3) if geo_dim ! 3 else nn.Identity() self.dino_adaptive_proj nn.Linear(dino_dim, 1024) if dino_dim ! 1024 else nn.Identity() # Differential parameters self.lambda_init lambda_init_fn(depth) self.lambda_q1 nn.Parameter(torch.ones(num_heads) * 0.5) self.lambda_k1 nn.Parameter(torch.ones(num_heads) * 0.5) self.lambda_q2 nn.Parameter(torch.ones(num_heads) * 0.5) self.lambda_k2 nn.Parameter(torch.ones(num_heads) * 0.5) self.subln nn.LayerNorm(dim) # Use modal-specific projection layers if token_projection linear_concat: self.qkv DifferentialLinearProjection_Concat_kv(dim, num_heads, dim // num_heads, biasqkv_bias) else: self.qkv DifferentialLinearProjection(dim, num_heads, dim // num_heads, biasqkv_bias) # Relative position encoding self.relative_position_bias_table nn.Parameter( torch.zeros((2 * self.win_size[0] - 1) * (2 * self.win_size[1] - 1), num_heads)) coords_h torch.arange(self.win_size[0]) coords_w torch.arange(self.win_size[1]) coords torch.stack(torch.meshgrid([coords_h, coords_w], indexingij)) coords_flatten torch.flatten(coords, 1) relative_coords coords_flatten[:, :, None] - coords_flatten[:, None, :] relative_coords relative_coords.permute(1, 2, 0).contiguous() relative_coords[:, :, 0] self.win_size[0] - 1 relative_coords[:, :, 1] self.win_size[1] - 1 relative_coords[:, :, 0] * 2 * self.win_size[1] - 1 relative_position_index relative_coords.sum(-1) self.register_buffer(relative_position_index, relative_position_index) self.token_projection token_projection self.attn_drop nn.Dropout(attn_drop) self.proj nn.Linear(dim, dim) self.proj_drop nn.Dropout(proj_drop) nn.init.trunc_normal_(self.relative_position_bias_table, std.02) self.softmax nn.Softmax(dim-1) def forward(self, x, dino_mat, point_feature, attn_kvNone, maskNone): B_, N, C x.shape # Prepare modal features dino_mat self.dino_adaptive_proj(dino_mat) point_feature self.geo_adaptive_proj(point_feature) geo_feat point_feature dino_feat dino_mat # QKV projection q, k, v self.qkv(x, geo_feat, dino_feat, attn_kv) q q * self.scale # k, v format: [2, B_, heads, N, head_dim] k_geo, k_sem k[0], k[1] # K for geometry and semantics v_geo, v_sem v[0], v[1] # V for geometry and semantics # All heads calculate both attentions attn_geo torch.matmul(q, k_geo.transpose(-2, -1)) # [B_, heads, N, N] attn_sem torch.matmul(q, k_sem.transpose(-2, -1)) # [B_, heads, N, N] # Add relative position bias relative_position_bias self.relative_position_bias_table[ self.relative_position_index.view(-1) ].view(self.win_size[0] * self.win_size[1], self.win_size[0] * self.win_size[1], -1) relative_position_bias relative_position_bias.permute(2, 0, 1).contiguous() ratio attn_geo.size(-1) // relative_position_bias.size(-1) if ratio 1: relative_position_bias repeat(relative_position_bias, nH l c - nH l (c d), dratio) attn_geo attn_geo relative_position_bias.unsqueeze(0) attn_sem attn_sem relative_position_bias.unsqueeze(0) # Handle mask if mask is not None: nW mask.shape[0] mask repeat(mask, nW m n - nW m (n d), dratio) attn_geo attn_geo.view(B_ // nW, nW, self.num_heads, N, N * ratio) mask.unsqueeze(1).unsqueeze(0) attn_sem attn_sem.view(B_ // nW, nW, self.num_heads, N, N * ratio) mask.unsqueeze(1).unsqueeze(0) attn_geo attn_geo.view(-1, self.num_heads, N, N * ratio) attn_sem attn_sem.view(-1, self.num_heads, N, N * ratio) # Softmax attn_geo self.softmax(attn_geo) attn_sem self.softmax(attn_sem) # Differential Attention: Subtract geometry from semantics lambda_val torch.sigmoid(self.lambda_q1 * self.lambda_k1) self.lambda_init # Expand lambda_val to match attention dimensions lambda_val lambda_val.view(1, self.num_heads, 1, 1) attn_diff attn_sem - lambda_val * attn_geo # Apply attention attn_geo self.attn_drop(attn_geo) attn_diff self.attn_drop(attn_diff) # Outputs of two branches x_geo torch.matmul(attn_geo, v_geo) # Geometric branch x_diff torch.matmul(attn_diff, v_sem) # Difference branch # Weighted fusion x x_geo x_diff # Or use learnable weights x x.transpose(1, 2).contiguous().view(B_, N, C) x self.subln(x) x x * (1 - self.lambda_init) x self.proj(x) x self.proj_drop(x) return x def extra_repr(self) - str: return fdim{self.dim}, win_size{self.win_size}, num_heads{self.num_heads}, \ fhead_dim{self.head_dim}, lambda_init{self.lambda_init:.3f} if __name__ __main__: device torch.device(cuda:0 if torch.cuda.is_available() else cpu) input torch.randn(1, 8*8, 64).to(device) sem torch.randn(1, 8*8, 1024).to(device) geo torch.randn(1, 8*8, 3).to(device) model DifferentialWindowAttention(64, (8, 8), 8).to(device) output model(input, sem, geo) print(输入局部特征维度, input.shape) print(输出特征维度, output.shape)结合自己的思路可将其即插即用至任何模型做结构创新设计该模块博主已成功嵌入至YOLO26模型中可订阅博主YOLO系列算法改进或YOLO26自研改进专栏专栏链接YOLO系列算法改进专栏链接、YOLO26自研改进系列专栏
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2432788.html
如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!