DSU-Net

news2025/5/19 5:36:30

目录

Abstract

摘要

DSU-Net

模型框架

编码器

轻量级适配器模块

特征融合与协作

解码器

模型优势

实验

代码

总结


Abstract

DSU-Net is an improved U-Net model based on DINOv2 and SAM2. It addresses the limitations of existing image segmentation models in specific downstream tasks, such as camouflaged object segmentation and salient object segmentation, through multi-scale cross-model feature collaboration and lightweight adapter modules. By leveraging the high-dimensional semantic features from DINOv2 to enhance the multi-scale feature fusion of SAM2, and utilizing attention mechanisms for adaptive aggregation of multi-granularity features, DSU-Net significantly improves segmentation accuracy. Experimental results demonstrate that DSU-Net outperforms existing state-of-the-art methods on multiple benchmark datasets. Moreover, by freezing pre-trained parameters, it reduces training costs, showcasing high efficiency and broad applicability.

摘要

DSU-Net是一种基于DINOv2和SAM2改进的U-Net模型,通过多尺度跨模型特征协作和轻量级适配器模块,解决了现有图像分割模型在特定下游任务中表现不足的问题,如:伪装目标分割和显著目标分割。它利用DINOv2的高维语义特征增强SAM2的多尺度特征融合,同时通过注意力机制实现多粒度特征的自适应聚合,显著提升了分割精度。实验结果表明,DSU-Net在多个基准数据集上超越了现有最先进方法,并且通过冻结预训练参数降低了训练成本,展现出高效性和广泛的适用性。

DSU-Net

DSU-Net是一种改进型的U-Net模型,旨在通过结合DINOv2和SAM2的优势,实现多尺度特征的跨模型协作增强。该模型针对大规模预训练基础模型在特定领域中表现不足的问题,提出了一种高效的解决方案。

论文地址:https://arxiv.org/abs/2503.21187

模型框架

编码器

编码器部分是DSU-Net的核心,它结合了SAM2的Hiera模块和DINOv2的ViT模块,以实现多尺度特征的提取和融合。

  • SAM2 Hiera模块:Hiera模块是SAM2的核心特征提取器,它能够提取高质量的语义特征,适用于通用图像分割任务。DSU-Net将Hiera模块作为主干网络,用于提取图像的多尺度特征。
  • DINOv2 ViT模块:DINOv2的ViT模块通过自监督学习提取高维语义特征,这些特征在捕捉图像的全局语义信息方面表现出色,如下图所示。DSU-Net将DINOv2的特征图注入到Hiera模块的特征图中,以增强语义信息。

轻量级适配器模块

为了缓解训练数据集与预训练模型数据集之间的域差异,DSU-Net引入了轻量级适配器模块。该模块通过少量参数对Hiera模块的特征进行调整,使其能够快速适应新的数据集。

  • 特征降采样与上采样:适配器模块通过线性层和激活函数对Hiera模块的特征进行降采样和上采样,以匹配DINOv2的特征尺度。
  • 参数高效性:适配器模块仅引入少量参数,显著降低了训练成本,同时保持了模型的灵活性。

特征融合与协作

DSU-Net通过多尺度特征融合和注意力机制,实现了DINOv2和SAM2特征的有效协作。

  • 内容引导注意力(Content-Guided Attention)模块:CGA模块利用DINOv2的语义特征作为引导,通过注意力机制增强SAM2的特征表示。该模块动态调整特征图的权重,突出重要的语义信息。
  • 多尺度特征融合:DSU-Net在多个尺度上融合DINOv2和SAM2的特征,通过特征金字塔网络结构,将不同尺度的特征进行交互和融合,以获得更丰富的语义信息。

解码器

解码器部分负责将编码器提取的多尺度特征上采样并生成最终的分割掩码。

  • 空间特征融合模块:SFF模块动态调整不同尺度特征图的权重,通过空间注意力机制增强特征的空间一致性。
  • 分割头:分割头采用1x1卷积和双线性插值上采样,将融合后的特征图转换为高分辨率的分割掩码。

模型优势

高效性:通过轻量级适配器模块和冻结预训练参数,DSU-Net显著降低了训练成本,适合在资源受限的设备上高效训练。

适应性:通过多尺度特征融合和注意力机制,DSU-Net能够适应多种复杂的图像分割任务,展现出强大的泛化能力。

高精度:实验结果表明,DSU-Net在多个基准数据集上超越了现有的最先进方法,显著提升了分割精度。

实验

DUTS数据集测试:

PASCAL数据集测试:

代码

DSU-Net:

import torch
import torch.nn as nn
import torch.nn.functional as F

from backbone import dinov2_extract, sam2hiera
from fusion import CGAFusion, sff
from modules import updown, wtconv, RFB
from torchinfo import summary


class DGSUNet(nn.Module):
    def __init__(self,dino_model_name=None,dino_hub_dir=None,sam_config_file=None,sam_ckpt_path=None):
        super(DGSUNet, self).__init__()
        if dino_model_name is None:
            print("No model_name specified, using default")
            dino_model_name = 'dinov2_vitl14'
        if dino_hub_dir is None:
            print("No dino_hub_dir specified, using default")
            dino_hub_dir = 'facebookresearch/dinov2'
        if sam_config_file is None:
            print("No sam_config_file specified, using default")
            # Replace with your own SAM configuration file path
            sam_config_file = r'G:\MyProjectCode\SAM2DINO-Seg\sam2_configs\sam2.1_hiera_l.yaml'
        if sam_ckpt_path is None:
            print("No sam_ckpt_path specified, using default")
            # Replace with your own SAM pt file path
            sam_ckpt_path = r'G:\MyProjectCode\SAM2DINO-Seg\checkpoints\sam2.1_hiera_large.pt'
        # Backbone Feature Extractor
        self.backbone_dino = dinov2_extract.DinoV2FeatureExtractor(dino_model_name, dino_hub_dir)
        self.backbone_sam = sam2hiera.sam2hiera(sam_config_file,sam_ckpt_path)
        # Feature Fusion
        self.fusion4 = CGAFusion.CGAFusion(1152)
        # (1024,37,37)->(1024,11,11)
        self.dino2sam_down4 = updown.interpolate_upsample(11)
        # (1024,11,11)->(1152,11,11)
        self.dino2sam_down14 = wtconv.DepthwiseSeparableConvWithWTConv2d(in_channels=1024, out_channels=1152)
        self.rfb1 = RFB.RFB_modified(144, 64)
        self.rfb2 = RFB.RFB_modified(288, 64)
        self.rfb3 = RFB.RFB_modified(576, 64)
        self.rfb4 = RFB.RFB_modified(1152, 64)
        self.decoder1 = sff.SFF(64)
        self.decoder2 = sff.SFF(64)
        self.decoder3 = sff.SFF(64)
        self.side1 = nn.Conv2d(64, 1, kernel_size=1)
        self.side2 = nn.Conv2d(64, 1, kernel_size=1)
        self.head = nn.Conv2d(64, 1, kernel_size=1)

    def forward(self, x_dino, x_sam):
        # Backbone Feature Extractor
        x1, x2, x3, x4 = self.backbone_sam(x_sam)
        x_dino = self.backbone_dino(x_dino)
        # change dino feature map size and dimension
        x_dino4 = self.dino2sam_down4(x_dino)
        x_dino4 = self.dino2sam_down14(x_dino4)
        # Feature Fusion(sam & dino)
        x4 = self.fusion4(x4, x_dino4)
        # change fusion feature map dimension->(64,11/22/44/88,11/22/44/88)
        x1, x2, x3, x4 = self.rfb1(x1), self.rfb2(x2), self.rfb3(x3), self.rfb4(x4)
        x = self.decoder1(x4,x3)
        out1 = F.interpolate(self.side1(x), scale_factor=16, mode='bilinear')
        x = self.decoder2(x,x2)
        out2 = F.interpolate(self.side2(x), scale_factor=8, mode='bilinear')
        x = self.decoder3(x,x1)
        out3 = F.interpolate(self.head(x), scale_factor=4, mode='bilinear')
        return out1,out2,out3

######################################################################################################

if __name__ == "__main__":
    with torch.no_grad():
        model = DGSUNet().cuda()
        x_dino = torch.randn(1, 3, 518, 518).cuda()
        x_sam = torch.randn(1, 3, 352, 352).cuda()
        # print(model)
        summary(model, input_data=(x_dino, x_sam))
        out, out1, out2 = model(x_dino,x_sam)
        print(out.shape, out1.shape, out2.shape)

sam2hiera:

import torch
import torch.nn as nn
from sam2.build_sam import build_sam2
from matplotlib import rcParams
from sam2dino_seg.self_transforms.preprocess_image import transforms_image
from sam2dino_seg.modules import adapter
from visualize.features_vis import visualize_feature_maps_mean, visualize_feature_maps_pca, visualize_feature_maps_tsne


# 设置全局字体为 SimHei(黑体)
rcParams['font.sans-serif'] = ['SimHei']  # 指定默认字体
rcParams['axes.unicode_minus'] = False    # 解决负号 '-' 显示为方块的问题
class sam2hiera(nn.Module):
    def __init__(self, config_file=None, ckpt_path=None) -> None:
        super().__init__()
        if config_file is None:
            print("No config file provided, using default config")
            config_file = "./sam2_configs/sam2.1_hiera_l.yaml"
        if ckpt_path is None:
            model = build_sam2(config_file)
        else:
            model = build_sam2(config_file, ckpt_path)
        del model.sam_mask_decoder
        del model.sam_prompt_encoder
        del model.memory_encoder
        del model.memory_attention
        del model.mask_downsample
        del model.obj_ptr_tpos_proj
        del model.obj_ptr_proj
        del model.image_encoder.neck
        self.sam_encoder = model.image_encoder.trunk

        for param in self.sam_encoder.parameters():
            param.requires_grad = False
        # Adapter
        blocks = []
        for block in self.sam_encoder.blocks:
            blocks.append(
                adapter.Adapter(block)
            )
        self.sam_encoder.blocks = nn.Sequential(
            *blocks
        )
    def forward(self, x):
        out = self.sam_encoder(x)
        return out
    
if __name__ == "__main__":
    config_file = r"G:\MyProjectCode\SAM2DINO-Seg\sam2_configs\sam2.1_hiera_l.yaml"
    ckpt_path = r"G:\MyProjectCode\SAM2DINO-Seg\checkpoints\sam2.1_hiera_large.pt"
    # 预处理图像
    image_path = r"G:\MyProjectCode\SAM2DINO-Seg\data\images\COD10K-CAM-1-Aquatic-3-Crab-29.jpg"  # 替换为您的图像路径
    x = transforms_image(image_path, image_size=352)
    with torch.no_grad():
        model = sam2hiera(config_file, ckpt_path).cuda()
        if torch.cuda.is_available():
            x = x.cuda()
        out= model(x)
        # 组合为字典
        # features = {
        #     'high_level': out['backbone_fpn'][2],
        #     'mid_level': out['backbone_fpn'][1],
        #     'low_level': out['backbone_fpn'][0]
        # }
        features = {
            'top_level': out[3],
            'high_level': out[2],
            'mid_level': out[1],
            'low_level': out[0]
        }

        # 打印各特征形状
        print(f"顶级特征形状 (全局尺度): {features['top_level'].shape}")
        # print(f"高级特征形状 (高等尺度): {features['high_level']}")
        print(f"高级特征形状 (高等尺度): {features['high_level'].shape}")
        print(f"中级特征形状 (中等尺度): {features['mid_level'].shape}")
        print(f"低级特征形状 (局部尺度): {features['low_level'].shape}")

        # 均值可视化特征
        visualize_feature_maps_mean(features,backbone_name='SAM2')

        # PCA可视化
        visualize_feature_maps_pca(features,backbone_name='SAM2')

        # T-SNE可视化
        visualize_feature_maps_tsne(features, backbone_name='SAM2')

        print("Hiera多尺度特征提取完成!")

dinov2_extract:

import torch
import torch.nn as nn
import numpy as np
from matplotlib import rcParams
from sam2dino_seg.self_transforms.preprocess_image import transforms_image

# 设置全局字体为 SimHei(黑体)
rcParams['font.sans-serif'] = ['SimHei']  # 指定默认字体
rcParams['axes.unicode_minus'] = False    # 解决负号 '-' 显示为方块的问题

class DinoV2FeatureExtractor(nn.Module):
    def __init__(self, model_name=None, hub_dir=None) -> None:
        super().__init__()
        if hub_dir is None:
            print("No hub_dir specified, using default")
            hub_dir = 'facebookresearch/dinov2'
        if model_name is None:
            print("No model_name specified, using default")
            model_name = 'dinov2_vitl14'
        model = torch.hub.load(hub_dir, model_name, pretrained=True)
        self.dino_encoder = model
        self.patchsize = 14

        for param in self.dino_encoder.parameters():
            param.requires_grad = False

    def forward(self, x):
        output = self.dino_encoder.forward_features(x)
        dino_feature = output['x_norm_patchtokens']
        # print(dino_feature.shape)
        # 转换为空间特征图
        img_size = int(x.shape[-1])
        batch_size = int(x.shape[0])
        feature_size = int((img_size / self.patchsize) ** 2)
        # 验证获取的特征大小
        assert dino_feature.shape[1] == feature_size, f"特征大小不匹配: {dino_feature.shape[1]} vs {feature_size}"
        # 重新构建为2D特征图
        side_length = int(np.sqrt(feature_size))
        dino_feature_map = dino_feature.reshape(batch_size, side_length, side_length, -1).permute(0, 3, 1, 2)

        return dino_feature_map
# 示例使用
if __name__ == "__main__":
    # 预处理图像
    # image_path = r"G:\MyProjectCode\SAM2DINO-Seg\data\images\R-C.jpg"  # 替换为您的图像路径
    # x = transforms_image(image_path, image_size=518)
    x = torch.randn(12, 3, 518, 518)
    with torch.no_grad():
        model = DinoV2FeatureExtractor().cuda()
        if torch.cuda.is_available():
            x = x.cuda()
        out = model(x)
        print(out.shape)
        # print(out)

输入图像:

预测结果:

总结

DSU-Net通过结合DINOv2和SAM2的优势,解决了现有图像分割模型在特定下游任务中表现不足的问题。它利用多尺度跨模型特征协作和轻量级适配器模块,显著提升了模型的分割精度和泛化能力,同时降低了训练成本,使其能够在资源受限的设备上高效运行。DSU-Net的成功为未来图像分割领域的研究提供了重要启发,特别是在跨模型协作、轻量级适配器设计以及多尺度特征融合等方面,为开发更高效、更适应多样化任务的模型奠定了基础。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2379043.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

2025年- H30-Lc138- 141.环形链表(快慢指针,快2慢1)---java版

1.题目描述 2.思路 弗洛伊德算法(快慢指针 3.代码实现 public boolean hasCycle(ListNode head) {//1.如果空节点或者只有一个节点,都说明没有环,返回falseif(headnull||head.nextnull){return false;}//2.定义快慢指针,都从头…

LoadBarWorks:一款赛博风加载动画生成器的构建旅程

我正在参加CodeBuddy「首席试玩官」内容创作大赛,本文所使用的 CodeBuddy 免费下载链接:腾讯云代码助手 CodeBuddy - AI 时代的智能编程伙伴 项目缘起:赛博与实用的结合 在日常开发中,我经常需要为不同的项目添加加载动画&#x…

SAP集团内部公司间交易自动开票

SAP集团内部公司间交易自动开票(非STO/EDI模式) 集团内部公司间采购与销售业务,在确认相应单据无误后,为减少人工开票业务, 可以用系统标准功能来实现自动开票。 1.采购发票自动开票(ERS) T-CODE:BP,勾选“基于收货的发票校验”、“自动G…

【YOLO(txt)格式转VOC(xml)格式数据集】以及【制作VOC格式数据集 】

1.txt—>xml转化代码 如果我们手里只有YOLO标签的数据集,我们要进行VOC格式数据集的制作首先要进行标签的转化,以下是标签转化的脚本。 其中picPath为图片所在文件夹路径; txtPath为你的YOLO标签对应的txt文件所在路径; xmlPa…

Linux 的 UDP 网络编程 -- 回显服务器,翻译服务器

目录 1. 回显服务器 -- echo server 1.1 相关函数介绍 1.1.1 socket() 1.1.2 bind() 1.1.3 recvfrom() 1.1.4 sendto() 1.1.5 inet_ntoa() 1.1.6 inet_addr() 1.2 Udp 服务端的封装 -- UdpServer.hpp 1.3 服务端代码 -- UdpServer.cc 1.4 客户端代码 -- UdpClient.…

C++笔试题(金山科技新未来训练营):

题目分布: 17道单选(每题3分)3道多选题(全对3分,部分对1分)2道编程题(每一道20分)。 不过题目太多,就记得一部分了: 单选题: static变量的初始…

【RabbitMQ】 RabbitMQ高级特性(二)

文章目录 一、重试机制1.1、重试配置1.2、配置交换机&队列1.3、发送消息1.4、消费消息1.5、运行程序1.6、 手动确认 二、TTL2.1、设置消息的TTL2.2、设置队列的TTL2.3、两者区别 三 、死信队列6.1 死信的概念3.2 代码示例3.2.1、声明队列和交换机3.2.2、正常队列绑定死信交…

电子电路:什么是电流离散性特征?

关于电荷的量子化,即电荷的最小单位是电子的电荷量e。在宏观电路中,由于电子数量极大,电流看起来是连续的。但在微观层面,比如纳米器件或单电子晶体管中,单个电子的移动就会引起可观测的离散电流。 还要提到散粒噪声,这是电流离散性的表现之一。当电流非常小时,例如在二…

深入理解位图(Bit - set):概念、实现与应用

目录 引言 一、位图概念 (一)基本原理 (二)适用场景 二、位图的实现(C 代码示例) 三、位图应用 1. 快速查找某个数据是否在一个集合中 2. 排序 去重 3. 求两个集合的交集、并集等 4. 操作系…

猫番阅读APP:丰富资源,优质体验,满足你的阅读需求

猫番阅读APP是一款专为书籍爱好者设计的移动阅读应用,致力于提供丰富的阅读体验和多样化的书籍资源。它不仅涵盖了小说、非虚构、杂志等多个领域的电子书,还提供了个性化推荐、书架管理、离线下载等功能,满足不同读者的阅读需求。无论是通勤路…

MetaMask安装及使用-使用水龙头获取测试币的坑?

常见的异常有: 1.unable to request drip, please try again later. 2.You must hold at least 1 LINK on Ethereum Mainnet to request native tokens. 3.The address provided does not have sufficient historical activity or balance on the Ethereum Mainne…

AI:OpenAI论坛分享—《AI重塑未来:技术、经济与战略》

AI:OpenAI论坛分享—《AI重塑未来:技术、经济与战略》 导读:2025年4月24日,OpenAI论坛全面探讨了 AI 的发展趋势、技术范式、地缘政治影响以及对经济和社会的广泛影响。强调了 AI 的通用性、可扩展性和高级推理能力,以…

Linux配置vimplus

配置vimplus CentOS的配置方案很简单,但是Ubuntu的解决方案网上也很多但是有效的很少,尤其是22和24的解决方案,在此我整理了一下我遇到的问题解决方法 CentOS7 一键配置VimForCPP 基本上不会有什么特别难解决的报错 sudo yum install vims…

服务端HttpServletRequest、HttpServletResponse、HttpSession

一、概述 在JavaWeb 开发中,获取客户端传递的参数至关重要。http请求是客户端向服务端发起数据传输协议,主要包含包含请求行、请求头、空行和请求体四个部分,在这四部分中分别携带客户端传递到服务端的数据。常见的http请求方式有get、post、…

实验九视图索引

设计性实验 1. 创建视图V_A包括学号,姓名,性别,课程号,课程名、成绩; 一个语句把学号103 课程号3-105 的姓名改为陆君茹1,性别为女 ,然后查看学生表的信息变化,再把上述数据改为原…

git 本地提交后修改注释

dos命令行进入目录,idea可以点击Terminal 进入命令行 git commit --amend -m "修改内容"

面向具身智能的视觉-语言-动作模型(VLA)综述

具身智能被广泛认为是通用人工智能(AGI)的关键要素,因为它涉及控制具身智能体在物理世界中执行任务。在大语言模型和视觉语言模型成功的基础上,一种新的多模态模型——视觉语言动作模型(VLA)已经出现&#…

计算机发展的历程

计算机系统的概述 一, 计算机系统的定义 计算机系统的概念 计算机系统 硬件 软件 硬件的概念 计算机的实体, 如主机, 外设等 计算机系统的物理基础 决定了计算机系统的天花板瓶颈 软件的概念 由具有各类特殊功能的程序组成 决定了把硬件的性能发挥到什么程度 软件的分类…

深度学习驱动下的目标检测技术:原理、算法与应用创新(三)

五、基于深度学习的目标检测代码实现 5.1 开发环境搭建 开发基于深度学习的目标检测项目,首先需要搭建合适的开发环境,确保所需的工具和库能够正常运行。以下将详细介绍 Python、PyTorch 等关键开发工具和库的安装与配置过程。 Python 是一种广泛应用于…

jenkins流水线常规配置教程!

Jenkins流水线是在工作中实现CI/CD常用的工具。以下是一些我在工作和学习中总结出来常用的一些流水线配置:变量需要加双引号括起来 "${main}" 一 引用无账号的凭据 使用变量方式引用,这种方式只适合只由密码,没有用户名的凭证。例…