ConvMixer:Patches Are All You Need

news2025/7/27 4:35:03

Patches Are All You Need

发表时间:[Submitted on 24 Jan 2022];

发表期刊/会议:Computer Vision and Pattern Recognition;

论文地址:https://arxiv.org/abs/2201.09792;

代码地址:https://github.com/locuslab/convmixer;


0 摘要

尽管CNN多年以来一直是计算机视觉任务的主要架构,但最近的一些工作表明,基于Transformer的模型,尤其是ViT,在某些情况下会超越CNN的性能(尤其是后来的swin transformer,完全超越CNN, 里程碑);

然而,因为Transformer的self-attention运行时间为二次的/平方的( O ( n 2 ) O(n^2) O(n2)),ViT使用patch embedding,将图像的小区域组合成单个输入特征,以便应用于更大的图像尺寸。

这就引出一个问题: ViT的性能是由于Transformer本身就足够强大,还是因为输入是patch?

本文为后者提供了一些证据;

本文提出一种非常简单的模型:ConvMixer,思想类似于MLP-Mixer;

  • MLP-Mixer直接在作为输入的patch上操作,分离空间和通道维度的混合信息,并在整个网络中保持相同的大小和分辨率。

  • ConvMixer只使用标准卷积来实现混合步骤。

尽管它很简单,但本文表明ConvMixer在类似的参数计数和数据集大小方面优于ViT、MLP-Mixer和它们的一些变体,此外还优于经典视觉模型(如ResNet)。


1 简介

本文探索一个问题:ViT的性能强大是因为Transformer结构本身,还是更多的来源于这种patch的表征形式?

本文提出一个非常简单的卷积架构,我们称之为“ConvMixer”,因为它与最近提出的MLP-Mixer相似(Tolstikhin et al, 2021)。

ConvMixer的许多方面都和ViT或MLP-Mixer类似

  • 直接对patch进行操作;
  • 在所有层中保持相同的分辨率和大小表示(feature map不降维、没有下采样);
  • 不会对连续层的表示进行下采样;
  • 将信息的“通道混合”与“空间混合”分开(depthwise 和 pointwise conv);

不同之处:

  • ConvMixer只通过标准卷积来完成所有这些操作;

结论:patch的表征形式很重要;


2 ConvMixer模型

2.0 模型概述

如图2所示:

  • 输入图像大小为 c × n × n c×n×n c×n×n,c-通道,n-宽度/高度;
  • patch大小为 p p p,进行patch embedding后,个数为 n / p × n / p n/p × n/p n/p×n/p,一个嵌入成h维的向量,得到向量块(也可以叫feature map) h × ( n / p ) × ( n / p ) h×(n/p)×(n/p) h×(n/p)×(n/p)
    • 这个patch embedding不同于Transformer的patch embedding;
    • 这一步相当于用一个输入通道为 c c c,输出通道为 h h h,卷积核大小=patch_size, stride = patch_size的卷积核去卷出的feature map;
  • 将这个feature map进行GeLU激活和BN,输入进ConvMixer Layer中;
  • ConvMixer层由深度卷积depthwise conv和逐点卷积pointwise conv和残差连接组成,每一个卷积之后都会有GeLU激活和BN;
    • depthwise conv: 将 h h h个通道各自进行卷积=>空间混合;
    • pointwise conv:1×1的卷积,对通道之间混合;
  • ConvMixer层会循环depth次;
  • 最后接入分类头;
图2:ConvMixer概述

Pytorch实现:

class ConvMixerLayer(nn.Module):
    def __init__(self,dim,kernel_size = 9):
        super().__init__()
        #残差结构
        self.Resnet =  nn.Sequential(
            nn.Conv2d(dim,dim,kernel_size=kernel_size,groups=dim,padding='same'),
            nn.GELU(),
            nn.BatchNorm2d(dim)
        )
        #逐点卷积
        self.Conv_1x1 = nn.Sequential(
            nn.Conv2d(dim,dim,kernel_size=1),
            nn.GELU(),
            nn.BatchNorm2d(dim)
        )
    def forward(self,x):
        x = x +self.Resnet(x)
        x = self.Conv_1x1(x)
        return x

class ConvMixer(nn.Module):
    def __init__(self,dim,depth,kernel_size=9, patch_size=7, n_classes=1000):
        super().__init__()
        self.conv2d1 = nn.Sequential(
            nn.Conv2d(3,dim,kernel_size=patch_size,stride=patch_size),
            nn.GELU(),
            nn.BatchNorm2d(dim)
        )
        self.ConvMixer_blocks =nn.ModuleList([])

        for _ in range(depth):
            self.ConvMixer_blocks.append(ConvMixerLayer(dim=dim,kernel_size=kernel_size))

        self.head =  nn.Sequential(
            nn.AdaptiveAvgPool2d((1,1)),
            nn.Flatten(),
            nn.Linear(dim,n_classes)
        )
    def forward(self,x):
    	#编码时的卷积
        x = self.conv2d1(x)
		#多层ConvMixer_block  的计算
        for ConvMixer_block in  self.ConvMixer_blocks:
             x = ConvMixer_block(x)
        #分类输出
        x = self.head(x)

        return x


model = ConvMixer(dim=128,depth=2)
print(model)
ConvMixer(
  (conv2d1): Sequential(
    (0): Conv2d(3, 128, kernel_size=(7, 7), stride=(7, 7))
    (1): GELU()
    (2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (ConvMixer_blocks): ModuleList(
    (0): ConvMixerLayer(
      (Resnet): Sequential(
        (0): Conv2d(128, 128, kernel_size=(9, 9), stride=(1, 1), padding=same, groups=128)
        (1): GELU()
        (2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (Conv_1x1): Sequential(
        (0): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1))
        (1): GELU()
        (2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): ConvMixerLayer(
      (Resnet): Sequential(
        (0): Conv2d(128, 128, kernel_size=(9, 9), stride=(1, 1), padding=same, groups=128)
        (1): GELU()
        (2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (Conv_1x1): Sequential(
        (0): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1))
        (1): GELU()
        (2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
  )
  (head): Sequential(
    (0): AdaptiveAvgPool2d(output_size=(1, 1))
    (1): Flatten(start_dim=1, end_dim=-1)
    (2): Linear(in_features=128, out_features=1000, bias=True)
  )
)


2.1 参数设计

ConvMixer的实例化依赖于四个参数:

  • the “width” or hidden dimension: h h h (patch embedding的维度);
  • ConvMixer层的循环次数: d e p t h depth depth
  • 控制模型内部分辨率的patch size: p p p
  • 深度卷积层的核大小: k k k

其他ConvMixer模型的命名规则:ConvMixer-h/d;


2.2 动机

本文的架构是基于混合的想法;特别地,我们选择了深度卷积dw来混合空间位置和点卷积来pw混合通道位置。

以前工作的一个关键观点是,MLP和自我注意可以混合远的空间位置,也就是说,它们可以有任意大的接受域。因此,我们使用大核卷积来混合遥远的空间位置。

虽然自我注意和MLP理论上更灵活,允许大的接受域和内容感知行为,但卷积的归纳偏差非常适合视觉任务。通过使用这样的标准操作,我们也可以看到与传统的金字塔形、逐步下采样的卷积网络设计相比,patch表示本身的效果。


3 实验

3.1 训练设置

主要在ImageNet-1k分类上评估ConvMixers,没有任何预训练或其他数据;

将ConvMixer添加到timm框架,并使用接近标准的设置对其进行训练: 除了默认的timm增强外,我们还使用RandAugment、mixup、CutMix、随机擦除和梯度范数裁剪。使用AdamW优化器;

由于计算量有限,我们绝对没有在ImageNet上进行超参数调优,并且训练的epoch比竞争对手少。

因此,我们的模型可能过度正则化或不正则化,我们报告的准确性可能低估了我们模型的能力。


3.2 实验结果


  • 精度:在ImageNet上,参数为52M的ConvMixer-1536/20可以达到81.4%的top-1精度,参数为21M的ConvMixer-768/32可以达到80.2%的top-1精度;
  • 宽度:更宽的ConvMixer似乎收敛更快,但需要大量内存和计算;
  • 内核大小:当将内核大小从k = 9减小到k = 3时,ConvMixer-1536/20的精度下降了≈1%;
  • patch大小:较小patch的ConvMixers基本上更好,更大的patch可能需要更深的ConvMixers;除了将patch大小从7增加到14,其他都保持不变,ConvMixer-1536/20达到了78.9%的top-1精度,但速度快了大约4倍;
  • 激活函数:用ReLU训练了一个模型,证明在最近的各向同性模型中流行的GELU是不必要的。


3.3 比较

将ConvMixer模型与ResNet/DeiT/ResMLP比较,结果如表1、图1所示;


  • 同等参数量,ConvMixer-1536/20的性能优于ResNet-152和ResMLP-B24;
  • ConvMixers在推理方面比竞争对手慢得多,可能是由于它们的patch尺寸更小; 超参数调优和优化可以缩小这一差距。有关更多讨论和比较,请参见表2和附录A。


4 相关工作

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/394709.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Python编程训练题2

1.11 有 n 盏灯&#xff0c;编号 1&#xff5e;n&#xff08;0<n<100&#xff09;。第 1 个人把所有灯打开&#xff0c;第 2 个人按下所有编号为 2 的倍数的开关&#xff08;这些灯将被关掉&#xff09;&#xff0c;第 3 个人按下所有编号为 3 的倍数的开关&#xff08;其…

【华为OD机试2023】租车骑绿岛 C++ Java Python

【华为OD机试2023】租车骑绿岛 C++ Java Python 前言 如果您在准备华为的面试,期间有想了解的可以私信我,我会尽可能帮您解答,也可以给您一些建议! 本文解法非最优解(即非性能最优),不能保证通过率。 Tips1:机试为ACM 模式 你的代码需要处理输入输出,input/cin接收输入…

如何实现在on ethernetPacket中自动回复NDP response消息

对于IPv4协议来说,如果主机想通过目标ipv4地址发送以太网数据帧给目的主机,需要在数据链路层填充目的mac地址。根据目标ipv4地址查找目标mac地址,这是ARP协议的工作原理 对于IPv6协议来说,根据目标ipv6地址查找目标mac地址,它使用的不是ARP协议,而是邻居发现NDP(Neighb…

Oracle启动数据库报ORA-01102解决办法

1.机器启动之后登录服务器使用sqlplus / as sysdba 登录数据库发现数据库并没有启动之前把数据库服务添加过开机自启动 2.使用startup命令启动数据库报错了 SYSorcl>startup; ORACLE 例程已经启动。 Total System Global Area 2471931904 bytes Fixed Size 2255752 byt…

框架——MyBatis的入门案例

框架概述1.1什么是框架框架&#xff08;Framework&#xff09;是整个或部分系统的可重用设计&#xff0c;表现为一组抽象构件及构件实例间交与的方法&#xff1b;另一种定义认为&#xff0c;框架是可被应用开发者定制的应用骨架。前者是从应用方面而后者是从目的方面给出的定义…

关基系统国产化全面落地,ZoomEye Pro支持信创资产识别

信创产业发展的核心动力是IT底层架构的独立自主&#xff0c;为了尽快推进关键信息基础设施系统的国产化替代&#xff0c;一方面国家不断推出相关政策&#xff0c;协调各方资源&#xff0c;提供强有力的政策支撑&#xff0c;另一方面也在各关基行业有序推进重要信息基础设施的国…

第四章:面向对象编程

第四章&#xff1a;面向对象编程 4.1&#xff1a;面向过程与面向对象 面向过程(POP)与面向对象(OOP) 二者都是一种思想&#xff0c;面向对象是相对于面向过程而言的。面向过程&#xff0c;强调的是功能行为&#xff0c;以函数为最小单位&#xff0c;考虑怎么做。面向对象&…

2024秋招BAT核心算法 | 详解图论

图论入门与最短路径算法 图的基本概念 由节点和边组成的集合 图的一些概念&#xff1a; ①有向边&#xff08;有向图&#xff09;&#xff0c;无向边&#xff08;无向图&#xff09;&#xff0c;权值 ②节点&#xff08;度&#xff09;&#xff0c;对应无向图&#xff0c;…

抓狂!谷歌账号又又登录异常?给你支招解决

最近&#xff0c;就有很多朋友向东哥反馈说&#xff0c;谷歌账号登录异常了&#xff0c;明明账号密码都是对的&#xff0c;愣是登不上去&#xff0c;严重影响工作进度&#xff0c;很是捉急。所以东哥今天就总结了一份谷歌账号登录异常的解决方案&#xff0c;希望能帮助到大家&a…

CAS详解

CAS详解一 简介二 CAS底层原理2.1.AtomicInteger内部的重要参数2.2.AtomicInteger.getAndIncrement()分析2.2.1.getAndIncrement()方法分析2.2.2.举例分析三 CAS缺点四 CAS会导致"ABA问题"4.1.AtomicReference 原⼦引⽤。4.2.ABA问题的解决(AtomicStampedReference 类…

Eslint、Stylelint、Prettier、lint-staged、husky、commitlint【前端代码校验规则】

一、Eslint yarn add typescript-eslint/eslint-plugin typescript-eslint/parser eslint eslint-config-prettier eslint-config-standard-with-typescript eslint-plugin-import eslint-plugin-n eslint-plugin-prettier eslint-plugin-promise eslint-plugin-react eslint-…

实验四:搜索

实验四&#xff1a;搜索 1.填格子 题目描述 有一个由数字 0、1 组成的方阵中&#xff0c;存在一任意形状的封闭区域&#xff0c;封闭区域由数字1 包围构成&#xff0c;每个节点只能走上下左右 4 个方向。现要求把封闭区域内的所有空间都填写成2 输入要求 每组测试数据第一…

Provisioning Edge Inference as a Service via Online Learning 阅读笔记

通过在线学习提供边缘推理服务 一、论文研究背景、动机和主要贡献 研究背景 趋势&#xff1a;机器学习模型训练从中央云服务器逐步转移到边缘服务器 好处&#xff1a; 与云相比&#xff1a;a.低延迟 b.保护用户隐私&#xff08;数据不会上传到云&#xff09;与on-device相…

如何理解元数据、数据元、元模型、数据字典、数据模型这五个的关系?如何进行数据治理呢?数据治理该从哪方面入手呢?

如何理解元数据、数据元、元模型、数据字典、数据模型这五个的关系&#xff1f;如何进行数据治理呢&#xff1f;数据治理该从哪方面入手呢&#xff1f;导读一、数据元二、元数据三、数据模型四、数据字典五、元模型导读 请问元数据、数据元、数据字典、数据模型及元模型的区别…

数仓治理之数据梳理

目录 1.定义 2.用途作用 3.实施方法 3.1自上而下 3.1.1数据域梳理 3.1.2数据主题梳理 3.1.3 数据实体梳理 3.1.4设计数据模型 3.1.5优点 3.1.5缺点 3.2自下而上 3.2.1需求分析 3.2.2展现 3.2.3分析逻辑 3.2.4数据建模 3.2.5优点 3.2.6缺点 1.定义 “数据梳理”即对…

SpringBoot 如何保证接口安全?

为什么要保证接口安全对于互联网来说&#xff0c;只要你系统的接口暴露在外网&#xff0c;就避免不了接口安全问题。 如果你的接口在外网裸奔&#xff0c;只要让黑客知道接口的地址和参数就可以调用&#xff0c;那简直就是灾难。举个例子&#xff1a;你的网站用户注册的时候&am…

【云原生kubernetes】k8s数据存储之Volume使用详解

目录 一、什么是Volume 二、k8s中的Volume 三、k8s中常见的Volume类型 四、Volume 之 EmptyDir 4.1 EmptyDir 特点 4.2 EmptyDir 实现文件共享 4.2.1 关于busybox 4.3 操作步骤 4.3.1 创建配置模板文件yaml 4.3.2 创建Pod 4.3.3 访问nginx使其产生访问日志 4.3.4 …

I.MX6ULL_Linux_系统篇(27) 系统烧录工具

前面我们已经移植好了 uboot 和 linux kernle&#xff0c;制作好了根文件系统。但是我们移植都是通过网络来测试的&#xff0c;在实际的产品开发中肯定不可能通过网络来运行&#xff0c;因此我们需要将 uboot、 linux kernel、 .dtb(设备树)和 rootfs 这四个文件烧写到板子上的…

Nginx学习 (2) —— 虚拟主机配置

文章目录虚拟主机原理域名解析与泛域名解析&#xff08;实践&#xff09;配置文件中ServerName的匹配规则技术架构多用户二级域名短网址虚拟主机原理 为什么需要虚拟主机&#xff1a; 当一台主机充当服务器给用户提供资源的时候&#xff0c;并不是一直都有很大的用户量&#…

数据库面试题总结——DBA面试battle指南

目录 前言 数据库复制 oracle和pg的同步原理 mysql的同步原理 mysql的GTID 主从架构如何保证数据不丢失 oracle的保护模式 pg的日志传输模​​​​​​​式 mysql同步模式 从库只读 oracle的只读 pg的只读 mysql的只读 索引结构和寻迹 B树索引 索引寻迹 绑定执…