Transformer Fusion for Indoor RGB-D Semantic Segmentation非官方自己实现的代码

news2025/7/19 6:22:52

在这里插入图片描述
声明:文章没有官方的代码,这里自己浅显的分析一下。
首先看一下encoder,就是swin transformer,假设RGB的维度为(1,3,480,480),Depth维度为(1,1,480,480)。维度分别变为**(1,64,120,120)—>(1,64,120,120)—>(1,128,60,60)—>(1,256,30,30)—>(1,512,15,15)。**
在这里插入图片描述
接着经过卷积,然后展平。在原文中,通道进行了减少,假设我们这里通过一个1x1卷积,将通道由512缩减到256.然后进行展平,(1,256,15,15)展平为(1,256,225)。
在这里插入图片描述
接着输入到三步融合模块中。
首先大小和维度经过融合后是不发生变换的。即输出也为(1,256,225)。
在这里插入图片描述
然后就是图中的位置编码:
既然他能和K和Q进行相加,那么维度肯定是匹配的。
公式:我们可以看到位置编码是在生成Q,K,V之后才加上的。而非先加上再生成Q,K,V因为V就没有加位置编码。且这个地方采用的多头,对每一个头的Q和V进行相加。
在这里插入图片描述
在这里插入图片描述
把公式5括号里面的参数带入到multihead(Q,K,V)中。
在这里插入图片描述
关于这个地方位置编码的相加存在一些疑问:
首先他的计算方式和VIT的一样,没有什么不同,但是关键在于只对Q和K进行相加,对V没有加。
我们看VIT的维度变化,VIT使用了多头注意力计算。本来每个头的Q,K,V都计算一次,即每次只有一个头,有多少个头计算多少次,然后总的结果按照维度进行拼接。
在这里插入图片描述
在代码中(12个头),实现的是
Q,K,V有12个头,Q,K,V总共只计算一次。

在这里插入图片描述
在本文中,位置编码为第一个stage的输出,即(1,14400,64)。经过reshape为图片(1,64,120,120)。我们知道self.attention的输入为(1,225,256)。那么如果要相加,需要把维度升高到256,即原文中的投影到高维空间(语义信息就更丰富)。然后经过两个3x3卷积进行局部特征提取,最后进行下采样,reshape到序列。
我们猜测一下(1,64,120,120)—>(1,256,120,120)—>(1,256,120,120)—>(1,256,120,120)—>(1,256,15,15)—>(1,225,256)
在这里插入图片描述
然后回到融合部分,输入master序列为(1,225,256),经过类似VIT线性投影操作,生成Q,K,V,和VIT一样是多头,那么假设有8个头,那么Q,K,V的维度分别为(1,8,225,32)。然后就是位置编码(1,225,256)经过reshape,即分为八个头,每个头维度为(1,225,32)。这样就可以和Q,V进行相加。至于为什么V不加位置编码,我也不晓得。难道是为了和后面保持一致?补充序列也一样。

后面的交叉注意力:和前面的基本一样,区别是:上下两个分支的Q和Q带的位置编码分别来自对方。因为他们两个分支的头一样,维度一样,只有模态不一样,所以进行相乘还是不难的,还有就是位置编码和前一个自注意力使用的是一样的,即一个来自stage1的位置编码全程使用。
在这里插入图片描述
最后是深度引导,上一个模块的辅助分支的输出(1,225,256),仍然经过线性投影为多头。他进行qkv生成时候的linear就不需要x3了,,因为他只有一个输出。同理主分支也只需生成两个分支,embed_dim扩大两倍。最后在reshape为图片大小,即(1,256,15,15)。
decoder:八倍上采样。
在这里插入图片描述
自己写的代码,很粗糙,其中将swin transformer转换为了resnet34,这样结构简单了点,而且每一层的输出都是图片,避免了位置编码还要将序列转换为图片。都是一些很简单的函数,如有错误,欢迎指正。

import warnings

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary
# from src.models.refine import Refine
from src.models.resnet import ResNet34
from src.models.context_modules import get_context_module
from src.models.resnet import BasicBlock, NonBottleneck1D


# from src.models.model_utils import ConvBNAct, Swish, Hswish
# conding=gb2312

def convt_bn_relu(ch_in, ch_out, kernel=3, stride=2, padding=1, output_padding=1,
                  bn=True, relu=True):
    assert (kernel % 2) == 1, \
        'only odd kernel is supported but kernel = {}'.format(kernel)

    layers = []
    layers.append(nn.ConvTranspose2d(ch_in, ch_out, kernel, stride, padding,
                                     output_padding, bias=not bn))
    if bn:
        layers.append(nn.BatchNorm2d(ch_out))
    if relu:
        layers.append(nn.ReLU(inplace=True))

    layers = nn.Sequential(*layers)

    return layers


def conv_bn_relu(ch_in, ch_out, kernel, stride=1, padding=0, bn=True,relu=True):
    assert (kernel % 2) == 1, \
        'only odd kernel is supported but kernel = {}'.format(kernel)

    layers = []
    layers.append(nn.Conv2d(ch_in, ch_out, kernel, stride, padding,
                            bias=not bn))
    if bn:
        layers.append(nn.BatchNorm2d(ch_out))
    if relu:
        layers.append(nn.ReLU(inplace=True))

    layers = nn.Sequential(*layers)

    return layers

def drop_path(x, drop_prob: float = 0., training: bool = False):

    if drop_prob == 0. or not training:
        return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
    random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
    random_tensor.floor_()  # binarize
    output = x.div(keep_prob) * random_tensor
    return output

class DropPath(nn.Module):

    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training)

class Mlp(nn.Module):
    """
    MLP as used in Vision Transformer, MLP-Mixer and related networks
    """
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

class Spge(nn.Module):
    def __init__(self,
                 in_channels,
                 ):
        super(Spge, self).__init__()
        self.spge_rgb = nn.Sequential(
            nn.Conv2d(in_channels, in_channels * 4, kernel_size=1, stride=1, padding=0),
            conv_bn_relu(in_channels * 4, in_channels * 4, kernel=3, padding=1, bn=True, relu=True),
            conv_bn_relu(in_channels * 4, in_channels * 4, kernel=3, padding=1, bn=True, relu=True)
        )
        self.spge_depth = nn.Sequential(
            nn.Conv2d(in_channels, in_channels * 4, kernel_size=1, stride=1, padding=0),
            conv_bn_relu(in_channels * 4, in_channels * 4, kernel=3, padding=1, bn=True, relu=True),
            conv_bn_relu(in_channels * 4, in_channels * 4, kernel=3, padding=1, bn=True, relu=True)
        )
    def forward(self,rgb,depth): #(1,64,120,120)
        _,_,h,w = rgb.size()
        rgb_f = self.spge_rgb(rgb)#(1,256,120,120)
        depth_f = self.spge_depth(depth)#(1,256,120,120)
        rgb_f_1 = F.interpolate(rgb_f,size=(h//8,w//8),mode='nearest').flatten(2).permute(0,2,1)#(1,256,15,15)
        depth_f_1 = F.interpolate(depth_f,size=(h//8,w//8),mode='nearest').flatten(2).permute(0,2,1)#(1,256,15,15)
        return rgb_f_1 ,depth_f_1

class Attention(nn.Module):
    def __init__(self,
                 dim,   # 输入token的dim
                 num_heads=8,
                 qkv_bias=False,
                 qk_scale=None,
                 attn_drop_ratio=0.,
                 proj_drop_ratio=0.):
        super(Attention, self).__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.pos = nn.Linear(dim, dim, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop_ratio)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop_ratio)

    def forward(self, x ,pe):
        # [batch_size, num_patches + 1, total_embed_dim]
        B, N, C = x.shape #(1,225,256)
        b, n, c = pe.shape#(1,225,256)
        posenc = self.pos(pe).reshape(b, n, self.num_heads, c // self.num_heads).permute(0,2,1,3)#(1,8,225,32)
        # qkv(): -> [batch_size, num_patches + 1, 3 * total_embed_dim]
        # reshape: -> [batch_size, num_patches + 1, 3, num_heads, embed_dim_per_head]
        # permute: -> [3, batch_size, num_heads, num_patches + 1, embed_dim_per_head]
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)#(3,1,8,225,32)
        # [batch_size, num_heads, num_patches + 1, embed_dim_per_head]
        # q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)
        q = qkv[0] + posenc #(1,8,225,32)
        k = qkv[1] + posenc #(1,8,225,32)
        v = qkv[2]          #(1,8,225,32)

        # transpose: -> [batch_size, num_heads, embed_dim_per_head, num_patches + 1]
        # @: multiply -> [batch_size, num_heads, num_patches + 1, num_patches + 1]
        attn = (q @ k.transpose(-2, -1)) * self.scale#(1,8,225,225)
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        # @: multiply -> [batch_size, num_heads, num_patches + 1, embed_dim_per_head]
        # transpose: -> [batch_size, num_patches + 1, num_heads, embed_dim_per_head]
        # reshape: -> [batch_size, num_patches + 1, total_embed_dim]
        x = (attn @ v).transpose(1, 2).reshape(B, N, C) #(1,225,256)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

class Attention_c(nn.Module):
    def __init__(self,
                 dim,   # 输入token的dim
                 num_heads=8,
                 qkv_bias=False,
                 qk_scale=None,
                 attn_drop_ratio=0.,
                 proj_drop_ratio=0.):
        super(Attention_c, self).__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.pos = nn.Linear(dim, dim, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop_ratio)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop_ratio)

    def forward(self, x ,y, x_pos, y_pos): #(depth_s,rgb_s,depth_pe,rgb_pe)

        B, N, C = x.shape #(1,225,256)
        b, n, c = x_pos.shape#(1,225,256)

        x_pos = self.pos(x_pos).reshape(b, n,self.num_heads, c // self.num_heads).permute(0,2,1,3)#(1,8,225,32)
        y_pos = self.pos(y_pos).reshape(b, n,self.num_heads, c // self.num_heads).permute(0,2,1,3)#(1,8,225,32)

        qkv_x = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)#(3,1,8,225,32)
        qkv_y = self.qkv(y).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)#(3,1,8,225,32)

        q = qkv_y[0] + y_pos #(1,8,225,32)
        k = qkv_x[1] + x_pos #(1,8,225,32)
        v = qkv_x[2]#(1,8,225,32)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

class Attention_g(nn.Module):
    def __init__(self,
                 dim,   # 输入token的dim
                 num_heads=8,
                 qkv_bias=False,
                 qk_scale=None,
                 attn_drop_ratio=0.,
                 proj_drop_ratio=0.):
        super(Attention_g, self).__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5
        self.qkv_x = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.qkv_y = nn.Linear(dim, dim,bias=qkv_bias)
        self.pos = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop_ratio)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop_ratio)

    def forward(self, x ,y, x_pos, y_pos):

        B, N, C = x.shape #(1,225,256)
        b, n, c = x_pos.shape

        x_pos = self.pos(x_pos).reshape(b, n, 3, self.num_heads, c // self.num_heads).permute(2, 0, 3, 1, 4)
        y_pos = self.pos(y_pos).reshape(b, n, 3, self.num_heads, c // self.num_heads).permute(2, 0, 3, 1, 4)

        qkv_x = self.qkv_x(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        qkv_y = self.qkv_y(y).reshape(B, N, self.num_heads, C // self.num_heads).permute(0,2,1,3)

        q = qkv_y+ y_pos
        k = qkv_x[1] + x_pos
        v = qkv_x[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

class Block_s(nn.Module):
    def __init__(self,
                 dim=256,
                 num_heads=8,
                 qkv_bias=False,
                 qk_scale=None,
                 drop_ratio=0.,
                 attn_drop_ratio=0.,
                 drop_path_ratio=0.,
                 norm_layer=nn.LayerNorm):
        super(Block_s, self).__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
                              attn_drop_ratio=attn_drop_ratio, proj_drop_ratio=drop_ratio)
        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        self.drop_path = DropPath(drop_path_ratio) if drop_path_ratio > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)

    def forward(self, x,rgb_pos):
        x = x + self.drop_path(self.attn(self.norm1(x),rgb_pos)) #(1,225,256)
        return x

class Block_c(nn.Module):
    def __init__(self,
                 dim,
                 num_heads,
                 qkv_bias=False,
                 qk_scale=None,
                 drop_ratio=0.,
                 attn_drop_ratio=0.,
                 drop_path_ratio=0.,
                 norm_layer=nn.LayerNorm):
        super(Block_c, self).__init__()
        self.norm1 = norm_layer(dim)
        self.attn_c = Attention_c(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
                              attn_drop_ratio=attn_drop_ratio, proj_drop_ratio=drop_ratio)
        self.drop_path = DropPath(drop_path_ratio) if drop_path_ratio > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)

    def forward(self, x,y,rgb_pos,depth_pos):
        x = x + self.drop_path(self.attn_c(self.norm1(x),self.norm2(y),rgb_pos,depth_pos))#(1,225,256)
        return x

class Block_g(nn.Module):
    def __init__(self,
                 dim,
                 num_heads,
                 qkv_bias=False,
                 qk_scale=None,
                 drop_ratio=0.,
                 attn_drop_ratio=0.,
                 drop_path_ratio=0.,
                 norm_layer=nn.LayerNorm):
        super(Block_g, self).__init__()
        self.norm1 = norm_layer(dim)
        self.attn_c = Attention_c(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
                              attn_drop_ratio=attn_drop_ratio, proj_drop_ratio=drop_ratio)
        self.drop_path = DropPath(drop_path_ratio) if drop_path_ratio > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)

    def forward(self, x,y,rgb_pos,depth_pos):
        x = x + self.drop_path(self.attn_c(self.norm1(x),self.norm2(y),rgb_pos,depth_pos))
        return x


class Transformer_fusion(nn.Module):
    def __init__(self,
                 num_heads=8,
                 qkv_bias = False,
                 embed_dim =256,
                 qk_scale = None,
                 drop_ratio = 0.1,
                 attn_drop_ratio =0.,
                 drop_path_ratio =0.,
                 norm_layer = nn.LayerNorm,
                 ):
        super(Transformer_fusion,self).__init__()

        self.rgb_out = Block_s(dim=embed_dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
                    drop_ratio=drop_ratio, attn_drop_ratio=attn_drop_ratio, drop_path_ratio=drop_path_ratio,
                    norm_layer=norm_layer)

        self.depth_out = Block_s(dim=embed_dim, num_heads=num_heads,qkv_bias=qkv_bias,
                             qk_scale=qk_scale,drop_ratio=drop_ratio, attn_drop_ratio=attn_drop_ratio, drop_path_ratio=drop_path_ratio,
                             norm_layer=norm_layer)

        self.rgb_c = Block_c(dim=embed_dim, num_heads=num_heads,  qkv_bias=qkv_bias,
                             qk_scale=qk_scale,drop_ratio=drop_ratio, attn_drop_ratio=attn_drop_ratio, drop_path_ratio=drop_path_ratio,
                             norm_layer=norm_layer)

        self.depth_c = Block_c(dim=embed_dim, num_heads=num_heads, qkv_bias=qkv_bias,
                           qk_scale=qk_scale, drop_ratio=drop_ratio, attn_drop_ratio=attn_drop_ratio,
                           drop_path_ratio=drop_path_ratio,
                           norm_layer=norm_layer)

        self.rgb_g = Block_g(dim=embed_dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
                    drop_ratio=drop_ratio, attn_drop_ratio=attn_drop_ratio, drop_path_ratio=drop_path_ratio,
                    norm_layer=norm_layer)

        self.norm_rgb = norm_layer(embed_dim)
        self.norm_depth = norm_layer(embed_dim)

    def forward(self,rgb,depth,rgb_pe,depth_pe):
        b,n,c = rgb.shape #(1,225,256)
        rgb_s   = self.rgb_out  (rgb,rgb_pe)#(1,225,256)
        depth_s = self.depth_out(depth,depth_pe)#(1,225,256)

        rgb_c   = self.rgb_c    (rgb_s,depth_s,rgb_pe,depth_pe)#(1,225,256)
        depth_c = self.depth_c  (depth_s,rgb_s,depth_pe,rgb_pe)#(1,225,256)

        rgb_g = self.rgb_g(rgb_c,depth_c,rgb_pe,depth_pe)#(1,225,256)
        out = rgb_g.reshape(b,c,int(np.sqrt(n)),int(np.sqrt(n)))#(1,256,15,15)

        return out

class DeepLabV3Plus(nn.Module):
    def __init__(self, in_channels=[64, 128, 256, 512], num_classes=40, norm_layer=nn.BatchNorm2d):
        super(DeepLabV3Plus, self).__init__()
        self.num_classes = num_classes

        self.aspp = ASPP(in_channels=in_channels[2], atrous_rates=[6, 12, 18], norm_layer=norm_layer)
        self.low_level = nn.Sequential(
            nn.Conv2d(in_channels[0], 256, kernel_size=3, stride=1, padding=1),
            norm_layer(256),
            nn.ReLU(inplace=True)
        )
        self.block = nn.Sequential(
            nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1),
            norm_layer(256),
            nn.ReLU(inplace=True),
            convt_bn_relu(256,128,kernel=3,stride=4,padding=1,output_padding=3),
            nn.Dropout(0.1),
            nn.Conv2d(128, num_classes, 1))

    def forward(self, c1,c4): #(1,64,120,120)/(1,512,15,15)
        c1 = self.low_level(c1)#(1,256,120,120)
        c4 = self.aspp(c4)#(1,256,15,15)
        c4 = F.interpolate(c4, c1.size()[2:], mode='bilinear', align_corners=True)#(1,256,120,120)
        output = self.block(torch.cat([c4, c1], dim=1)) #(1,40,480,480)
        return output


class ASPPConv(nn.Module):
    def __init__(self, in_channels, out_channels, atrous_rate, norm_layer):
        super(ASPPConv, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=atrous_rate, dilation=atrous_rate, bias=False),
            norm_layer(out_channels),
            nn.ReLU(True)
        )

    def forward(self, x):
        return self.block(x)


class AsppPooling(nn.Module):
    def __init__(self, in_channels, out_channels, norm_layer):
        super(AsppPooling, self).__init__()
        self.gap = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels, out_channels, 1, bias=False),
            norm_layer(out_channels),
            nn.ReLU(True)
        )

    def forward(self, x):
        size = x.size()[2:]
        pool = self.gap(x)
        out = F.interpolate(pool, size, mode='bilinear', align_corners=True)
        return out


class ASPP(nn.Module):
    def __init__(self, in_channels, atrous_rates, norm_layer):
        super(ASPP, self).__init__()
        out_channels = 256
        self.b0 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 1, bias=False),
            norm_layer(out_channels),
            nn.ReLU(True)
        )

        rate1, rate2, rate3 = tuple(atrous_rates)
        self.b1 = ASPPConv(in_channels, out_channels, rate1, norm_layer)
        self.b2 = ASPPConv(in_channels, out_channels, rate2, norm_layer)
        self.b3 = ASPPConv(in_channels, out_channels, rate3, norm_layer)
        self.b4 = AsppPooling(in_channels, out_channels, norm_layer=norm_layer)

        self.project = nn.Sequential(
            nn.Conv2d(5 * out_channels, out_channels, 1, bias=False),
            norm_layer(out_channels),
            nn.ReLU(True),
            nn.Dropout(0.5)
        )

    def forward(self, x):
        feat1 = self.b0(x)
        feat2 = self.b1(x)
        feat3 = self.b2(x)
        feat4 = self.b3(x)
        feat5 = self.b4(x)
        x = torch.cat((feat1, feat2, feat3, feat4, feat5), dim=1)
        x = self.project(x) #(1,256,15,15)
        return x

class ESANet(nn.Module):
    def __init__(self,
                 height=480,
                 width=480,
                 num_classes=40,
                 encoder_rgb='resnet34',
                 encoder_depth='resnet34',
                 encoder_block='BasicBlock',
                 pretrained_on_imagenet=True,
                 pretrained_dir='',
                 activation='relu',
                 ):
        super(ESANet, self).__init__()
        in_channels = 64

        # set activation function
        if activation.lower() == 'relu':
            self.activation = nn.ReLU(inplace=True)

        else:
            raise NotImplementedError(
                'Only relu, swish and hswish as activation function are '
                'supported so far. Got {}'.format(activation))

        if encoder_rgb == 'resnet50' or encoder_depth == 'resnet50':
            warnings.warn('Parameter encoder_block is ignored for ResNet50. '
                          'ResNet50 always uses Bottleneck')

        # rgb encoder
        if encoder_rgb == 'resnet34':
            self.encoder_rgb = ResNet34(
                block=encoder_block,
                pretrained_on_imagenet=pretrained_on_imagenet,
                pretrained_dir=pretrained_dir,
                activation=self.activation)

        else:
            raise NotImplementedError(
                'Only ResNets are supported for '
                'encoder_rgb. Got {}'.format(encoder_rgb))

        # depth encoder
        if encoder_depth == 'resnet34':
            self.encoder_depth = ResNet34(
                block=encoder_block,
                pretrained_on_imagenet=pretrained_on_imagenet,
                pretrained_dir=pretrained_dir,
                activation=self.activation,
                input_channels=1)
        else:
            raise NotImplementedError(
                'Only ResNets are supported for '
                'encoder_depth. Got {}'.format(encoder_rgb))

        self.conv1_rgb = conv_bn_relu(4, 64, kernel=3, stride=2, padding=1,bn=True)

        self.rgbconv = conv_bn_relu(in_channels*8,in_channels*4,kernel=3,stride=1,padding=1,bn=True)
        self.depthconv = conv_bn_relu(in_channels*8,in_channels*4,kernel=3,stride=1,padding=1,bn=True)

        self.spge = Spge(in_channels)
        self.transfusion = Transformer_fusion(num_heads=8,
                 qkv_bias = False,
                 embed_dim =256,
                 qk_scale = None,
                 drop_ratio = 0.1,
                 attn_drop_ratio =0.,
                 drop_path_ratio =0.,
                 norm_layer = nn.LayerNorm,)

        self.decoder = DeepLabV3Plus(
            in_channels=[64, 128, 256, 512],
            num_classes=40,
            norm_layer=nn.BatchNorm2d
        )

    def forward(self, rgb, depth):#rgb, depth

        rgb_0 = torch.cat([rgb,depth],dim=1) #(1,4,480,480)
        rgb1 = self.conv1_rgb(rgb_0)#(1,64,240,240)
        depth1 = self.encoder_depth.forward_first_conv(depth)  # (1,64,240,240)

        fuse1_1 = F.max_pool2d(rgb1, kernel_size=3, stride=2, padding=1)#(1,64,120,120)
        fuse2_2 = F.max_pool2d(depth1, kernel_size=3, stride=2, padding=1)#(1,64,120,120)

        # block 2
        rgb2   = self.encoder_rgb.forward_layer1(fuse1_1)  # (1,64,120,120)
        depth2 = self.encoder_depth.forward_layer1(fuse2_2)  # (1,64,120,120)

        # block 2
        rgb3   = self.encoder_rgb.forward_layer2(rgb2)  # (1,128,60,60)
        depth3 = self.encoder_depth.forward_layer2(depth2)  # (1,128,60,60)

        # block 3
        rgb4   = self.encoder_rgb.forward_layer3(rgb3)  # (1,256,30,30)
        depth4 = self.encoder_depth.forward_layer3(depth3)  # (1,256,30,30)

        # block 4
        rgb5   = self.encoder_rgb.forward_layer4(rgb4)  # (1,512,15,15)
        depth5 = self.encoder_depth.forward_layer4(depth4)  # (1,512,15,15)

        rgb6 = self.rgbconv(rgb5).flatten(2).permute(0,2,1) #(1,256,15,15)--->(1,225,256)
        depth6 = self.depthconv(depth5).flatten(2).permute(0,2,1)#(1,256,15,15)--->(1,225,256)

        # position encoding
        rgb_pe ,depth_pe = self.spge(rgb2,depth2)

        # transformer fusion
        out = self.transfusion(rgb6,depth6,rgb_pe , depth_pe)#(1,256,15,15)

        # decoder

        final_out = self.decoder(rgb2,out) #(1,40,480,480)

        return final_out


def main():
    height = 480,
    width = 480,

    # --------------------------------实例化ESAnet-------------------------
    model = ESANet()  # (传入参数)

    model.eval()
    rgb_image = torch.randn(1, 3, 480, 480)
    depth_image = torch.randn(1, 1, 480, 480)
    summary(model, input_size=[(3, 480, 480), (1, 480, 480)], device='cpu')
    with torch.no_grad():
        output = model(rgb_image, depth_image)
    print(output.shape)


if __name__ == '__main__':
    main()

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/36070.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

关于conda、虚拟环境、镜像通道、pip、pycharm解释器配置的一些总结

目录conda与虚拟环境相关命令查看当前存在哪些虚拟环境创建虚拟环境克隆虚拟环境删除指定虚拟环境删除指定虚拟环境中某个包设置国内镜像添加Anaconda的TUNA镜像添加USTC仓库镜像设置搜索时显示通道地址Conda 附加库查看channels恢复默认镜像(即删除手动配置的全部镜…

TCP/IP五层协议栈(1)

1.应用层协议 应用层协议相对来说比较简单,因为其他层的协议属于硬件上的.相对程序员来说已经固定了.不需要自己设计和实现了. 设计应用层协议有两个要点 要约定好传输的数据还要约定好数据的格式 1.1.协议模板 虽说可以自己设计,不过当前已经有很多模板被大佬设计好了.这…

51单片机语音进店迎宾器统计进店人数可定制播报铃声(可选PCB)

实践制作DIY- GC0115-语音进店迎宾器统计进店人数 一、功能说明: 基于51单片机设计-语音进店迎宾器统计进店人数 功能介绍: 硬件组成:STC89C52(AT89C51/52)单片机串口语音播报模块2个红外对射传感器一个按键&#x…

C++不知算法系列之高精度数值处理算法

1. 前言 什么是高精度数值处理算法? 高精度数值指因受限于计算机硬件的制约,超过计算机所能存储范围的数值。既然不能存储,更谈不上运算。 对此类数值的加、减、乘、除运算需要提供针对性的算法方能获取到结果。此类算法的设计思路因有别于…

VMware环境配置

文章目录一、环境配置1、修改主机名,然后切换到root用户2、确认主机网关a.确认windows主机网关b.确认虚拟机主机网关3、修改网络配置4、设置DNS域名解析的配置文件resolv.conf。5、修改hosts文件6、重启网络服务7、验证网络服务a.虚拟机ping百度b.主机ping虚拟机二、…

vscode使用restClient实现各种http请求

vscode使用restClient实现各种http请求 一,安装插件 首先,我们要在vscode的扩展中,搜索rest Client,然后安装它,这里我已经安装过了。 安装后,我们就可以使用rest client插件进行http各种操作了。 二&a…

MySQL---权限控制和用户、角色管理详解

目录 一、MySQL用户登录 二、用户管理 三、权限控制 四、角色管理 一、MySQL用户登录 一般在本机上我们的登录命令: mysql -u root -p密码这里介绍命令的作用: -u 指定用户名 -h 指定主机地址(默认为localhost) -p 指定用…

CHRONY - 时钟同步

一、同步公网时间源 安装chrony: yum install chrony -y 查看chrony的重要配置文件:rpm -ql chrony 修改chrony配置文件: vim /etc/chrony.conf 查看修改了的配置文件 egrep -v "^#|^$" /etc/chrony.conf

Baklib知识库|为什么知识共享工具对减少内部知识缺口至关重要

你的企业是否存在知识缺口? 知识缺口——没有对关键知识进行研究和记录,以有效地传播信息,并教育企业内外的用户——可能是寻求生产率最大化并最终实现利润增长的公司的一个关键缺陷。知识(或数据、关键信息等)是你的…

你一定要知道的四个程序员接外包的网站,悄悄把技术变现!

说起程序员接外包的网站,你在网上一搜发现数不胜数,但真正有用的却很少。然后你想快速的找到几个靠谱的网站,去看了看接外包的攻略,你会发现排雷的又数不胜数。一时间你还真不知道要选哪一个。 接下来就为大家推荐几个我认为比较…

165 pbi-utils 使用文档

165 pbi-utils 使用文档 一、背景 先来说一下为什么会有 pbi-utils 这个小工具吧。在我日常做演示的示例文件的时候,每次都要重新搞一次 Power BI Desktop,就想能不能像 PPT 模板那样,搞一个模板,然后更专心的去专注内容本身呢&…

网络规划.1.扩展.IP地址规划

第一章 IPv4 Internet中有数百万台以上的主机和路由器,IP 地址可以确切地标识它们。- 一台主机至少拥有一-个IP地址。任何两台主机的IP地址不能相同,但是允许一台主机拥有多个IP地址。如果一台计算机虽然也连入Internet, 使用Internet 的某些功能&#…

Android 开发中原始音频的录播和和自定义音频控制条的讲解及实战(超详细 附源码)

需要源码请点赞关注收藏后评论区留下QQ~~~ 一、原始音频的录播 语音通话功能要求实时传输,如果使用MediaRecorder与MediaPlayer组合,那么只能整句话都录完并编码好了才能传给对方去播放,这个时效性太差。 此时用到音频录制器AudioRecord与音…

[附源码]Python计算机毕业设计二手书交易系统

项目运行 环境配置: Pychram社区版 python3.7.7 Mysql5.7 HBuilderXlist pipNavicat11Djangonodejs。 项目技术: django python Vue 等等组成,B/S模式 pychram管理等等。 环境需要 1.运行环境:最好是python3.7.7,…

企业如何提供安全方面的投资回报率?

为什么增加在恢复方面的投资可以提高投资回报率? 所有企业都会认可安全的重要性,但在安全上的投入却经常令人迷惑。 一方面,由于安全威胁在不断变化,所以,安全建设维护需要长期持续大量投入。另一方面,长期大量投入后…

第6章 集成第3方依赖注入中间件“Autofac”

“Blog.Core-master”示例程序中接口及其具体实现类的注入操作,是通过第3方依赖注入中间件“Autofac”来以反射方式把Service.dll 程序集中所有接口及其具体实现类的实例依赖注入内置容器中。.Net(Core).x框架是中的内置依赖注入容器是不支持程序集注入的。 从最佳实…

葡萄糖-聚乙二醇-人血清白蛋白,HAS-PEG-Glucose,人血清白蛋白-PEG-葡萄糖

葡萄糖-聚乙二醇-人血清白蛋白,HAS-PEG-Glucose,人血清白蛋白-PEG-葡萄糖 人血清白蛋白(HSA)是一种高度水溶性的球状单体血浆蛋白,相对分子量为67KDa,由585个氨基酸残基、一个巯基和17个二硫键组成。在纳米颗粒载体中,HSA纳米颗粒与各种药物…

直播预告丨中高频多因子库存储的最佳实践

因子挖掘是量化交易的基础。随着历史交易数据日益增多,交易市场量化竞赛的不断升级和进化,量化投研团队开始面对数据频率高、因子数量多的场景,以10分钟线10000个因子5000个股票为例,一年的因子数据约为 2.3T 左右,1分…

BUUCTF Reverse/firmware

安装工具firmware-mod-kit &#xff0c;可以参考这个firmware-mod-kit工具安装和使用说明 最后make的时候报了个错&#xff0c;参考这个&#xff1a;linux安装firmware-mod-kit /firmware-mod-kit/src/uncramfs/uncramfs.c 加上头文件&#xff1a;#include <sys/sysmacros.h…

想做TikTok跨境电商?如何运营?

想做TikTok跨境电商&#xff1f;如何运营&#xff1f; 据官方报道&#xff1a; 2021年8月TikTok全球月活跃用户数已突破10亿 自2020年7月以来增长了45% 相比2020年初则增长了一倍 月活跃用户能达10亿&#xff0c;覆盖150个国家地区 你只要爆了一条视频 分一小杯羹就能舒…