【AI模型学习】上/下采样

news2025/5/23 10:51:11

文章目录

  • 分割中的上/下采样
  • 下采样
    • SegFormer和PVT(使用卷积)
    • Swin-Unet(使用 Patch Merging)
  • 上采样
    • SegFormer(interpolate)
    • Swin-Unet(Patch Expanding)
    • 逐级interpolate的方式
    • 反卷的方式


基于Transformer架构的图像分割模型(如 SegFormer、Swin-Unet)中,上采样和下采样结构几乎是标准配置。

分割中的上/下采样

为什么需要下采样?

  1. 提取高层语义特征
    Transformer擅长全局建模,结合下采样可以:降低分辨率;聚焦于更宽范围的上下文。

  2. 减少计算成本
    原始输入图像太大,直接送入多层Transformer(特别是多头注意力)会导致计算量和显存爆炸。

为什么要上采样?

  1. 恢复空间分辨率
    Segmentation任务最终要输出与输入图像同样大小的分割mask;

  2. 细粒度定位
    但如果没有上采样、跳跃连接或融合,容易失去细节;所以上采样常结合UNet-like结构来补偿细节损失。

模型下采样方式上采样方式
SETRViT backbone patchify多层反卷积上采样
SegFormerMLP Mixer + 4阶段卷积下采样多层插值 + FFN
Swin-UnetSwin Transformer 下采样Patch expanding + skip连接

下采样

SegFormer和PVT(使用卷积)

# 输入:img.shape = [B, 3, 512, 512]

# Stage 1
x1 = Conv2d(3, 32, kernel_size=7, stride=4, padding=3)(img)    # → [B, 32, 128, 128]
x1 = x1.flatten(2).transpose(1, 2)                             # → [B, 16384, 32]
x1 = TransformerBlock(x1)

# Stage 2
x2 = Conv2d(32, 64, kernel_size=3, stride=2, padding=1)(x1_reshaped)  # → [B, 64, 64, 64]
x2 = x2.flatten(2).transpose(1, 2)                                    # → [B, 4096, 64]
x2 = TransformerBlock(x2)

# 后面还有 Stage3、Stage4 类似

Shape 演化:

Stage1: [B, 128×128=16384, 32]
Stage2: [B, 64×64=4096, 64]
Stage3: [B, 32×32=1024, 160]
Stage4: [B, 16×16=256, 256]

Swin-Unet(使用 Patch Merging)

# 初始 patch embedding(patch_size=4)
x = Conv2d(3, 96, kernel_size=4, stride=4)(img)      # [B, 96, 128, 128]
x = x.flatten(2).transpose(1, 2)                     # → [B, 16384, 96]

# Stage 1
x = SwinBlock(x)                                     # [B, 16384, 96]
x = PatchMerging(x)                                  # → [B, 4096, 192]

# Stage 2
x = SwinBlock(x)                                     # [B, 4096, 192]
x = PatchMerging(x)                                  # → [B, 1024, 384]

# Stage 3
x = SwinBlock(x)
x = PatchMerging(x)                                  # → [B, 256, 768]

Shape 演化:

Stage0: [B, 128×128=16384, 96]
Stage1: [B, 64×64=4096, 192]
Stage2: [B, 32×32=1024, 384]
Stage3: [B, 16×16=256, 768]

过程不难,只是不好描述,可以看相关教程,这里就把代码贴出来

class PatchMerging(nn.Module):
    def __init__(self, in_dim):
        super().__init__()
        self.reduction = nn.Linear(in_dim * 4, in_dim * 2)

    def forward(self, x, H, W):
        # x: [B, H*W, C] → [B, H, W, C]
        x = x.view(B, H, W, C)

        # 拆分四个方向的 token
        x0 = x[:, 0::2, 0::2, :]  # top-left
        x1 = x[:, 1::2, 0::2, :]  # bottom-left
        x2 = x[:, 0::2, 1::2, :]  # top-right
        x3 = x[:, 1::2, 1::2, :]  # bottom-right

        x = torch.cat([x0, x1, x2, x3], dim=-1)  # → [B, H/2, W/2, 4C]
        x = x.view(B, -1, 4 * C)                 # → [B, H/2*W/2, 4C]
        x = self.reduction(x)                   # → [B, H/2*W/2, 2C]
        return x


上采样

SegFormer(interpolate)

在这里插入图片描述

def forward(self, x1, x2, x3, x4):  
    # 输入来自4个Stage:
    # x1: [B, 128*128, 32]
    # x2: [B, 64*64,   64]
    # x3: [B, 32*32,  160]
    # x4: [B, 16*16,  256]

    B = x1.shape[0]

    # === 1. Linear Projection:通道都投影为 256 ===
    _x1 = self.linear1(x1).permute(0, 2, 1).reshape(B, 256, 128, 128)  # [B, 256, 128, 128]
    _x2 = self.linear2(x2).permute(0, 2, 1).reshape(B, 256,  64,  64)  # [B, 256,  64,  64]
    _x3 = self.linear3(x3).permute(0, 2, 1).reshape(B, 256,  32,  32)  # [B, 256,  32,  32]
    _x4 = self.linear4(x4).permute(0, 2, 1).reshape(B, 256,  16,  16)  # [B, 256,  16,  16]

    # === 2. 上采样到统一大小 ===
    _x2 = F.interpolate(_x2, size=(128, 128), mode='bilinear', align_corners=False)  # [B, 256, 128, 128]
    _x3 = F.interpolate(_x3, size=(128, 128), mode='bilinear', align_corners=False)  # [B, 256, 128, 128]
    _x4 = F.interpolate(_x4, size=(128, 128), mode='bilinear', align_corners=False)  # [B, 256, 128, 128]

    # === 3. 拼接所有层 ===
    fused = torch.cat([_x1, _x2, _x3, _x4], dim=1)  # [B, 4*256=1024, 128, 128]

    # === 4. 1x1卷积融合通道数 ===
    out = self.fuse_conv(fused)  # [B, 256, 128, 128]

    return out

Swin-Unet(Patch Expanding)

看图也能看出来,十分经典的U-Net结构。

在这里插入图片描述

在上采样阶段
输入:

一个高语义 token,维度为 [4C],是上一步 Patch Merging 得到的。

  1. Linear 映射
    [4C] 投影为 [C] × 4,也就是还原为 2×2 patch 每格的 C 维向量。

  2. reshape → [H, W, 2, 2, C] → [2H, 2W, C]
    把这 4 个 token 安排到一个新的空间位置(上采样 ×2)。

  3. 最终输出为:

    Token 数量 × 4 , 通道数 ÷ 2 \text{Token 数量} \times 4,\quad \text{通道数} \div 2 Token 数量×4通道数÷2

class PatchExpanding(nn.Module):
    def __init__(self, in_dim, expand_ratio=2):
        super().__init__()
        # Linear: [B, H*W, in_dim] → [B, H*W, out_dim = (expand_ratio^2) * out_channels]
        # 例如:in_dim = 512,expand_ratio = 2 → 输出 4×C = 1024
        self.linear = nn.Linear(in_dim, in_dim // 2 * expand_ratio**2)
        self.expand_ratio = expand_ratio

    def forward(self, x, H, W):
        # x: [B, H*W, C]
        B, N, C = x.shape
        R = self.expand_ratio  # 通常为 2

        # 线性投影:C → 4 * (C/2),也就是 [B, H*W, 4*C'],每个 token 展开为 2×2 的 patch
        x = self.linear(x)                      # [B, H*W, 4*C'] = [B, H*W, R*R*(C//2)]

        #  reshape 成图像形式,带有 2×2 子结构 → [B, H, W, R, R, C']
        x = x.view(B, H, W, R, R, C // 2)       # [B, H, W, 2, 2, C//2]

        #  调整顺序,将 2×2 子结构移入空间维度 → [B, H*2, W*2, C//2]
        x = x.permute(0, 1, 3, 2, 4, 5)         # [B, H, 2, W, 2, C//2]
        x = x.reshape(B, H * R, W * R, C // 2)  # [B, 2H, 2W, C//2]

        #  flatten 成 token 序列形式(可再送入 Transformer)→ [B, 4*H*W, C//2]
        x = x.view(B, -1, C // 2)               # [B, 4*H*W, C//2]
        return x


逐级interpolate的方式

  1. 输入来自编码器 4 个 stage

    • x4:[16×16, 512] ← 最深层
    • x3:[32×32, 320]
    • x2:[64×64, 128]
    • x1:[128×128, 64] ← 最浅层
  2. 通道统一:
    每个特征图先通过 1×1 卷积或 Linear 映射,统一成相同维度(如全部 → 256 或 512)

  3. 上采样与融合(逐级):

    f4 = Conv(x4)                          # [16×16]
    f3 = F.interpolate(f4, scale=2) + Conv(x3)  # → [32×32]
    f2 = F.interpolate(f3, scale=2) + Conv(x2)  # → [64×64]
    f1 = F.interpolate(f2, scale=2) + Conv(x1)  # → [128×128]
    

反卷的方式

很经典的设计,不必过多介绍。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2383828.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Unity Shader入门(更新中)

参考书籍:UnityShader入门精要(冯乐乐著) 参考视频:Bilibili《Unity Shader 入门精要》 写在前面:前置知识需要一些计算机组成原理、线性代数、Unity的基础 这篇记录一些学历过程中的理解和笔记(更新中&…

嵌入式学习的第二十六天-系统编程-文件IO+目录

一、文件IO相关函数 1.read/write cp #include <fcntl.h> #include <sys/stat.h> #include <sys/types.h> #include <stdio.h> #include<unistd.h> #include<string.h>int main(int argc, char **argv) {if(argc<3){fprintf(stderr, …

珠宝课程小程序源码介绍

这款珠宝课程小程序源码&#xff0c;基于ThinkPHPFastAdminUniApp开发&#xff0c;功能丰富且实用。ThinkPHP提供稳定高效的后台服务&#xff0c;保障数据安全与处理速度&#xff1b;FastAdmin助力快速搭建管理后台&#xff0c;提升开发效率&#xff1b;UniApp则让小程序能多端…

KNN模型思想与实现

KNN算法简介 核心思想&#xff1a;通过样本在特征空间中k个最相似样本的多数类别来决定其类别归属。"附近的邻居确定你的属性"是核心逻辑 决策依据&#xff1a;采用"多数表决"原则&#xff0c;即统计k个最近邻样本中出现次数最多的类别 样本相似性度量 …

fscan教程1-存活主机探测与端口扫描

实验目的 本实验主要介绍fscan工具信息收集功能&#xff0c;对同一网段的主机进行存活探测以及常见服务扫描。 技能增长 通过本次实验的学习&#xff0c;了解信息收集的过程&#xff0c;掌握fscan工具主机探测和端口扫描功能。 预备知识 fscan工具有哪些作用&#xff1f; …

腾讯2025年校招笔试真题手撕(三)

一、题目 今天正在进行赛车车队选拔&#xff0c;每一辆赛车都有一个不可以改变的速度。现在需要选取速度差距在10以内的车队&#xff08;车队中速度的最大值减去最小值不大于10&#xff09;&#xff0c;用于迎宾。车队的选拔按照的是人越多越好的原则&#xff0c;给出n辆车的速…

怎样通过神经网络估计股票走向

本博文将教会你如何通过神经网络建立股票模型并对其进行未来趋势估计&#xff0c;尽管博主已通过此方法取得一定利润&#xff0c;但是建议大家不要过分相信AI。本博文仅用于代码学习&#xff0c;请大家谨慎投资。 一、通过爬虫爬取股票往年数据 在信息爆炸的当今时代&#xf…

【RocketMQ 生产者和消费者】- 生产者启动源码-上报生产者和消费者心跳信息到 broker(3)

文章目录 1. 前言2. sendHeartbeatToAllBrokerWithLock 上报心跳信息3. prepareHeartbeatData 准备心跳数据4. sendHearbeat 发送心跳上报请求5. broker 处理心跳请求5.1 heartBeat 处理心跳包5.2 createTopicInSendMessageBackMethod 创建重传 topic5.3 registerConsumer 注册…

Python----循环神经网络(Word2Vec的优化)

一、负采样 基本思想&#xff1a; 在训练过程中&#xff0c;对于每个正样本&#xff08;中心词和真实上下文词组成的词对&#xff09;&#xff0c;随机采样少量&#xff08;如5-20个&#xff09;负样本&#xff08;中心词与非上下文词组成的词对&#xff09;。 模型通过区分正…

Simon J.D. Prince《Understanding Deep Learning》

学习神经网络和深度学习推荐这本书&#xff0c;这本书站位非常高&#xff0c;且很多问题都深入剖析了&#xff0c;甩其他同类书籍几条街。 多数书&#xff0c;不深度分析、没有知识体系&#xff0c;知识点零散、章节之间孤立。还有一些人Tian所谓的权威&#xff0c;醒醒吧。 …

开搞:第四个微信小程序:图上县志

原因&#xff1a;我换了一个微信号来搞&#xff0c;因为用同一个用户&#xff0c;备案只能一个个的来。这样不行。所以我换了一个。原来注册过小程序。现在修改即可。注意做好计划后&#xff0c;速度备案和审核&#xff0c;不然你时间浪费不起。30元花起。 结构&#xff1a; -…

Seata源码—7.Seata TCC模式的事务处理一

大纲 1.Seata TCC分布式事务案例配置 2.Seata TCC案例服务提供者启动分析 3.TwoPhaseBusinessAction注解扫描源码 4.Seata TCC案例分布式事务入口分析 5.TCC核心注解扫描与代理创建入口源码 6.TCC动态代理拦截器TccActionInterceptor 7.Action拦截处理器ActionIntercept…

【语法】C++的map/set

目录 平衡二叉搜索树 set insert() find() erase() swap() map insert() 迭代器 erase() operator[] multiset和multimap 在之前学习的STL中&#xff0c;string&#xff0c;vector&#xff0c;list&#xff0c;deque&#xff0c;array都是序列式容器&#xff0c;它们的…

vue vite textarea标签按下Shift+Enter 换行输入,只按Enter则提交的实现思路

注意input标签不能实现&#xff0c;需要用textarea标签 直接看代码 <template><textareav-model"message"keydown.enter"handleEnter"placeholder"ShiftEnter 换行&#xff0c;Enter 提交"></textarea> </template>&l…

深入理解 PlaNet(Deep Planning Network):基于python从零实现

引言&#xff1a;基于模型的强化学习与潜在动态 基于模型的强化学习&#xff08;Model-based Reinforcement Learning&#xff09;旨在通过学习环境动态的模型来提高样本效率。这个模型可以用来进行规划&#xff0c;让智能体在不需要与真实环境进行每一次决策交互的情况下&…

仿腾讯会议——视频发送接收

1、 添加音频模块 2、刷新图片&#xff0c;触发重绘 3、 等比例缩放视频帧 4、 新建视频对象 5、在中介者内定义发送视频帧的函数 6、完成发送视频的函数 7、 完成开启/关闭视频 8、绑定视频的信号槽函数 9、 完成开启/关闭视频 10、 完成发送视频 11、 完成刷新图片显示 12、完…

从3.7V/5V到7.4V,FP6291在应急供电智能门锁中的应用

在智能家居蓬勃发展的当下&#xff0c;智能门锁以其便捷、安全的特性&#xff0c;成为现代家庭安防的重要组成部分。在智能门锁电量耗尽的情况下&#xff0c;应急电源外接移动电源&#xff08;USB5V输入&#xff09; FP6291升压到7.4V供电可应急开锁。增强用户在锁具的安全性、…

【人工智障生成日记1】从零开始训练本地小语言模型

&#x1f3af; 从零开始训练本地小语言模型&#xff1a;MiniGPT TinyStories&#xff08;4090Ti&#xff09; &#x1f9ed; 项目背景 本项目旨在以学习为目的&#xff0c;从头构建一个完整的本地语言模型训练管线。目标是&#xff1a; ✅ 不依赖外部云计算✅ 完全本地运行…

Selenium-Java版(frame切换/窗口切换)

frame切换/窗口切换 前言 切换到frame 原因 解决 切换回原来的主html 切换到新的窗口 问题 解决 回到原窗口 法一 法二 示例 前言 参考教程&#xff1a;Python Selenium Web自动化 2024版 - 自动化测试 爬虫_哔哩哔哩_bilibili 上期文章&#xff1a;Sel…

一文深度解析:Pump 与 PumpSwap 的协议机制与技术差异

在 Solana 链上&#xff0c;Pump.fun 和其延伸产品 PumpSwap 构成了 meme coin 发行与流通的两大核心场景。从初期的游戏化发行模型&#xff0c;到后续的自动迁移与交易市场&#xff0c;Pump 系列协议正在推动 meme coin 从“爆发性投机”走向“协议化运营”。本文将从底层逻辑…