UNet网络 图像分割模型学习

news2025/7/18 13:09:46

UNet 由Ronneberger等人于2015年提出,专门针对医学图像分割任务,解决了早期卷积网络在小样本数据下的效率问题和细节丢失难题。

一 核心创新

1.1对称编码器-解码器结构

实现上下文信息高分辨率细节的双向融合

如图所示:编码器进行了4步(红框)到达了瓶颈层(紫框),每一步包含两次3x3卷积+ReLU并通过通过2x2最大池化下采样,到达瓶颈层后,解码器也进行了4步(绿框),使用了转置卷积上采样后与编码器对应层特征拼接(跳跃连接(灰色箭头))后再进行两次卷积。

可以看出解码器和编码器非常的对称,呈现一个U型,所以叫UNet。

其中:
编码器:通过池化逐渐扩大感受野。

解码器:逐步恢复空间分辨率,精确定位目标边界。

跳跃连接:将编码器特征与解码器特征拼接,融合多级信息解决深层网络定位精度下降的问题

1.2跳跃连接(Skip Connections)

解决深层卷积神经网络中空间信息丢失细节模糊的核心问题。

因为编码器下采样会丢失细节,而解码器上采样又难以完全恢复位置信息,所以使用跳跃链接来补偿细节。

1.2.1数学形式表达

设编码器第 $l$ 层输出为 $E_l \in \mathbb{R}^{H_l \times W_l \times C_l}$ , 解码器第 $l$ 层输入为 $D_l \in \mathbb{R}^{H_l \times W_l \times C_{l}'}$, 则跳跃连接操作:

$ D_l' = \text{Concat}(E_l, \text{UpSample}(D_{l+1})) $

Concat : 沿通道维度拼接(Channel-wise Concatenation)

UpSample:  转置卷积/双线性插值将解码器输出的分辨率提升至与编码器相同

1.2.2特征融合方法

编码器每层的输出须与解码器对应层上采样后的尺寸匹配,拼接后总通道数为两者之和。

(黑色圆圈)

# PyTorch代码示例:拼接编码器和解码器特征
def forward(self, decoder_feat, encoder_feat):
    # decoder_feat: [B, C1, H, W] 
    # encoder_feat: [B, C2, H, W]
    merged = torch.cat([decoder_feat, encoder_feat], dim=1)  # 沿通道拼接
    return merged  # 结果维度:[B, C1+C2, H, W]

 1.3端到端精细分割(End-to-End Fine Segmentation)

在少量标注数据下仍能输出像素级预测

直接从原始输入图像生成像素级预测的模型设计范式,无需手动设计特征提取器或多阶段后处理。

1.3.1核心

全流程自动映射:输入 → 特征学习 → 高精度分割结果,中间过程由网络自动优化

细节敏感机制:通过多层次特征融合、边界增强模块等手段保证细粒度分割

无后处理输出:输出可直接使用,无需形态学后处理

1.3.2技术实现

编码器:通过卷积与池化逐层提取高层语义(形状、位置)

# 编码器层示例:每次下采样通道数翻倍
class Encoder(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),#卷积
            nn.BatchNorm2d(out_ch),#标准化(归一+线性变换)
            nn.ReLU(),#非线性激活
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(),
            nn.MaxPool2d(2)#最大值池化
        )
    
    def forward(self, x):
        return self.block(x)

解码器:上采样恢复分辨率 + 跳跃连接补充细节

# 解码器层示例:特征拼接后卷积
class Decoder(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2)
        self.conv = nn.Sequential(
            nn.Conv2d(out_ch*2, out_ch, 3, padding=1), # 拼接后通道数翻倍
            nn.BatchNorm2d(out_ch),
            nn.ReLU()
        )
    
    def forward(self, x, skip):
        x = self.up(x)
        x = torch.cat([x, skip], dim=1)  # 与编码器特征拼接
        return self.conv(x)

改良1: 注意力引导跳跃连接:通过空间注意力强化边缘区域(在跳跃连接前应用空间注意力,突出边缘信息)

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super().__init__()
        self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2)
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x):
        avg = torch.mean(x, dim=1, keepdim=True)
        max_pool, _ = torch.max(x, dim=1, keepdim=True)
        concat = torch.cat([avg, max_pool], dim=1)  # 沿通道维度拼接均值和最大值
        mask = self.sigmoid(self.conv(concat))      # 生成空间注意力掩码
        return x * mask                             # 加权增强关键区域

改良2: 多尺度损失监督:在不同解码层注入辅助损失。

class MultiScaleLoss(nn.Module):
    def __init__(self, losses):
        super().__init__()
        self.losses = losses  # 各层对应的损失函数列表
    
    def forward(self, preds, target):
        total_loss = 0
        for pred, loss_fn in zip(preds, self.losses):
            # 将目标下采样至与当前预测同尺寸
            _, _, H, W = pred.shape
            resized_target = F.interpolate(target, size=(H,W), mode='nearest')
            total_loss += loss_fn(pred, resized_target)
        return total_loss

适用性扩展:该范式可迁移至其他密集预测任务,如卫星影像分析、自动驾驶场景理解等。

二 与传统分割模型对比

模型优势局限性
FCN全卷积保留空间信息输出分辨率粗糙,跳跃连接简单
SegNet使用池化索引提升精度特征复用效率低
DeepLab空洞卷积扩大感受野小目标分割边缘模糊
UNet对称结构+密集跳跃连接,细节恢复原版对大尺度变化敏感

三 UNet的改良方法 

3.1跨尺度空洞卷积替换编码器的普通卷积层

在底层使用扩张率=1捕捉细节,高层使用d=3或5扩大感受野。

# 原编码器卷积块
self.encoder_conv = nn.Sequential(
    nn.Conv2d(in_ch, out_ch, 3, padding=1),
    nn.ReLU(),
    nn.Conv2d(out_ch, out_ch, 3, padding=1),
    nn.ReLU()
)

# 改进:跨尺度空洞卷积模块
self.encoder_conv = CrossScaleDilatedConv(in_ch, out_ch)

3.2融入密集块融合增强跳跃连接的特征传递

在编码器和解码器拼接前加入密集块

class ImprovedSkipConnection(nn.Module):
    def __init__(self, in_ch):
        super().__init__()
        self.dense_block = DenseBlock(num_layers=4, in_channels=in_ch)
    
    def forward(self, enc_feat, dec_feat):
        enc_processed = self.dense_block(enc_feat)  # 特征增强
        merged = torch.cat([enc_processed, dec_feat], dim=1)
        return merged

# 在UNet解码器中应用
def forward(self, x):
    # ... 编码过程
    d4 = self.upconv4(d5)
    d4 = self.skip_conn4(e4, d4)  # 使用改进的跳跃连接
    d4 = self.decoder_conv4(d4)
    # ...

四 核心代码(未改良)

class UNet(nn.Module):
    def __init__(self, n_class=1):
        super().__init__()
        # 编码器
        self.enc1 = EncoderBlock(3, 64)
        self.enc2 = EncoderBlock(64, 128)
        self.enc3 = EncoderBlock(128, 256)
        self.enc4 = EncoderBlock(256, 512)
        self.bottleneck = EncoderBlock(512, 1024)
        
        # 解码器
        self.upconv4 = UpConv(1024, 512)
        self.dec4 = DecoderBlock(1024, 512)  # 输入1024因拼接
        self.upconv3 = UpConv(512, 256)
        self.dec3 = DecoderBlock(512, 256)
        self.upconv2 = UpConv(256, 128)
        self.dec2 = DecoderBlock(256, 128)
        self.upconv1 = UpConv(128, 64)
        self.dec1 = DecoderBlock(128, 64)
        
        self.final = nn.Conv2d(64, n_class, kernel_size=1)
    
    def forward(self, x):
        # 编码
        e1 = self.enc1(x)
        e2 = self.enc2(F.max_pool2d(e1, 2))
        e3 = self.enc3(F.max_pool2d(e2, 2))
        e4 = self.enc4(F.max_pool2d(e3, 2))
        bn = self.bottleneck(F.max_pool2d(e4, 2))
        
        # 解码
        d4 = self.dec4(self.upconv4(bn), e4)
        d3 = self.dec3(self.upconv3(d4), e3)
        d2 = self.dec2(self.upconv2(d3), e2)
        d1 = self.dec1(self.upconv1(d2), e1)
        return torch.sigmoid(self.final(d1))

class EncoderBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU()
        )
    
    def forward(self, x):
        return self.conv(x)

class UpConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2)
    
    def forward(self, x):
        return self.up(x)

class DecoderBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = EncoderBlock(in_ch, out_ch)
    
    def forward(self, x, skip):
        x = torch.cat([x, skip], dim=1)  # 通道拼接
        return self.conv(x)

UNet凭借其优雅的对称结构和密集跳跃连接,成为医学图像分割的基准模型。通过集成跨尺度空洞卷积密集块融合等模块,可显著提升其对多尺度目标的适应性。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2373856.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

使用 SHAP 进行特征交互检测:揭示变量之间的复杂依赖关系

我们将探讨如何使用 SHAP(SHapley 加法解释)来检测和可视化机器学习模型中的特征交互。了解特征组合如何影响模型预测对于构建更透明、更准确的模型至关重要。SHAP 有助于揭示这些复杂的依赖关系,并使从业者能够以更有意义的方式解释模型决策…

Python-MCPInspector调试

Python-MCPInspector调试 使用FastMCP开发MCPServer,熟悉【McpServer编码过程】【MCPInspector调试方法】-> 可以这样理解:只编写一个McpServer,然后使用MCPInspector作为McpClient进行McpServer的调试 1-核心知识点 1-熟悉【McpServer编…

Java设计模式-策略模式(行为型)

策略模式详解 一、策略模式概述 1.1 基本概念 策略模式是一种行为型设计模式,它主要用于处理算法的不同变体。其核心思想是将算法的定义与使用分离开来,把一系列具体的算法封装成独立的策略类,这些策略类实现相同的策略接口。客户端可以在…

html body 设置heigth 100%,body内元素设置margin-top出滚动条(margin 重叠问题)

今天在用移动端的时候发现个问题&#xff0c;html,body 设置 height&#xff1a;100% 会出现纵向滚动条 <!DOCTYPE html> <html> <head> <title>html5</title> <style> html, body {height: 100%; } * {margin: 0;padding: 0; } </sty…

C语言模糊不清的知识

1、malloc、calloc、realloc的区别和用法 malloc实在堆上申请一段连续指定大小的内存区域&#xff0c;并以void*进行返回&#xff0c;不会初始化内存。calloc与malloc作用一致&#xff0c;只是calloc会初始化内存&#xff0c;自动将内存清零。realloc用于重新分配之前通过mallo…

如何配置光猫+路由器实现外网IP访问内部网络?

文章目录 前言一、网络拓扑理解二、准备工作三、光猫配置3.1 光猫工作模式3.2 光猫端口转发配置&#xff08;路由模式时&#xff09; 四、路由器配置4.1 路由器WAN口配置4.2 端口转发配置4.3 动态DNS配置&#xff08;可选&#xff09; 五、防火墙设置六、测试配置七、安全注意事…

springboot3+vue3融合项目实战-大事件文章管理系统获取用户详细信息-ThreadLocal优化

一句话本质 为每个线程创建独立的变量副本&#xff0c;实现多线程环境下数据的安全隔离&#xff08;线程操作自己的副本&#xff0c;互不影响&#xff09;。 关键解读&#xff1a; 核心机制 • 同一个 ThreadLocal 对象&#xff08;如示意图中的红色区域 tl&#xff09;被多个线…

【高数上册笔记篇02】:数列与函数极限

【参考资料】 同济大学《高等数学》教材樊顺厚老师B站《高等数学精讲》系列课程 &#xff08;注&#xff1a;本笔记为个人数学复习资料&#xff0c;旨在通过系统化整理替代厚重教材&#xff0c;便于随时查阅与巩固知识要点&#xff09; 仅用于个人数学复习&#xff0c;因为课…

c++STL-string的模拟实现

cSTL-string的模拟实现 string的模拟实现string的模拟线性表的实现构造函数析构函数获取长度&#xff08;size&#xff09;和获取容量&#xff08;capacity&#xff09;访问 [] 和c_str迭代器&#xff08;iterator&#xff09;交换swap拷贝构造函数赋值重载&#xff08;&#x…

YashanDB(崖山数据库)V23.4 LTS 正式发布

2024年回顾 2024年11月我们受邀去深圳参与了2024国产数据库创新生态大会。在大会上崖山官方发布了23.3。这个也是和Oracle一样采用的事编年体命名。 那次大会官方希望我们这些在一直从事在一线的KOL帮助产品提一些改进建议。对于这样的想法&#xff0c;我们都是非常乐于合作…

python 写一个工作 简单 番茄钟

1、图 2、需求 番茄钟&#xff08;Pomodoro Technique&#xff09;是一种时间管理方法&#xff0c;由弗朗西斯科西里洛&#xff08;Francesco Cirillo&#xff09;在 20 世纪 80 年代创立。“Pomodoro”在意大利语中意为“番茄”&#xff0c;这个名字来源于西里洛最初使用的一个…

PyCharm 加载不了 conda 虚拟环境,不存在的

#工作记录 前言 在开发过程中&#xff0c;PyCharm 无法加载 Conda 虚拟环境是常见问题。 在不同情况下&#xff0c;“Conda 可执行文件路径”的指定可能会发生变化&#xff0c;不会一尘不变&#xff0c;需要灵活处置。 以下是一系列解决此问题的经验参考。 检查 Conda 安装…

设计模式学习整理

目录 UML类图 设计模式六大原则 1.单一职责原则 2.里氏替换原则 3.依赖倒置原则 4.接口隔离原则 5.迪米特法则(最少知道原则) 6.开(放封)闭原则 设计模式分类 1.创建型模式 2.结构型模式 4.行为型模式 一、工厂模式(factory——简单工厂模式和抽象工厂模式) 1.1、…

二分查找的理解

#define _CRT_SECURE_NO_WARNINGS #include <stdio.h>int binary_search(int arr[], int k, int sz) {int left 0;int right sz - 1;//这个是下标&#xff0c;减一是因为在0开始的&#xff0c;怕越界&#xff08;访问无效&#xff09;while (left < right){int mid…

【Java】线程实例化 线程状态 线程属性

线程实例化 继承 Thread 类 创建类继承自 Thread 类 . class MyThread extends Thread重写 run() 方法 . Overridepublic void run(){// 线程要执行的任务代码}实例化自定义线程类 . 实现 Runnable 接口 创建类实现 Runnable 接口 . class MyRunnable implements Runnable实…

卫宁健康WiNGPT3.0与WiNEX Copilot 2.2:医疗AI创新的双轮驱动分析

引言:医疗AI的双翼时代 在医疗信息化的浪潮中,人工智能技术的深度融入正在重塑整个医疗行业。卫宁健康作为国内医疗健康和卫生领域数字化解决方案的领军企业,持续探索AI技术在医疗场景中的创新应用。2025年5月10日,在第29届中国医院信息网络大会(CHIMA2025)上,卫宁健康…

I2C通讯

3.1. 本章节的代码仓库 1 2 3 4 5 6 #如之前有获取则可跳过 #获取仓库 git clone https://gitee.com/LubanCat/lubancat_rk_code_storage.git#代码所在的位置 lubancat_rk_code_storage/quick_start/i2c3.2. i2c I2C(Inter&#xff0d;Integrated Circuit)是一种通用的总线协…

Excel实现单元格内容拼接

一、应用场景&#xff1a; 场景A&#xff1a;将多个单元格拼接&#xff0c;比如写测试用例时&#xff0c;将多个模块拼接&#xff0c;中间用“-”隔开 场景B&#xff1a;将某单元格内容插入另一单元格固定位置&#xff08;例如在B1中添加A1的内容&#xff09; 二、实际应用&a…

2025前端面试遇到的问题(vue+uniapp+js+css)

Vue相关面试题 vue2和vue3的区别 一、核心架构差异 特性Vue2Vue3响应式系统基于Object.defineProperty基于Proxy&#xff08;支持动态新增/删除属性&#xff09;代码组织方式Options API&#xff08;data/methods分块&#xff09;Composition API&#xff08;逻辑按功能聚合&am…

广东省省考备考(第八天5.11)—言语:逻辑填空(每日一练)

错题 解析 第一空&#xff0c;搭配“期盼”&#xff0c;且根据“生命&#xff0c;是来自上天的馈赠”&#xff0c;可知父母对孩子的出生是非常期盼的。A项“望穿秋水”&#xff0c;形容对远地亲友的殷切盼望&#xff0c;C项“望眼欲穿”&#xff0c;形容盼望殷切&#xff0c;均…