pytorch代码实现之动态蛇形卷积模块DySnakeConv

news2025/7/19 10:17:36

动态蛇形卷积模块DySnakeConv

血管、道路等拓扑管状结构的精确分割在各个领域都至关重要,确保下游任务的准确性和效率。 然而,许多因素使任务变得复杂,包括薄的局部结构和可变的全局形态。在这项工作中,我们注意到管状结构的特殊性,并利用这些知识来指导我们的 DSCNet 在三个阶段同时增强感知:特征提取、特征融合、 和损失约束。 首先,我们提出了一种动态蛇卷积,通过自适应地关注细长和曲折的局部结构来准确捕获管状结构的特征。 随后,我们提出了一种多视图特征融合策略,以补充特征融合过程中多角度对特征的关注,确保保留来自不同全局形态的重要信息。 最后,提出了一种基于持久同源性的连续性约束损失函数,以更好地约束分割的拓扑连续性。 2D 和 3D 数据集上的实验表明,与多种方法相比,我们的 DSCNet 在管状结构分割任务上提供了更好的准确性和连续性。 我们的代码是公开的。

主要的挑战源于细长微弱的局部结构特征与复杂多变的全局形态特征。本文关注到管状结构细长连续的特点,并利用这一信息在神经网络以下三个阶段同时增强感知:特征提取、特征融合和损失约束。分别设计了动态蛇形卷积(Dynamic Snake Convolution),多视角特征融合策略与连续性拓扑约束损失。

原文地址:Dynamic Snake Convolution based on Topological Geometric Constraints for Tubular Structure Segmentation

结构图

pytorch代码实现

import torch
import torch.nn as nn

class DySnakeConv(nn.Module):
    def __init__(self, inc, ouc, k=3, act=True) -> None:
        super().__init__()
        
        self.conv_0 = Conv(inc, ouc, k, act=act)
        self.conv_x = DSConv(inc, ouc, 0, k)
        self.conv_y = DSConv(inc, ouc, 1, k)
        self.conv_1x1 = Conv(ouc * 3, ouc, 1, act=act)
    
    def forward(self, x):
        return self.conv_1x1(torch.cat([self.conv_0(x), self.conv_x(x), self.conv_y(x)], dim=1))

class DSConv(nn.Module):
    def __init__(self, in_ch, out_ch, morph, kernel_size=3, if_offset=True, extend_scope=1):
        """
        The Dynamic Snake Convolution
        :param in_ch: input channel
        :param out_ch: output channel
        :param kernel_size: the size of kernel
        :param extend_scope: the range to expand (default 1 for this method)
        :param morph: the morphology of the convolution kernel is mainly divided into two types
                        along the x-axis (0) and the y-axis (1) (see the paper for details)
        :param if_offset: whether deformation is required, if it is False, it is the standard convolution kernel
        """
        super(DSConv, self).__init__()
        # use the <offset_conv> to learn the deformable offset
        self.offset_conv = nn.Conv2d(in_ch, 2 * kernel_size, 3, padding=1)
        self.bn = nn.BatchNorm2d(2 * kernel_size)
        self.kernel_size = kernel_size

        # two types of the DSConv (along x-axis and y-axis)
        self.dsc_conv_x = nn.Conv2d(
            in_ch,
            out_ch,
            kernel_size=(kernel_size, 1),
            stride=(kernel_size, 1),
            padding=0,
        )
        self.dsc_conv_y = nn.Conv2d(
            in_ch,
            out_ch,
            kernel_size=(1, kernel_size),
            stride=(1, kernel_size),
            padding=0,
        )

        self.gn = nn.GroupNorm(out_ch // 4, out_ch)
        self.act = Conv.default_act

        self.extend_scope = extend_scope
        self.morph = morph
        self.if_offset = if_offset

    def forward(self, f):
        offset = self.offset_conv(f)
        offset = self.bn(offset)
        # We need a range of deformation between -1 and 1 to mimic the snake's swing
        offset = torch.tanh(offset)
        input_shape = f.shape
        dsc = DSC(input_shape, self.kernel_size, self.extend_scope, self.morph)
        deformed_feature = dsc.deform_conv(f, offset, self.if_offset)
        if self.morph == 0:
            x = self.dsc_conv_x(deformed_feature.type(f.dtype))
            x = self.gn(x)
            x = self.act(x)
            return x
        else:
            x = self.dsc_conv_y(deformed_feature.type(f.dtype))
            x = self.gn(x)
            x = self.act(x)
            return x


# Core code, for ease of understanding, we mark the dimensions of input and output next to the code
class DSC(object):
    def __init__(self, input_shape, kernel_size, extend_scope, morph):
        self.num_points = kernel_size
        self.width = input_shape[2]
        self.height = input_shape[3]
        self.morph = morph
        self.extend_scope = extend_scope  # offset (-1 ~ 1) * extend_scope

        # define feature map shape
        """
        B: Batch size  C: Channel  W: Width  H: Height
        """
        self.num_batch = input_shape[0]
        self.num_channels = input_shape[1]

    """
    input: offset [B,2*K,W,H]  K: Kernel size (2*K: 2D image, deformation contains <x_offset> and <y_offset>)
    output_x: [B,1,W,K*H]   coordinate map
    output_y: [B,1,K*W,H]   coordinate map
    """

    def _coordinate_map_3D(self, offset, if_offset):
        device = offset.device
        # offset
        y_offset, x_offset = torch.split(offset, self.num_points, dim=1)

        y_center = torch.arange(0, self.width).repeat([self.height])
        y_center = y_center.reshape(self.height, self.width)
        y_center = y_center.permute(1, 0)
        y_center = y_center.reshape([-1, self.width, self.height])
        y_center = y_center.repeat([self.num_points, 1, 1]).float()
        y_center = y_center.unsqueeze(0)

        x_center = torch.arange(0, self.height).repeat([self.width])
        x_center = x_center.reshape(self.width, self.height)
        x_center = x_center.permute(0, 1)
        x_center = x_center.reshape([-1, self.width, self.height])
        x_center = x_center.repeat([self.num_points, 1, 1]).float()
        x_center = x_center.unsqueeze(0)

        if self.morph == 0:
            """
            Initialize the kernel and flatten the kernel
                y: only need 0
                x: -num_points//2 ~ num_points//2 (Determined by the kernel size)
                !!! The related PPT will be submitted later, and the PPT will contain the whole changes of each step
            """
            y = torch.linspace(0, 0, 1)
            x = torch.linspace(
                -int(self.num_points // 2),
                int(self.num_points // 2),
                int(self.num_points),
            )

            y, x = torch.meshgrid(y, x)
            y_spread = y.reshape(-1, 1)
            x_spread = x.reshape(-1, 1)

            y_grid = y_spread.repeat([1, self.width * self.height])
            y_grid = y_grid.reshape([self.num_points, self.width, self.height])
            y_grid = y_grid.unsqueeze(0)  # [B*K*K, W,H]

            x_grid = x_spread.repeat([1, self.width * self.height])
            x_grid = x_grid.reshape([self.num_points, self.width, self.height])
            x_grid = x_grid.unsqueeze(0)  # [B*K*K, W,H]

            y_new = y_center + y_grid
            x_new = x_center + x_grid

            y_new = y_new.repeat(self.num_batch, 1, 1, 1).to(device)
            x_new = x_new.repeat(self.num_batch, 1, 1, 1).to(device)

            y_offset_new = y_offset.detach().clone()

            if if_offset:
                y_offset = y_offset.permute(1, 0, 2, 3)
                y_offset_new = y_offset_new.permute(1, 0, 2, 3)
                center = int(self.num_points // 2)

                # The center position remains unchanged and the rest of the positions begin to swing
                # This part is quite simple. The main idea is that "offset is an iterative process"
                y_offset_new[center] = 0
                for index in range(1, center):
                    y_offset_new[center + index] = (y_offset_new[center + index - 1] + y_offset[center + index])
                    y_offset_new[center - index] = (y_offset_new[center - index + 1] + y_offset[center - index])
                y_offset_new = y_offset_new.permute(1, 0, 2, 3).to(device)
                y_new = y_new.add(y_offset_new.mul(self.extend_scope))

            y_new = y_new.reshape(
                [self.num_batch, self.num_points, 1, self.width, self.height])
            y_new = y_new.permute(0, 3, 1, 4, 2)
            y_new = y_new.reshape([
                self.num_batch, self.num_points * self.width, 1 * self.height
            ])
            x_new = x_new.reshape(
                [self.num_batch, self.num_points, 1, self.width, self.height])
            x_new = x_new.permute(0, 3, 1, 4, 2)
            x_new = x_new.reshape([
                self.num_batch, self.num_points * self.width, 1 * self.height
            ])
            return y_new, x_new

        else:
            """
            Initialize the kernel and flatten the kernel
                y: -num_points//2 ~ num_points//2 (Determined by the kernel size)
                x: only need 0
            """
            y = torch.linspace(
                -int(self.num_points // 2),
                int(self.num_points // 2),
                int(self.num_points),
            )
            x = torch.linspace(0, 0, 1)

            y, x = torch.meshgrid(y, x)
            y_spread = y.reshape(-1, 1)
            x_spread = x.reshape(-1, 1)

            y_grid = y_spread.repeat([1, self.width * self.height])
            y_grid = y_grid.reshape([self.num_points, self.width, self.height])
            y_grid = y_grid.unsqueeze(0)

            x_grid = x_spread.repeat([1, self.width * self.height])
            x_grid = x_grid.reshape([self.num_points, self.width, self.height])
            x_grid = x_grid.unsqueeze(0)

            y_new = y_center + y_grid
            x_new = x_center + x_grid

            y_new = y_new.repeat(self.num_batch, 1, 1, 1)
            x_new = x_new.repeat(self.num_batch, 1, 1, 1)

            y_new = y_new.to(device)
            x_new = x_new.to(device)
            x_offset_new = x_offset.detach().clone()

            if if_offset:
                x_offset = x_offset.permute(1, 0, 2, 3)
                x_offset_new = x_offset_new.permute(1, 0, 2, 3)
                center = int(self.num_points // 2)
                x_offset_new[center] = 0
                for index in range(1, center):
                    x_offset_new[center + index] = (x_offset_new[center + index - 1] + x_offset[center + index])
                    x_offset_new[center - index] = (x_offset_new[center - index + 1] + x_offset[center - index])
                x_offset_new = x_offset_new.permute(1, 0, 2, 3).to(device)
                x_new = x_new.add(x_offset_new.mul(self.extend_scope))

            y_new = y_new.reshape(
                [self.num_batch, 1, self.num_points, self.width, self.height])
            y_new = y_new.permute(0, 3, 1, 4, 2)
            y_new = y_new.reshape([
                self.num_batch, 1 * self.width, self.num_points * self.height
            ])
            x_new = x_new.reshape(
                [self.num_batch, 1, self.num_points, self.width, self.height])
            x_new = x_new.permute(0, 3, 1, 4, 2)
            x_new = x_new.reshape([
                self.num_batch, 1 * self.width, self.num_points * self.height
            ])
            return y_new, x_new

    """
    input: input feature map [N,C,D,W,H];coordinate map [N,K*D,K*W,K*H] 
    output: [N,1,K*D,K*W,K*H]  deformed feature map
    """
    def _bilinear_interpolate_3D(self, input_feature, y, x):
        device = input_feature.device
        y = y.reshape([-1]).float()
        x = x.reshape([-1]).float()

        zero = torch.zeros([]).int()
        max_y = self.width - 1
        max_x = self.height - 1

        # find 8 grid locations
        y0 = torch.floor(y).int()
        y1 = y0 + 1
        x0 = torch.floor(x).int()
        x1 = x0 + 1

        # clip out coordinates exceeding feature map volume
        y0 = torch.clamp(y0, zero, max_y)
        y1 = torch.clamp(y1, zero, max_y)
        x0 = torch.clamp(x0, zero, max_x)
        x1 = torch.clamp(x1, zero, max_x)

        input_feature_flat = input_feature.flatten()
        input_feature_flat = input_feature_flat.reshape(
            self.num_batch, self.num_channels, self.width, self.height)
        input_feature_flat = input_feature_flat.permute(0, 2, 3, 1)
        input_feature_flat = input_feature_flat.reshape(-1, self.num_channels)
        dimension = self.height * self.width

        base = torch.arange(self.num_batch) * dimension
        base = base.reshape([-1, 1]).float()

        repeat = torch.ones([self.num_points * self.width * self.height
                             ]).unsqueeze(0)
        repeat = repeat.float()

        base = torch.matmul(base, repeat)
        base = base.reshape([-1])

        base = base.to(device)

        base_y0 = base + y0 * self.height
        base_y1 = base + y1 * self.height

        # top rectangle of the neighbourhood volume
        index_a0 = base_y0 - base + x0
        index_c0 = base_y0 - base + x1

        # bottom rectangle of the neighbourhood volume
        index_a1 = base_y1 - base + x0
        index_c1 = base_y1 - base + x1

        # get 8 grid values
        value_a0 = input_feature_flat[index_a0.type(torch.int64)].to(device)
        value_c0 = input_feature_flat[index_c0.type(torch.int64)].to(device)
        value_a1 = input_feature_flat[index_a1.type(torch.int64)].to(device)
        value_c1 = input_feature_flat[index_c1.type(torch.int64)].to(device)

        # find 8 grid locations
        y0 = torch.floor(y).int()
        y1 = y0 + 1
        x0 = torch.floor(x).int()
        x1 = x0 + 1

        # clip out coordinates exceeding feature map volume
        y0 = torch.clamp(y0, zero, max_y + 1)
        y1 = torch.clamp(y1, zero, max_y + 1)
        x0 = torch.clamp(x0, zero, max_x + 1)
        x1 = torch.clamp(x1, zero, max_x + 1)

        x0_float = x0.float()
        x1_float = x1.float()
        y0_float = y0.float()
        y1_float = y1.float()

        vol_a0 = ((y1_float - y) * (x1_float - x)).unsqueeze(-1).to(device)
        vol_c0 = ((y1_float - y) * (x - x0_float)).unsqueeze(-1).to(device)
        vol_a1 = ((y - y0_float) * (x1_float - x)).unsqueeze(-1).to(device)
        vol_c1 = ((y - y0_float) * (x - x0_float)).unsqueeze(-1).to(device)

        outputs = (value_a0 * vol_a0 + value_c0 * vol_c0 + value_a1 * vol_a1 +
                   value_c1 * vol_c1)

        if self.morph == 0:
            outputs = outputs.reshape([
                self.num_batch,
                self.num_points * self.width,
                1 * self.height,
                self.num_channels,
            ])
            outputs = outputs.permute(0, 3, 1, 2)
        else:
            outputs = outputs.reshape([
                self.num_batch,
                1 * self.width,
                self.num_points * self.height,
                self.num_channels,
            ])
            outputs = outputs.permute(0, 3, 1, 2)
        return outputs

    def deform_conv(self, input, offset, if_offset):
        y, x = self._coordinate_map_3D(offset, if_offset)
        deformed_feature = self._bilinear_interpolate_3D(input, y, x)
        return deformed_feature


#### YOLOV5
class Bottleneck_DySnake(nn.Module):
    # Standard bottleneck
    def __init__(self, c1, c2, shortcut=True, g=1, e=0.5):  # ch_in, ch_out, shortcut, groups, expansion
        super().__init__()
        c_ = int(c2 * e)  # hidden channels
        self.cv1 = Conv(c1, c_, 1, 1)
        self.cv2 = DySnakeConv(c_, c2, 3)
        self.add = shortcut and c1 == c2

    def forward(self, x):
        return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))

class C3_DySnake(C3):
    # C3 module with DySnakeConv
    def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
        super().__init__(c1, c2, n, shortcut, g, e)
        c_ = int(c2 * e)
        self.m = nn.Sequential(*(Bottleneck_DySnake(c_, c_, shortcut, g, e=1.0) for _ in range(n)))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1102709.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

安达发|大多数离散型生产模式适用APS自动排程系统

在离散型生产模式中&#xff0c;智能生产排程软件&#xff08;APS&#xff09;的应用越来越广泛。这是因为APS能够根据实时的生产需求和资源状况&#xff0c;自动进行生产计划的制定和调整&#xff0c;从而提高生产效率&#xff0c;降低生产成本&#xff0c;保证生产的顺利进行…

美国股票和加密货币平台【Alpaca】完成1500万美元融资

来源&#xff1a;猛兽财经 作者&#xff1a;猛兽财经 猛兽财经获悉&#xff0c;总部位于美国加利福尼亚州圣马特奥的股票和加密交易经纪平台提供商&#xff0c;近期宣布已从SBI集团获得了1500万美元融资。 该公司打算利用这笔资金加快业务扩张&#xff0c;并将其业务范围扩大到…

iPaaS混合集成平台,打造数字化生态

如今企业分工越来越细&#xff0c;上下游合作越来越紧密、各企业之间的业务系统需要相互协作完成业务、外部API依赖越来越多、同时企业系统运行在多个混合云环境及SaaS中&#xff0c;私有端大量业务系统与云端系统形成了错综复杂的集成关系&#xff0c;企业面临集成技术复杂多样…

Springboot整合taos时序数据库TDengine

1.首先安装TDengine服务端在linux上 TDengine多种安装包的安装和卸载 - TDengine | 涛思数据安装过程直接去官网看,非常详细简单 2.出现的问题 windows连接 invalid app version 版本不对应 版本不对应的问题,需要在linux上安装的版本和windows client版本一致,不然w…

Kubernetes基础(六)-常见 Kubernetes Pod 驱逐场景

Kubernetes Pod 被驱逐是什么意思&#xff1f; 它们被终止&#xff0c;通常是没有足够资源的结果。但是为什么会这样呢&#xff1f; 驱逐是指派给节点的Pod 被终止的过程。 Kubernetes 中最常见的情况之一是Preemption&#xff0c;为了在资源有限的节点中调度新的 Pod&#…

安卓14通过“冻结”缓存应用程序腾出CPU,提高性能和内存效率

本月早些时候&#xff0c;我们听说更新到安卓14似乎提高了谷歌Pixel 7和Pixel 6的效率——提高了电池寿命&#xff0c;并在这个过程中减少了热量的产生。现在看来&#xff0c;安卓14的增效功能细节已经公布。 安卓侦探Mishaal Rahman在X&#xff08;前身为Twitter&#xff09;…

林沛满--快递员的工作策略——TCP窗口

本文整理自&#xff1a;《Wireshark网络分析就这么简单 第1版》 作者&#xff1a;林沛满 著 出版时间&#xff1a;2014-12 假如你是一位勤劳的快递员&#xff0c;要送100个包裹到某公司去&#xff0c;怎样送货才科学? 最简单的方式是每次送1个&#xff0c;总共跑100趟。当然这…

uCOSIII实时操作系统 八 软件定时器

目录 软件定时器概述 使用步骤&#xff1a; 创建软件定时器&#xff1a; 启动软件定时器&#xff1a; 停止软件定时器&#xff1a; 删除软件定时器&#xff1a; 单次定时器&#xff1a; ​编辑周期定时器&#xff1a; 无初始化延时&#xff1a; 有初始化延时&#xff…

LabVIEW中使用Get LV Class Default Value 出现错误1498

LabVIEW中使用Get LV Class Default Value 出现错误1498 在LabVIEW中开发了一个应用程序&#xff0c;其中包含可以在执行时动态配置插件的基类。生成可执行文件后&#xff0c;当应用程序要执行子类时&#xff0c;收到以下错误信息。 Error1498 occurred at Gen LV Class Defa…

Sandboxie+Buster Sandbox Analyzer打造个人沙箱

一、运行环境和需要安装的软件 实验环境&#xff1a;win7_x32或win7_x64 用到的软件&#xff1a;WinPcap_4_1_3.exe、Sandboxie-3-70.exe、Buster Sandbox Analyzer 重点是Sandboxie必须是3.70版本。下载地址&#xff1a;https://github.com/sandboxie-plus/sandboxie-old/blo…

Linux性能优化--使用性能工具发现问题

9.0 概述 本章主要介绍综合运用之前提出的性能工具来缩小性能问题产生原因的范围。阅读本章后&#xff0c;你将能够&#xff1a; 启动行为异常的系统&#xff0c;使用Linux性能工具追踪行为异常的内核函数或应用程序。启动行为异常的应用程序&#xff0c;使用Linux性能工具追…

浪子带你【25天】玩转Python——期中福利

人生苦短&#xff0c;我用Python! 目录 回顾上文 正文 最后的话 回顾上文 浪子带你【25天】玩转Python——5.面向对象编程&#xff08;类和对象&#xff09;-CSDN博客 正文 哈喽&#xff0c;不知不觉中&#xff0c;浪子的【25天】玩转Python已经开播13天啦&#xff01…

【c++智能指针】

目录 一、智能指针的使用及原理二、auto_ptr三、unique_ptr三、shared_ptr四、weak_ptr五、定制删除器 一、智能指针的使用及原理 RAII&#xff08;Resource Acquisition Is Initialization&#xff09;是一种利用对象生命周期来控制程序资源&#xff08;如内存、文件句柄、网…

新型的终端复用器 tmux

以前遇到长时间执行任务时&#xff0c;一般是使用nohup加后台运行&#xff0c;但是涉及到少量代码编写。 同事介绍了一个screen命令&#xff0c;根据文档&#xff0c;此命令已经过时&#xff0c;最新的命令是tmux。 tmux的介绍文档&#xff0c;RedHat的这一篇非常不错。 在文…

vue ref和$refs获取dom元素

vue ref和$refs获取dom元素 **创建 工程&#xff1a; H:\java_work\java_springboot\vue_study ctrl按住不放 右键 悬着 powershell H:\java_work\java_springboot\js_study\Vue2_3入门到实战-配套资料\01-随堂代码素材\day04\准备代码\14-ref和$refs获取dom对象 vue --ve…

Kotlin中的数值类型

在Kotlin中&#xff0c;Byte、Short、Int、Long、Float和Double是基本数据类型&#xff0c;用于表示不同范围和精度的数值。 Byte&#xff08;字节&#xff09;&#xff1a;Byte类型是8位有符号整数类型&#xff0c;取值范围为-128到127。在Kotlin中&#xff0c;可以使用字面值…

《深入浅出OCR》第三章:OCR文字检测

✨专栏介绍: 经过几个月的精心筹备,本作者推出全新系列《深入浅出OCR》专栏,对标最全OCR教程,具体章节如导图所示,将分别从OCR技术发展、方向、概念、算法、论文、数据集等各种角度展开详细介绍。 👨‍💻面向对象: 本篇前言知识主要介绍深度学习知识,全面总结知知识…

Linux Zabbix企业级监控平台+cpolar实现远程访问

文章目录 前言1. Linux 局域网访问Zabbix2. Linux 安装cpolar3. 配置Zabbix公网访问地址4. 公网远程访问Zabbix5. 固定Zabbix公网地址 前言 Zabbix是一个基于WEB界面的提供分布式系统监视以及网络监视功能的企业级的开源解决方案。能监视各种网络参数&#xff0c;保证服务器系…

C++stack和queue模拟实现以及deque的介绍

stack和queue介绍以及模拟实现 1.stack1.1stack的介绍1.2stack的使用 2.queue2.1queue的介绍2.2queue的使用 3.容器适配器3.1什么是适配器 4.stack模拟实现5.queue的模拟实现6.deque&#xff08;双端队列&#xff09; 1.stack 1.1stack的介绍 stack的文档介绍 stack是一种容…

软信天成:流程管理是企业精细化管理的一大利器

流程管理&#xff08;BPM&#xff09;是指组织和管理内部或跨部门的工作流程&#xff0c;主要包括设计、建模、执行、监控和优化业务流程&#xff0c;确保工作按照标准化的步骤进行&#xff0c;从而提高效率、降低成本&#xff0c;促进业务增长。 一、流程管理生命周期五大步骤…