RepVGG:让VGG风格的ConvNet再次伟大起来

news2025/7/9 16:54:37

引言

经典的卷积神经网络(ConvNet)VGG [31]在图像识别中取得了巨大的成功,其简单的架构由conv、ReLU和池化的堆栈组成。随着Inception [33,34,32,19]、ResNet [12]和DenseNet [17]的出现,许多研究兴趣转向了设计良好的架构,使得模型越来越复杂。一些最近的架构基于自动[44,29,23]或手动[28]架构搜索,或搜索的复合缩放策略[35]。

尽管许多复杂的ConvNet比简单的ConvNet具有更高的精度,但缺点也很明显

  1. 复杂的多分支设计(例如,ResNet中的跳跃连接和Inception中的分支级联)使得模型难以实现和定制,减慢了推理速度,降低了内存利用率
  2. 一些组件(例如,Xception [3]和MobileNets [16,30]中的深度卷积和ShuffleNet [24,41]中的通道重排增加了存储器访问成本,并且缺乏对各种设备的支持。由于影响推理速度的因素太多,浮点运算(FLOPs)的数量并不能精确地反映实际速度。

由于多分支结构的优点都是用于训练,而缺点是不希望用于推理,因此我们提出通过结构重新参数化来解耦训练时多分支结构和推理时普通结构,即通过变换其参数将结构从一种转换到另一种.具体而言,网络结构与一组参数相耦合,Conv层由4阶核张量表示。如果某个结构的参数可以转换成另一个结构耦合的另一组参数,我们就可以用后者等价地替换前者,从而改变整个网络架构

相关工作

从单路径到多分支

在VGG [31]将ImageNet分类的前1位准确率提高到70%以上之后,在使ConvNet复杂化以获得高性能方面出现了许多创新,例如:当代的GoogleNet [33]和后来的Inception模型[34,32,19]采用了精心设计的多分支体系结构ResNet [12]提出了简化的两分支体系结构,而DenseNet [17]通过将低层与许多高层连接起来使拓扑结构更加复杂。神经结构搜索(NAS)[44,29,23,35]和手工设计空间设计[28]可以生成具有更高性能的ConvNet,但代价是大量的计算资源或人力。

方法

快速,节省内存,灵活

许多最近的多分支体系结构的理论FLOPs比VGG低,但运行速度可能不会更快。有两个重要因素对速度有相当大的影响,但flop没有考虑到:内存访问成本(MAC)和并行度。

结构重参数化

 上图展现了各个算子的融合过程

总结:单一的简单模型性能较差,多分支模型性能好、但是效率比较低。作者的想法是模型在多分支的情况下进行训练,然后在测试的时候转换为单一的简单模型。

Conv2d+BN融合实验

from collections import OrderedDict

import numpy as np
import torch
import torch.nn as nn


def main():
    torch.random.manual_seed(0)

    f1 = torch.randn(1, 2, 3, 3)

    module = nn.Sequential(OrderedDict(
        conv=nn.Conv2d(in_channels=2, out_channels=2, kernel_size=3, stride=1, padding=1, bias=False),
        bn=nn.BatchNorm2d(num_features=2)
    ))

    module.eval()

    with torch.no_grad():
        output1 = module(f1)
        print(output1)

    # fuse conv + bn
    kernel = module.conv.weight
    running_mean = module.bn.running_mean
    running_var = module.bn.running_var
    gamma = module.bn.weight
    beta = module.bn.bias
    eps = module.bn.eps
    std = (running_var + eps).sqrt()
    t = (gamma / std).reshape(-1, 1, 1, 1)  # [ch] -> [ch, 1, 1, 1]
    kernel = kernel * t
    bias = beta - running_mean * gamma / std
    fused_conv = nn.Conv2d(in_channels=2, out_channels=2, kernel_size=3, stride=1, padding=1, bias=True)
    fused_conv.load_state_dict(OrderedDict(weight=kernel, bias=bias))

    with torch.no_grad():
        output2 = fused_conv(f1)
        print(output2)

    # np.testing.assert_allclose(output1.numpy(), output2.numpy(), rtol=1e-03, atol=1e-05)
    # print("convert module has been tested, and the result looks good!")


if __name__ == '__main__':
    main()

RepVGG

import time
import torch.nn as nn
import numpy as np
import torch


def conv_bn(in_channels, out_channels, kernel_size, stride, padding, groups=1):
    result = nn.Sequential()
    result.add_module('conv', nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
                                        kernel_size=kernel_size, stride=stride, padding=padding,
                                        groups=groups, bias=False))
    result.add_module('bn', nn.BatchNorm2d(num_features=out_channels))
    return result


class RepVGGBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3,
                 stride=1, padding=1, dilation=1, groups=1, padding_mode='zeros', deploy=False):
        super(RepVGGBlock, self).__init__()
        self.deploy = deploy
        self.groups = groups
        self.in_channels = in_channels
        self.nonlinearity = nn.ReLU()

        if deploy:
            self.rbr_reparam = nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
                                         kernel_size=kernel_size, stride=stride,
                                         padding=padding, dilation=dilation, groups=groups,
                                         bias=True, padding_mode=padding_mode)

        else:
            self.rbr_identity = nn.BatchNorm2d(num_features=in_channels) \
                if out_channels == in_channels and stride == 1 else None
            self.rbr_dense = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
                                     stride=stride, padding=padding, groups=groups)
            self.rbr_1x1 = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=1,
                                   stride=stride, padding=0, groups=groups)

    def forward(self, inputs):
        if hasattr(self, 'rbr_reparam'):
            return self.nonlinearity(self.rbr_reparam(inputs))

        if self.rbr_identity is None:
            id_out = 0
        else:
            id_out = self.rbr_identity(inputs)

        return self.nonlinearity(self.rbr_dense(inputs) + self.rbr_1x1(inputs) + id_out)

    def get_equivalent_kernel_bias(self):
        kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense)
        kernel1x1, bias1x1 = self._fuse_bn_tensor(self.rbr_1x1)
        kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity)
        return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid

    def _pad_1x1_to_3x3_tensor(self, kernel1x1):
        if kernel1x1 is None:
            return 0
        else:
            return torch.nn.functional.pad(kernel1x1, [1, 1, 1, 1])

    def _fuse_bn_tensor(self, branch):
        if branch is None:
            return 0, 0
        if isinstance(branch, nn.Sequential):
            kernel = branch.conv.weight
            running_mean = branch.bn.running_mean
            running_var = branch.bn.running_var
            gamma = branch.bn.weight
            beta = branch.bn.bias
            eps = branch.bn.eps
        else:
            assert isinstance(branch, nn.BatchNorm2d)
            if not hasattr(self, 'id_tensor'):
                input_dim = self.in_channels // self.groups
                kernel_value = np.zeros((self.in_channels, input_dim, 3, 3), dtype=np.float32)
                for i in range(self.in_channels):
                    kernel_value[i, i % input_dim, 1, 1] = 1
                self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device)
            kernel = self.id_tensor
            running_mean = branch.running_mean
            running_var = branch.running_var
            gamma = branch.weight
            beta = branch.bias
            eps = branch.eps
        std = (running_var + eps).sqrt()
        t = (gamma / std).reshape(-1, 1, 1, 1)
        return kernel * t, beta - running_mean * gamma / std

    def switch_to_deploy(self):
        if hasattr(self, 'rbr_reparam'):
            return
        kernel, bias = self.get_equivalent_kernel_bias()
        self.rbr_reparam = nn.Conv2d(in_channels=self.rbr_dense.conv.in_channels,
                                     out_channels=self.rbr_dense.conv.out_channels,
                                     kernel_size=self.rbr_dense.conv.kernel_size, stride=self.rbr_dense.conv.stride,
                                     padding=self.rbr_dense.conv.padding, dilation=self.rbr_dense.conv.dilation,
                                     groups=self.rbr_dense.conv.groups, bias=True)
        self.rbr_reparam.weight.data = kernel
        self.rbr_reparam.bias.data = bias
        for para in self.parameters():
            para.detach_()
        self.__delattr__('rbr_dense')
        self.__delattr__('rbr_1x1')
        if hasattr(self, 'rbr_identity'):
            self.__delattr__('rbr_identity')
        if hasattr(self, 'id_tensor'):
            self.__delattr__('id_tensor')
        self.deploy = True

def main():
    f1 = torch.randn(1, 64, 64, 64)
    block = RepVGGBlock(in_channels=64, out_channels=64)
    block.eval()
    with torch.no_grad():
        output1 = block(f1)
        start_time = time.time()
        for _ in range(100):
            block(f1)
        print(f"consume time: {time.time() - start_time}")

        # re-parameterization
        block.switch_to_deploy()
        output2 = block(f1)
        start_time = time.time()
        for _ in range(100):
            block(f1)
        print(f"consume time: {time.time() - start_time}")

        np.testing.assert_allclose(output1.numpy(), output2.numpy(), rtol=1e-03, atol=1e-05)
        print("convert module has been tested, and the result looks good!")


if __name__ == '__main__':
    main()

参考文献

DingXiaoH/RepVGG: RepVGG: Making VGG-style ConvNets Great Again (github.com)

RepVGG网络简介_太阳花的小绿豆的博客-CSDN博客_repvgg网络简介

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/18233.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Struts2】一_idea快速搭建struts2框架

文章目录什么是SSH框架?Struts2框架1、struts2的环境搭建1.1 创建web项目(maven),导入struts2核心jar包1.2 配置web.xml(过滤器),是struts2的入口,先进入1.3 创建核心配置文件struts…

C语言日记 36 类的组合

书P137: 如果声明组合类的对象时没有指定对象的初始值,自动调用无形参的构造函数, 构造内嵌对象时也对应调用内嵌对象的无形参的构造函数。 Q1:这里,对于“构造内嵌对象时也对应调用内嵌对象的无形参的构造函数”;他指的是什么…

STM32F429基于TouchGFX进行简单控制LED和显示ADC值

所需软件: CubeMX KEIL MDK ARM TouchGFX 首先配置外部时钟 配置时钟树,设置180MHZ 使能GPIO口如下,其中PA0用于LED 配置ADC通道 定时器TIM8触发 配置FMC和SDRAM,参数固定 使能DMA2D,参数如下: 配置LTDC 屏幕分…

JAVA反射

今天我们来讲一讲什么是java的反射机制,我们要了解一个新事物之前,我们应该首先的了解它的基本概念,那什么是反射呢? java的反射概念:JAVA反射机制是在运行状态中,对于任意一个类,都能够知道这个类的所有属性和方法;对于任意一个对象&#…

centos7 安装hadoop

文章目录一、将hadoop的压缩文件传递到虚拟机里面二、解压缩三、配置环境变量一、将hadoop的压缩文件传递到虚拟机里面 路径随意,只要你能到时候能找到压缩文件就行。 二、解压缩 这里我给解压到opt/module目录里(没有可以自己创建,主要是为了方便管理) tar -zx…

Java#15(集合)

一.集合和数组的区别 1.从长度方面: 数组的长度是固定的,而集合的长度不是固定的 2.从存储类型方面: 数组可以存储基本数据类型也可以存储引用数据类型,而集合能存储引用数据类型,若是想要存储基本数据类型要将其变成对应的包装类 创建一个集合 二.ArrayList成员方法 1.boole…

【Spring5】使用JdbcTemplate操作mysql数据库

文章目录1 JdbcTemplate简介及操作准备2 添加操作3 修改与删除操作4 数据库查询操作4.1 返回一个值4.2 返回对象4.3 返回集合5 批量操作5.1 批量添加与修改5.2 批量修改与批量删除写在最后1 JdbcTemplate简介及操作准备 Spring框架对JDBC进行封装,使用JdbcTemplate…

java计算机毕业设计ssm教师贴心宝的设计与实现

项目介绍 随着互联网技术的发发展,计算机技术广泛应用在人们的生活中,逐渐成为日常工作、生活不可或缺的工具,高校各种管理系统层出不穷。高校作为学习知识和技术的高等学府,信息技术更加的成熟,为教师开发必要的系统,能够有效的提升管理效率。近年来,高校规模不断扩大,同时在…

[附源码]SSM计算机毕业设计中达小区物业管理系统JAVA

项目运行 环境配置: Jdk1.8 Tomcat7.0 Mysql HBuilderX(Webstorm也行) Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。 项目技术: SSM mybatis Maven Vue 等等组成,B/S模式 M…

c#入参使用引用类型为啥要加ref?

目录 ref修饰入参的常用场景引用类型添加ref的作用是啥?总结那什么是值,什么是引用?大体可以理解为堆栈的区别,在.net中大多数实例存在于托管堆栈中。struct,int32,int64,double,en…

mathtype符号显示不全的问题

目录前言参考问题解决方法补充最后前言 最近在使用Mathtype往Word里插入数学符号时,发现有些符号的间距过大,整个排版不太好看,于是上网查了一下,将方法记录下来。 参考 Word插入数学符号,行间距变大怎么办&#xf…

【408数据结构与算法】—堆排序(二十一)

【408数据结构与算法】—堆排序&#xff08;二十一&#xff09; 一、堆的定义 从堆的定义可以看出&#xff0c;堆实质是满足如下性质的完全二叉树&#xff0c;二叉树中任一非叶子结点均小于&#xff08;大于&#xff09;它的孩子结点 C语言代码实现 #include <stdio.h>…

python零基础入门教程(非常详细),从零基础入门到精通,看完这一篇就够了

前言 本文罗列了了python零基础入门到精通的详细教程&#xff0c;内容均以知识目录的形式展开。 第一章&#xff1a;python基础之markdown Typora软件下载Typora基本使用Typora补充说明编程与编程语言计算机的本质计算机五大组成部分计算机三大核心硬件操作系统 第二章&…

java计算机毕业设计ssm健达企业项目管理系统

项目介绍 随着经济的发展和信息技术的普及,国内许多健达企业都面临了重大的挑战。健达企业的管理流程、战略规划如果不能进行调整,极有可能面临淘汰的风险。特别是企业项目的处理,面对大量的人员信息和业务信息,如果不使用信息系统进行有效的管理和利用,那就会阻碍健达企业的发…

[附源码]SSM计算机毕业设计中华美食网站JAVA

项目运行 环境配置&#xff1a; Jdk1.8 Tomcat7.0 Mysql HBuilderX&#xff08;Webstorm也行&#xff09; Eclispe&#xff08;IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持&#xff09;。 项目技术&#xff1a; SSM mybatis Maven Vue 等等组成&#xff0c;B/S模式 M…

gcc和g++的使用

linux编译器-gcc和g的使用 文章目录linux编译器-gcc和g的使用预处理编译汇编链接函数库动态库和静态库file 查看可执行程序ldd 查看可执行程序格式make和makefilestat 查看文件或目录时间在讲gcc和g编译同时&#xff0c;我们复习一下程序翻译的大概过程&#xff0c;并以此为例切…

IntelliJ IDEA 配置启动SprintBoot项目

【原文链接】IntelliJ IDEA 配置启动SprintBoot项目 文章目录一、IDEA 配置maven二、IDEA 配置jdk三、IDEA 启动项目一、IDEA 配置maven &#xff08;1&#xff09; 首先本机配置好maven&#xff0c;具体可参考 Win10系统如何安装配置maven &#xff08;2&#xff09;然后在打…

Windows10中使用VS2022和Cmake编译构建C++开源日志库-spdlog

一、关于C中的开源日志库spdlog Java中有很多日志库&#xff1a;java.util.logging、Log4j、Logback、Log4j2、slf4j、common-logging。C的日志库相对来说就比较少了&#xff0c;比如说glog、log4cpp、spdllog等&#xff0c;目前个人感觉比较好用的C开源日志库当属于spdlog了&…

这次把怎么做好一个PPT讲清-审美篇

要提高审美&#xff0c;主要是靠不断的看优秀的作品来知道什么是美的&#xff0c;这个短时间很难速成&#xff0c;只能靠不断的积累。 如何做出具有高级感的PPT&#xff1f; 已剪辑自: https://zhuanlan.zhihu.com/p/38642831 很多年前&#xff0c;走在大街上的PPT大多长得像…

打破边界,边缘计算有何应用场景?

近年来&#xff0c;随着5G、物联网、人工智能技术的发展&#xff0c;越来越多设备接入到互联网中&#xff0c;数据呈现爆炸式增长&#xff0c;对算力、延时提出更好要求&#xff0c;能够在靠近数据源头位置提供计算服务的边缘计算快速兴起&#xff0c;打破更多的场景边界&#…