CondConv 动态卷积学习笔记 (附代码)

news2025/7/19 2:49:52

论文地址:https://arxiv.org/abs/1904.04971

代码地址:https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/condconv

1.是什么?

CondConv是一种条件参数卷积,也称为动态卷积,它是一种即插即用的模块,可以为每个样例学习一个特定的卷积核参数。通过替换标准卷积,CondConv可以提升模型的尺寸与容量,同时保持高效推理。CondConv的设计优越性在于它只需要做一次卷积,而不像其他方法需要做n次卷积,这样可以大大减少计算开销。CondConv可以对现有网络中的标准卷积进行替换,同时适用于深度卷积与全连接层。实验结果表明,CondConv可以显著提高模型的性能。 

2.为什么?

CNN在诸多计算机视觉任务中取得了前所未有的成功,但其性能的提升更多源自模型尺寸与容量的提升以及更大的数据集。模型的尺寸提升进一步加剧了计算量的提升,进一步加大优秀模型的部署难度。

现有CNN的一个基本假设:对所有样例采用相同的卷积参数。这就导致:为提升模型的容量,就需要加大模型的参数、深度、通道数,进一步导致模型的计算量加大、部署难度提升。由于上述假设以及终端部署需求,当前高效网络往往具有较少的参数量。然而,在某些计算机视觉应用中(如终端视频处理、自动驾驶),模型实时性要求高,对参数量要求较低。

作者提出一种条件参数卷积用于解决上述问题,它通过输入计算卷积核参数打破了传统的静态卷积特性。特别的,作者将CondConv中的卷积核参数化为多个专家知识的线性组合(其中,是通过梯度下降学习的加权系数):。为更有效的提升模型容量,在网络设计过程中可以提升专家数量,这比提升卷积核尺寸更为高效,同时专家知识只需要进行一次组合,这就可以在提升模型容量的同时保持高效推理。

3 怎么样?

3.1网络结构

结构1,如下图,首先它采用更细粒度的集成方式,每一个卷积层都拥有多套权重,卷积层的输入分别经过不同的权重卷积之后组合输出,缺点是但这计算量依旧很大。

 

 结构2,如图2,为了解决图1计算大问题,作者提出既然输入相同,卷积是一种线性计算,COMBINE也是一个线性计算(比如加权求和),作者将多套权重加权组合之后,只做一次卷积就能完成相当的效果!计算量相比上图,大大降低。

3.2 原理

在常规卷积中,其卷积核参数经训练确定且对所有输入样本“一视同仁”;而在CondConv中,卷积核参数参数通过对输入进行变换得到,该过程可以描述为:

这里x xx表示上一个layer的输出,n nn表示这一层Condconv Layer有n nn个expert(expert就是该层的卷积核W),σ 表示激活函数,a_{i}=r_{i}(x)表示一个样本依赖的加权参数。
所以一个CondConv层的卷积核参数的由来,就是通过上述的线性组合公式。整个流程可以概括为:依赖于输入x,在卷积操作之前,通过routing函数r_{i}(x)计算出每一个expert前面的系数a_{i} ,再通过线性组合,得到CondConv层最终的kernal,最后与输入x xx做卷积,并进行activation。在这里,routing weight的计算公式如下:

对于输入x xx,首先做GlobalAveragePooling,随后右乘一个矩阵R(该矩阵的目的是将维度映射到n个expert上面,以实现后续的线性组合),最后通过sigmoid将每一个维度上的权值规约到[0,1]区间。因此,根据输入x xx的不同,就会得到不同的routing weight向量,进而CondConv层的kernal也各有差异。
 

3.3代码实现

import torch
import torch.nn.functional as F
import torch.nn as nn
from torch import Tensor
import functools
from torch.nn.modules.conv import _ConvNd
from torch.nn.modules.utils import _pair
from torch.nn.parameter import Parameter
 
 
class _routing(nn.Module):
 
    def __init__(self, in_channels, num_experts, dropout_rate):
        super(_routing, self).__init__()
        
        self.dropout = nn.Dropout(dropout_rate)
        self.fc = nn.Linear(in_channels, num_experts)
 
    def forward(self, x):
        x = torch.flatten(x)
        x = self.dropout(x)
        x = self.fc(x)
        return F.sigmoid(x)
    
 
class CondConv2D(_ConvNd):
    def __init__(self, in_channels, out_channels, kernel_size, 
                 stride=1, padding=0, dilation=1, groups=1,
                 bias=True, padding_mode='zeros', 
                 num_experts=3, dropout_rate=0.2):
        
        # tuple
        kernel_size = _pair(kernel_size)
        stride = _pair(stride)
        padding = _pair(padding)
        dilation = _pair(dilation)
        super(CondConv2D, self).__init__(
            in_channels, out_channels, kernel_size, stride, padding, dilation,
            False, _pair(0), groups, bias, padding_mode)
 
        self._avg_pooling = functools.partial(F.adaptive_avg_pool2d, output_size=(1, 1))
        self._routing_fn = _routing(in_channels, num_experts, dropout_rate)
        
        self.weight = Parameter(torch.Tensor(
            num_experts, out_channels, in_channels // groups, *kernel_size))
        
        self.reset_parameters()
 
    def _conv_forward(self, input, weight):
        if self.padding_mode != 'zeros':
            return F.conv2d(F.pad(input, self._padding_repeated_twice, mode=self.padding_mode),
                            weight, self.bias, self.stride,
                            _pair(0), self.dilation, self.groups)
        return F.conv2d(input, weight, self.bias, self.stride,
                        self.padding, self.dilation, self.groups)
    
    def forward(self, inputs):
        b, _, _, _ = inputs.size()
        res = []
        for input in inputs:
            input = input.unsqueeze(0)
            pooled_inputs = self._avg_pooling(input)
            routing_weights = self._routing_fn(pooled_inputs)
            kernels = torch.sum(routing_weights[: ,None, None, None, None] * self.weight, 0)
            out = self._conv_forward(input, kernels)
            res.append(out)
        return torch.cat(res, dim=0)

参考:

动态卷积之CondConv和DynamicConv

CondConv:用于有效推理的条件参数化卷积

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1157584.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

数据结构之集合框架

1.Java集合框架的定义 Java 集合框架 Java Collection Framework ,又被称为容器 container ,是定义在 java.util 包下的一组接口 interfaces和其实现类 classes 。 其主要表现为将多个元素 element 置于一个单元中,用于对这些元素进行…

向上管理中的沟通技巧

一. 背景 我们要弄清楚两个问题为什么要向上管理呢,向上管理主要是要做什么呢? 首先,第一个问题为什么要向上管理?向上管理的本质是为了同时给公司、上司和自己带来最好的结果,并有意识地配合和改变工作方法&#xf…

iPhone连不上Wi-Fi?看完这篇文章你就知道了!

大家在使用苹果手机的过程中有没有遇到过这样的情况:手机突然连接不上Wi-Fi,或者连接了也根本使用不了。遇到上述情况请不要着急,iphone连不上wifi是由很多种原因导致的。那么,iPhone连接不上Wi-Fi时该怎么办呢? 我们…

测试可用的安防视频分析软件:烟火检测、车型检测、玩手机打电话检测、厨帽检测、抽烟检测、人员入侵检测

下载地址:https://pan.baidu.com/s/1R1MvD_KQ3uB-0KL_N3is-w?pwdwa33 随着AI、大数据、云计算和边缘计算等技术的迅猛发展,我国的视频监控市场正处于全新的阶段。借助AI深度学习技术的进步,现代化的安防视频监控系统通过边缘计算设备上的AI识…

用友NC BeanShell RCE漏洞

一、漏洞简介 用友 NC 是面向集团企业的管理软件,其在同类市场占有率中达到亚太第一。用友 NC 由于对外开放了 BeanShell 接口,攻击者可以在未授权的情况下直接访问该接口,并构造恶意数据执行任意代码从而获取服务器权限。 二、影响版本 NC …

误删的文件不在回收站如何找回?分享3个简单方法!

“我前段时间清理电脑的时候误删了一些比较重要的文件,通常我都会使用回收站来还原这些文件的,但昨天不小心清空了回收站,想问问还有机会找回我的文件吗?” 为了保证用户的权益,误删的文件通常会先被移入电脑回收站中。…

SpringBootWeb案例——Tlias智能学习辅助系统(1)

目录 需求与准备环境搭建REST风格的API接口开发规范-统一响应结果 部门管理部门列表查询功能删除部门新增部门请求路径优化查询部门修改部门 员工管理分页查询分页插件PageHelper分页查询(带条件) (难点)删除员工 需求与准备 1、部门管理 包括: 查询部门列表 删除部…

解决远程连接数据库缓慢的问题【图文】【非常详细】

问题概述 当我们远程访问数据库,遇到连接不上或者连接等待时间较长,问题大概率就出在数据库远程链接解析的问题,就是在MySQL的配置文件中增加如下配置参数: [mysqld] skip-name-resolve 具体操作如下 解决步骤 打开mysql所在文…

中小企业选择外贸管理系统有哪些常见的误区?

中小企业基础设施相对薄弱、人员管理松散,选择外贸管理系统是很多管理者的解决方案。选型系统不是一蹴而就的,其中会遇到很多问题甚至进入误区,那么中小企业选择外贸管理系统有哪些常见的误区? 本地部署比云服务更安全 CRM数据安…

【唠唠嵌入式】__如何学习单片机?

目录 前言 个人定位,从事软件还是硬件? 学习内容 (* ̄︶ ̄)创作不易!期待你们的 点赞、收藏和评论喔。 前言 作为一个老司机,多年来跟单片机、Keil、C语言、AD、烙铁、风枪、示波器、电子元器件纠缠不清…

【机器学习】五、贝叶斯分类

我想说:“任何事件都是条件概率。”为什么呢?因为我认为,任何事件的发生都不是完全偶然的,它都会以其他事件的发生为基础。换句话说,条件概率就是在其他事件发生的基础上,某事件发生的概率。 条件概率是朴…

获取Webshell方法

CMS系统指的是内容管理系统。已经有别人开发好了整个网站的前后端,使用者只需要部署cms,然后通过后台添加数据,修改图片等工作,就能搭建好一个的WEB系统。 CMS获取Webshell方法 WordPress后台拿Webshell phpcms拿Webshell 非CMS…

Vue:实现复制按钮功能

作者:CSDN @ _乐多_ 本文记录了vue开发中,复制按钮的实现代码。用于复制网页中的一个数或者字符串啥的。 效果如下图所示, 文章目录 <el-button @click="copyToClipboard(wgs84Position2.altitude)">复制</el-button>data(

AI赋能,轻松出爆文!AI新闻创作新时代,你准备好了吗?

众所周知&#xff0c;传统新闻报道需要大量的人工参与&#xff0c;不仅耗时耗力&#xff0c;还对媒体工作者的文字功底和知识积累有很高的要求。但随着人工智能技术的发展&#xff0c;大模型在新闻写作领域展现出强大的潜力。通过AI写作技术&#xff0c;在很大程度上实现了新闻…

小程序商城免费搭建之java商城 电子商务Spring Cloud+Spring Boot+二次开发+mybatis+MQ+VR全景+b2b2c

1. 涉及平台 平台管理、商家端&#xff08;PC端、手机端&#xff09;、买家平台&#xff08;H5/公众号、小程序、APP端&#xff08;IOS/Android&#xff09;、微服务平台&#xff08;业务服务&#xff09; 2. 核心架构 Spring Cloud、Spring Boot、Mybatis、Redis 3. 前端框架…

第二证券:深圳规划建设中心城区1公里超充圈

大湾区之声消息&#xff0c;本年《深圳市新能源轿车超充设备专项规划》发布&#xff0c;明晰了深圳“超充之城”制作的“路线图”。现在&#xff0c;深圳累计建成超充站41座&#xff0c;还有11座超充站正在制作中。到2023年年末&#xff0c;争夺建成不少于150座共用超充站&…

SpringBoot数据响应、分层解耦、三层架构

响应数据 ResponseBody 类型&#xff1a;方法注解、类注解位置&#xff1a;Controller方法、类上作用&#xff1a;将方法返回值直接响应&#xff0c;如果返回值类型是 实体对象/集合 &#xff0c;将会转换为json格式响应说明&#xff1a;RestController Controller Respons…

免费硬盘数据恢复软件EasyRecovery2024

大部分人在存储数据时&#xff0c;都会将数据存储在硬盘之中。虽然数据存放在硬盘中相对安全&#xff0c;但是硬盘的开盘数据也会出现误删的状况&#xff0c;那么当数据误删时&#xff0c;该怎么进行修复呢。下面就来介绍硬盘开盘数据恢复能全部恢复吗&#xff0c;硬盘开盘数据…

tbh Cutter切割节点

在tbh过程中&#xff0c;如果想把一个元素置为一个位图图片某个图像的后边&#xff0c;如何做呢&#xff1f; 如&#xff1a;把圆形放到花的后边&#xff1a; 最后实现效果&#xff1a;&#xff08;请忽略边缘的细节&#xff0c;这里只是记录用法&#xff09; 第一步&#xf…

windows10下Node.js安装教程

文章目录 windows10下Node.js安装教程下载安装包执行安装检查环境变量测试 windows10下Node.js安装教程 下载安装包 官网 执行安装 检查环境变量 系统已经为Node.js添加了相应的系统环境变量 测试 打开命令窗口 使用命令&#xff1a; END