【深度学习】U-net网络结构搭建 | pytorch

news2025/7/27 12:00:38

文章目录

  • 前言
  • 一、U-net网络结构复现(上采样部分采用转置卷积nn.ConvTranspose2d)
    • 1.1、整体结构介绍
    • 1.2、encoder部分实现(左边网络部分)
    • 1.3、decoder部分实现(右边网络部分)
    • 1.4、整个网络搭建
  • 二、U-net网络结构复现(上采样部分采用上采样nn.Upsample)


前言

U-net论文地址:U-net论文
参考的一个还不错的开源项目地址:U-net开源项目地址
参考视频:视频1
视频2

一、U-net网络结构复现(上采样部分采用转置卷积nn.ConvTranspose2d)

1.1、整体结构介绍

首先我们看看论文里面的网络结构:
在这里插入图片描述
U-net网络是典型的encoder-decoder,整个呈U字形
1)左边的网络,随着不断向下,宽高减小,通道数增加
2)右边的网络,随着不断向上,宽高变大,通道数减少,最后恢复到和原来差不多的形状
3)最后输出的通道数是需要分类的个数
4)网络里使用的都是3X3卷积,如果想使得最后的输出图和原图宽高一致,可以使用311卷积(卷积核大小为3,步长1,扩充1)
5)灰色的线一共有4根,指的是“特征融合”,即两次卷积后生成的特征图与箭头指向的特征图进行torch.cat()操作,在通道维度进行拼接(dim=1,因为图片输入的维度为4,通道维数正好为1,b,c,h,w:0,1,2,3)
6) 右边decoder结构中,对宽高进行增大恢复的操作有两种:上采样nn.Upsample和转置卷积nn.ConvTranspose2d。在开源项目中,两种对应的方法代码是不一样的,解下来的部分我们首先讨论转置卷积nn.ConvTranspose2d进行扩大。

1.2、encoder部分实现(左边网络部分)

左边的网洛其实和VGG16比较类似,接下来我们具体看一下每部分的构造:

在这里插入图片描述

1)上图中,黄色框里为连续两次的311卷积(当然弄310卷积也可),我们可以定义为如下代码:

#搭建双倍卷积块
class doubleconv(nn.Module):
    def __init__(self,in_channels,out_channels):
        super(doubleconv,self).__init__()

        self.conv2=nn.Sequential(
            nn.Conv2d(in_channels,out_channels,3,1,1,bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1,bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )
    def forward(self, x):
        return self.conv2(x)  #更加简洁

这种打黄色的部分在左边网络中出现了5次,右边的网络出现了4次
2)这种双倍卷积块,加上一个向下2倍池化,就可以合成我们红色框里的内容了:

#搭建下采样模块(里面包含了双倍卷积块)
class down(nn.Module):
    def __init__(self,in_channels,out_channels):
        super(down, self).__init__()
        self.pool_conv2=nn.Sequential(
            nn.MaxPool2d(2),
            doubleconv(in_channels,out_channels)
        )
    def forward(self,x):
        x=self.pool_conv2(x)
        return x

1.3、decoder部分实现(右边网络部分)

3)再看右边的网络,一个向上的311转置卷积,再来一个双倍卷积块,就能合成紫色框里的内容了:

#搭建上采样模块(里面包含了双倍卷积块)
class up(nn.Module):
    def __init__(self,in_channels,out_channels):
        super(up, self).__init__()
        self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
        self.conv = doubleconv(in_channels, out_channels)

    def forward(self,x1,x2):
        x1=self.up(x1)
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, x1], dim=1)
        x=self.conv(x)
        return x

好了,这部分的代码是有点绕的,突然多了很多看不懂的东西出来,我们一步步看:
首先是

        self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
        self.conv = doubleconv(in_channels, out_channels)

明明这里前面nn.ConvTranspose2dd(in_channels, in_channels // 2, kernel_size=2, stride=2)已经将通道数减半了,为什么后面使用双倍卷积块时,初始的通道块依然是in_channels呢?
我们看看现场图,以最下面的灰色箭头为例子吧:
在这里插入图片描述
绿色的箭头正是nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2),转置卷积完成之后生成的特征图,会和之前左边网络的特征图发生特征融合,正是因为这个原因,特征图的通道数翻倍了,所以才会出现self.conv = doubleconv(in_channels, out_channels)里初始通道数依然是in_channels的情况。

然后是

        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])

为何会无缘无故有这一步操作呢?:
我们再回到案发现场:
在这里插入图片描述
注意看我打红色圈的部分,这是两个进行特征融合的特征图宽高,一个是64X64,一个是56X56。正常情况下这肯定是无法融合的(有关特征图的融合具体可以参考我之前的博客:特征融合的方法),需要对特征图进行一定裁剪;同时左边网络是向下16倍采样,56不能被16整除,为防止后续计算出现问题,我们需要将56X56的特征图进行一定操作变成64X64 ,这段代码的作用的就是这样的。

我们在实操一下代码就知道了”:

import torch
import torch.nn.functional as F

x1=torch.rand(1,512,56,56)
x2=torch.rand(1,512,64,64)
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
print(x2.size()[2])
print(x1.size()[2])
print(x2.size()[3])
print(x1.size()[3])
print(diffX)
print(diffY)

x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                diffY // 2, diffY - diffY // 2])
print(x1.shape)
x = torch.cat([x2, x1], dim=1)
print(x.shape)

在这里插入图片描述
融合完成后,最后通道数翻倍了,宽高为64X64

4)在网络的最后,有一个1X1卷积块,用以调整最后的特征图的输出通道数

class outconv(nn.Module):
    def __init__(self,in_channels, out_channels):
        super(outconv, self).__init__()
        self.outconv=nn.Conv2d(in_channels, out_channels,1)
    def forward(self,x):
        return self.outconv(x)

1X1卷积块不改变宽高,只改变通道数。

1.4、整个网络搭建

之前的模块搭建完毕后,即可完成最后网络的搭建:

class UNet(nn.Module):
    def __init__(self, n_channels, n_classes):
        super(UNet, self).__init__()

        self.inc = doubleconv(n_channels, 64)
        self.down1 = down(64, 128)
        self.down2 = down(128, 256)
        self.down3 = down(256, 512)

        self.down4 = down(512, 1024 )
        self.up1 = up(1024, 512)
        self.up2 = up(512, 256)
        self.up3 = up(256, 128)
        self.up4 = up(128, 64)
        self.outc = outconv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        out = self.outc(x)
        return out

其中n_channels是输入图片的通道数,一般我们就为3,最后n_classes为需要分类的数量,也就是最后输出的通道数。
特征融合一共进行了四次,发生在如下代码处:

        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)

我们实例化网络验证一下,假设输入的图像为3通道的572X572图像,我们最后分类的个数为2:

unet=UNet(3,2)
A =torch.rand(1,3,572,572)
B=unet(A)
print(B.shape)

最后输出的形状为
在这里插入图片描述
当然因为我们这里使用的是311卷积,和论文里的输出特征图有一定区别(文章里的特征图宽高不断-2,应该使用的是310卷积块)

二、U-net网络结构复现(上采样部分采用上采样nn.Upsample)

这个下次来写,要赶回去洗头了。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/16711.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

React源码分析5-commit

前两章讲到了,react 在 render 阶段的 completeUnitWork 执行完毕后,就执行 commitRoot 进入到了 commit 阶段,本章将讲解 commit 阶段执行过程源码。 总览 commit 阶段相比于 render 阶段要简单很多,因为大部分更新的前期操作都…

Dubbo框架基本使用

一:软件架构的演练过程【了解】 单体应用架构--->垂直应用架构--->分布式架构(SOA架构/微服务架构) 1.单体应用架构 单体应用架构,就是将一个系统的多个模块做成一个项目,然后部署到tomcat服务器上 优点: 项目架…

第01章+Java概述

课程链接:韩顺平Java_程序举例_哔哩哔哩_bilibili 什么叫程序 程序:计算机执行某些操作或解决某个问题而编写的一系列有序指令的集合。 Java版本迭代 官网介绍: Oracle Java SE Support Roadmap LTS为长期支持版本:推荐使用…

IQM的Unimon:一种新的量子比特,可促进量子计算机的实用化

​ 量子处理器中unimon 量子比特的艺术效果图。(图片来源:网络) 来自芬兰IQM量子计算机公司、阿尔托大学和芬兰VTT技术研究中心的一组科学家发现了一种新的超导量子比特——unimon,可提高量子计算的准确性。该团队已经实现了第一…

解读阿里Q2财报:阿里云的跨周期引擎

昨天,阿里巴巴公布今年6月到9月财务业绩,显示云业务总收入为267.6亿元,在除去阿里内部使用的额度后,抵销跨分部交易后营收为207.57亿元,比上一个季度增长超17%。 具体看,值得关注的有三点: 1、…

Python爬取公交线路信息及站点shp数据 文末附数据下载地址

本篇主要记录爬取公交网整个过程,由于这次所用方法虽比较常规,但由于该网站页面内容转码原因以及遍历链接较多,所以小坑还是比较多的,特在此进行记录。 以前爬过百度地图,当时用的是API平台,加上网站比较规范,所以标签节点什么的都比较清晰,但这次由于特殊原因所选择的…

对JavaScript中的Math.random随机函数破解

什么是随机 在通常的说法中,随机性是指事件中明显实际缺乏可预测性,事件、符号或步骤的随机序列通常没有顺序 举个例子,比如我们在抛硬币,硬币的结果取决于很多因素,比如说我们施加的力,空气阻力&#xff…

Linus shell 在一个脚本中调用另外一个脚本变量

1.新建public.sh文件,并添加以下内容: 2.新建ceshi.sh文件,并添加以下内容: 3.在终端赋予ceshi.sh文件执行权限,并运行该文件。

角度回归(复数与欧拉公式,L1,L2)

文章目录1 BEV下,Eula 损失函数2 BEV下,PointPillars使用sin联合SmoothL13 透视图下, MultiBin 全局方向损失4 L1/L2-norm 的周期损失函数1 BEV下,Eula 损失函数 Yolo-complex的论文中,对于BEV视角下,目标…

SDN和NFV的区别?

前言 网络功能虚拟化(Network Functions Virtualization,NFV)是一种关于网络架构的概念。我们平时使用的x86服务器由硬件厂商生产,在安装了不同的操作系统以及软件后实现了各种各样的功能。而传统的网络设备并没有采用这种模式&am…

2000-2019年各省产业结构合理化指数(干春晖泰尔指数)

2000-2019年省级产业结构合理化指数(干春晖泰尔指数) 1、来源:统计NJ及各省统计NJ 2、时间2000-2019年 3、数据说明:含原始数据和计算过程 4、范围包括全国31省 5、指标包括:各省总产值、第一产业增加值、第二产业…

C++基础知识要点--表达式 (Primer C++ 第五版 · 阅读笔记)

目录表达式基础算术运算符逻辑和关系运算符赋值运算符递增和递减运算符成员访问运算符条件运算符位运算符sizeof运算符逗号运算符类运算符运算符优先级表表达式 基础 当一个对象被用作右值的时候,用的是对象的 值(内容);当对象被用作左值的时候&#x…

Linux 信号

概念:信号不是信号量,信号量是进程间的一种通信方式,信号是系统中的软件中断,指一种事件通知机制,通知进程发生了某个事件,打断当前的操作,去处理这个事件。 种类:一共有62种信号&a…

Linux之用户管理、权限管理、程序安装卸载

一. 用户管理 1. 查看账户 (1). 查看当前账号:whoami ​​(2). 查看系统当前登录的账号:who ​​补充常用选项: ​​(3). 查看系统所有的账号: cat /etc/passwd ​​2. exit:退出登录账户 如果是图形界面&#xff0c…

curl命令的常用操作

curl是非常实用的命令行工具,用来与服务器之间传输数据。它的命令行参数多达几十种。 在Linux环境中使用curl命令可以进行接口测试。利用curl对http协议发送Get/Post/Delete/Put请求,同时还可以携带header来满足接口的特定需求。 curl命令的语法 curl[options] [U…

Linux03-网络设置

一、说明 在上一节,咱使用VMware安装了虚拟机,网络设置选择了 “桥接模式” ,本节咱们来具体讨论一下网络连接方式和网络设置。 实验环境:CentOS7 VMware 二、桥接模式 当我们设置桥接模式时,虚拟机是直接使用物理…

eNSP出现错误,错误代码40暴力解决方案

如果你和我一样,在eNSP中启动一个设备时发生了错误,错误代码为40,那么这篇文件可能会帮助你。 首先你可以仔细地按照这篇说明中的做法进行操作,如果你电脑也是win10,并且之前没有安装过wireshark,virtualb…

后端总说他啥也没动,我从线上调了一下测试接口,你再说一句动没动

◇ 不知道广大前端同学有没有过这样的经历,在做新需求联调的时候,原本上一个版本已经做的好好的功能,前后端已经联调好的。这次做需求的时候,测试发现好多地方都不对了。 ◇ 开发人员经常说的一句话就是:我啥也没动啊…

Java -- 每日一问:你了解Java应用开发中的注入攻击吗?

典型回答 注入式(Inject)攻击是一类非常常见的攻击方式,其基本特征是程序允许攻击者将不可信的动态内容注入到程序中,并将其执行,这就可能完全改变最初预计的执行过程,产生恶意效果。 下面是几种主要的注…

Web前端:2022年Web开发者的五大CSS工具

据相关数据统计,2018年至2028年,网络开发人员的就业预计将增长13%,这意味着网站开发者的需求量很大,而企业需要专业人员来构建网站,而高效制作优秀网站的最佳方法是拥有最好的web开发工具。 对优秀web开发工具的需求使…