从CNN到ViT:混合网络架构的设计哲学与PyTorch实战
1. 项目概述为什么我们需要混合网络在计算机视觉领域待了十几年我亲眼见证了模型架构的“风水轮流转”。从早期的LeNet、AlexNet到后来统治多年的ResNet、DenseNet等纯卷积神经网络再到这两年Transformer架构特别是Vision Transformer ViT的异军突起感觉就像看了一场技术“擂台赛”。但说实话无论是纯CNN还是纯ViT在实际落地时我总觉得各有各的“脾气”很难找到一个“全能选手”。纯CNN模型比如ResNet在处理图像局部特征、纹理、边缘时那是相当拿手得益于卷积操作的归纳偏置它对平移、缩放有一定的不变性而且参数量和计算量相对可控训练起来对数据量的要求也没那么“苛刻”。但它的“视野”有时候不够广长距离的依赖关系捕捉能力相对较弱尤其是当图像中关键信息分散时CNN可能需要堆叠很多层才能建立这种全局联系。而ViT呢它直接把图像打成一个个“补丁”然后像处理文本序列一样用自注意力机制去处理。这招在捕捉全局上下文信息上简直是“降维打击”模型能一下子“看到”整张图的关联对于需要理解整体场景的任务比如图像分类、目标检测中的上下文推理非常有利。但问题也来了它缺乏CNN那种与生俱来的对图像局部结构的先验知识导致在数据量不足时容易过拟合训练起来更“挑食”而且计算复杂度随着序列长度图像分辨率的平方增长处理高分辨率图像时“钱包”和“显卡”都吃不消。所以这几年我和团队在尝试各种项目时脑子里总在琢磨能不能“鱼与熊掌兼得”把CNN的“局部感知专家”和ViT的“全局关系大师”请到同一个模型里让它们优势互补。这就是“基于卷积和ViT的混合网络”最朴素的出发点。它不是简单的模型堆叠而是一种架构层面的深度融合设计目标是在保持甚至提升模型性能的同时获得更好的效率、更强的鲁棒性以及更灵活的任务适配性。无论你是做图像分类的老手还是刚入行目标检测的新人理解这种混合架构的设计思路都能帮你打开一扇新的大门在模型选型时多一个更优解。2. 混合网络的核心设计哲学与架构选型2.1 设计目标不是替代是协同构建混合网络首要问题是明确设计目标。我们绝不是要造一个“四不像”而是追求112的效果。具体来说核心目标通常包括提升性能与效率的帕累托前沿在相同的计算预算FLOPs或参数量下获得比纯CNN或纯ViT更高的精度如ImageNet Top-1 Acc。或者在达到相同精度时拥有更快的推理速度或更低的内存占用。增强模型鲁棒性结合CNN对局部扰动的不变性和ViT对全局结构理解的稳定性期望模型对常见的图像损坏如噪声、模糊、对抗攻击等具有更强的抵抗力。降低对大规模预训练的依赖利用CNN部分在中小型数据集上良好的可训练性减轻模型对海量有标签数据如JFT-300M预训练的强依赖让更多资源有限的团队也能用好Transformer的强大能力。实现多粒度特征融合自然地将浅层的、细节丰富的局部特征CNN擅长与深层的、语义明确的全局上下文ViT擅长进行融合为下游任务如分割、检测提供更丰富的特征表示。2.2 主流混合范式深度解析根据CNN和Transformer模块在网络中的组织方式目前主流的混合架构可以归纳为三大类每一种都有其独特的“性格”和适用场景。2.2.1 串联式混合清晰的阶段分工这是最直观、也最早被广泛探索的一种方式。你可以把它想象成工厂的“流水线”CNN模块作为“前端处理器”ViT模块作为“后端理解器”。典型代表Convolutional vision Transformer (CvT)LocalViT等。工作原理CNN阶段输入图像首先经过几层通常是2-4个阶段的卷积下采样模块。这个阶段的任务是进行高效的“特征粗提炼”。卷积层快速降低空间分辨率例如从224x224下采样到14x14或28x28同时增加通道维度提取出包含基本纹理、边缘和局部模式的低维特征图。这一步极大地压缩了后续需要处理的序列长度。ViT阶段将CNN阶段输出的特征图通过一个简单的展平操作Flatten或一个可学习的投影层转换成一系列特征向量Tokens然后送入标准的Transformer Encoder堆栈。此时序列长度已经大大减少自注意力机制的计算成本变得可接受。Transformer在这个阶段专注于建立这些高级特征之间的全局依赖关系完成最终的语义理解和分类。优势分析计算高效CNN前置大幅降低了ViT需要处理的序列长度是降低Transformer计算复杂度的最有效手段之一。训练友好CNN部分可以用ImageNet预训练权重初始化ViT部分可以随机初始化或用小规模数据微调整个模型训练收敛更快对数据量的要求降低。结构清晰模块分工明确易于理解和实现。劣势与注意事项信息损失风险CNN下采样过程是信息有损的。过于激进的下采样可能会在早期丢失对细粒度任务如密集预测至关重要的细节信息。因此设计时需要仔细权衡下采样率和后续ViT的输入分辨率。特征对齐CNN输出的特征图与ViT期望的序列输入之间需要做维度转换这个投影层的设计是简单的展平还是带学习的线性层会影响信息传递的效率。实操心得在串联结构中我通常会尝试将CNN部分设计成一个轻量化的“特征提取骨干”比如采用MobileNetV2或EfficientNet的早期块。重点不是让CNN部分有多深而是让它高效地完成“降维”和“初筛”工作把复杂的全局推理放心交给后面的ViT。2.2.2 并联式混合双路信息并行处理并联结构更像是一个“专家委员会”CNN支路和ViT支路同时处理输入或中间特征最后将两者的见解融合。典型代表Concurrent Spatial and Channel ‘Squeeze Excitation’ (scSE)在模块设计上有类似思想但网络级并联的典型如CrossViT的变体思路或一些自定义的双分支网络。工作原理双分支输入同一输入或经过浅层共享层后的特征被同时送入两个并行的分支。一个分支是CNN可能是一组卷积层另一个分支是ViT将图像分块后输入Transformer。独立处理与特征融合两个分支独立进行前向传播分别提取局部特征和全局上下文。在网络的特定阶段例如每个分辨率阶段末尾通过加法Add、拼接Concat或更复杂的注意力融合机制如Cross-Attention将两条支路的特征图进行融合。继续传播融合后的特征既可以作为下一阶段的输入也可以直接用于最终预测。优势分析特征丰富性最大程度地保留了两种范式提取的原始特征融合后的特征理论上包含最全面的信息。灵活性高两个分支的深度、宽度可以独立调整便于针对特定任务进行定制化设计。劣势与注意事项计算成本高维持两条完整的处理路径即使每条路径较浅总计算量也通常高于串联结构。融合策略是关键如何融合是设计的核心难点。简单的拼接会导致通道数激增增加后续计算负担加法要求特征图严格对齐设计跨分支注意力机制则又引入了新的复杂度。训练动态复杂两条分支可能学习速度不同容易出现一条分支主导另一条分支退化的情况需要仔细设计损失函数或训练策略如辅助损失来平衡。2.2.3 深度融合式混合你中有我我中有你这是目前最前沿、也最能体现“混合”精髓的设计。它不再将CNN和Transformer视为独立的模块而是在最基础的构建块Block层面进行融合。典型代表Convolutional Neural Networks with Transformer (ConvNeXt)MobileViTUniFormer。工作原理这类模型通常重新设计了Transformer的基本单元。例如ConvNeXt它本质上是一个“现代化”的纯CNN但其设计大量借鉴了ViT的成功经验如大的卷积核、LayerNorm、GELU激活等达到了媲美Swin Transformer的性能。它证明了精心设计的CNN可以学习到类似Transformer的行为。MobileViT提出了一个“MobileViT Block”。在这个块中首先用标准卷积提取局部特征然后将特征图在空间维度上展开成一系列重叠或不重叠的“块”将这些块视为Tokens在它们之间应用轻量化的Transformer自注意力来建模全局关系最后再变换回特征图格式。这样每个块内部同时完成了卷积和自注意力操作。UniFormer直接在一个统一模块中并行部署局部卷积和全局自注意力并通过可学习的权重动态融合两者的输出。优势分析效率与性能的极致平衡在网络的每一个阶段都同时进行局部和全局建模信息交互更充分、更及时。更适合下游任务输出的始终是标准的2D特征图与现有的大部分计算机视觉任务检测、分割框架无缝兼容无需复杂的适配。参数利用率高避免了模块间冗余通常能获得更紧凑的模型。劣势与注意事项设计复杂度高需要深入理解两种机制的底层原理设计出高效的融合单元创新门槛较高。实现细节敏感如何归一化、如何初始化、注意力头的设置等细节对最终性能影响巨大需要大量的消融实验来验证。避坑指南对于大多数应用者我建议从串联式开始尝试因为它结构简单易于调试和迁移。当对两种机制有更深理解后可以探索深度融合式的现成模型如MobileViT。并联式通常在非常特定的任务如需要极致多尺度特征的任务中才会考虑因为它成本最高。3. 从零构建一个串联式混合网络以CvT为蓝本理论说了这么多我们来点实际的。我将带你一步步实现一个简化版的CvT-Tiny模型用于ImageNet-1k规模的图像分类任务。我们将使用PyTorch框架并详细解释每一个关键步骤的设计意图。3.1 环境准备与依赖安装首先确保你的开发环境已经就绪。我强烈建议使用Anaconda管理环境避免包版本冲突。# 创建并激活一个名为hybrid_net的虚拟环境 conda create -n hybrid_net python3.8 -y conda activate hybrid_net # 安装PyTorch请根据你的CUDA版本访问PyTorch官网获取对应命令 # 例如对于CUDA 11.3 conda install pytorch torchvision torchaudio cudatoolkit11.3 -c pytorch # 安装其他必要的库 pip install timm # 一个非常好的PyTorch模型库我们会参考其设计 pip install tensorboard # 用于可视化训练过程 pip install opencv-python pillow # 用于图像处理3.2 核心模块代码实现我们的简化版CvT主要包含三个部分卷积词嵌入层Convolutional Token Embedding、卷积投影层Convolutional Projection和Transformer编码器。3.2.1 卷积词嵌入层这个层的作用是替代ViT中简单的“线性投影分块”。它使用卷积操作来实现分块和嵌入能更好地保留局部空间信息。import torch import torch.nn as nn import torch.nn.functional as F class ConvEmbedding(nn.Module): 卷积词嵌入层使用卷积实现图像分块和特征嵌入。 输入: (B, C_in, H, W) 输出: (B, num_tokens, embed_dim) def __init__(self, in_channels3, embed_dim64, patch_size7, stride4, padding2): super().__init__() # 使用一个卷积层同时完成分块、展平和特征映射 self.projection nn.Conv2d( in_channels, embed_dim, kernel_sizepatch_size, stridestride, paddingpadding ) # LayerNorm在通道维度上进行这是ViT的常见做法 self.norm nn.LayerNorm(embed_dim) if embed_dim is not None else nn.Identity() def forward(self, x): # x: [B, C, H, W] x self.projection(x) # [B, embed_dim, H, W] # 将空间维度展平形成序列 B, C, H, W x.shape x x.flatten(2).transpose(1, 2) # [B, H*W, embed_dim] x self.norm(x) return x为什么用卷积而不是线性层卷积具有平移等变性且能利用像素间的空间相关性比简单的线性投影更能提取有意义的局部特征。stride决定了分块的重叠程度和最终序列长度。3.2.2 卷积投影层用于Transformer块内在标准的Transformer中Q、K、V的生成是通过线性层。在CvT中这一步也用深度可分离卷积Depthwise Separable Convolution来替代以引入空间上下文。class ConvProjection(nn.Module): 使用深度可分离卷积进行投影 def __init__(self, dim, kernel_size3, stride1, padding1): super().__init__() # 深度卷积每个通道独立进行空间滤波 self.conv nn.Conv2d(dim, dim, kernel_size, stride, padding, groupsdim) # 后续的LayerNorm和线性投影1x1卷积 self.norm nn.LayerNorm(dim) # 1x1卷积等价于线性层用于通道混合/映射 self.projection nn.Conv2d(dim, dim, 1) def forward(self, x): # 输入x的形状假设是 [B, num_tokens, dim] # 为了进行卷积我们需要暂时将其还原为2D特征图 B, N, C x.shape # 这里需要知道原始的空间尺寸H, W。我们假设在构造时传入或者从上下文中获取。 # 为了简化我们在外部处理形状转换。这个模块更常被集成在Attention块内部。 pass在实际的CvT设计中ConvProjection被集成到了注意力模块内部。我们接下来实现一个完整的、带有卷积投影的注意力模块。3.2.3 卷积投影注意力模块class ConvAttention(nn.Module): 带有卷积投影的自注意力模块 def __init__(self, dim, num_heads8, kernel_size3, stride1, padding1, qkv_biasFalse): super().__init__() self.num_heads num_heads self.head_dim dim // num_heads self.scale self.head_dim ** -0.5 # 使用1x1卷积生成q, k, v的“原始”投影 self.qkv nn.Conv2d(dim, dim * 3, 1, biasqkv_bias) # 用于q, k, v的深度可分离卷积 self.conv_q nn.Conv2d(dim, dim, kernel_size, stride, padding, groupsdim) self.conv_k nn.Conv2d(dim, dim, kernel_size, stride, padding, groupsdim) self.conv_v nn.Conv2d(dim, dim, kernel_size, stride, padding, groupsdim) self.proj nn.Conv2d(dim, dim, 1) # 输出投影 self.norm_q nn.LayerNorm(dim) self.norm_k nn.LayerNorm(dim) self.norm_v nn.LayerNorm(dim) def forward(self, x, H, W): # x: [B, N, C], 其中 N H * W B, N, C x.shape # 临时转换回2D格式以进行卷积操作 x_2d x.transpose(1, 2).reshape(B, C, H, W) # 生成q, k, v qkv self.qkv(x_2d).chunk(3, dim1) # 得到三个[B, C, H, W]的张量 q, k, v qkv # 对q, k, v分别进行深度卷积和归一化 q self.conv_q(q).flatten(2).transpose(1, 2) # - [B, N, C] k self.conv_k(k).flatten(2).transpose(1, 2) v self.conv_v(v).flatten(2).transpose(1, 2) q self.norm_q(q) k self.norm_k(k) v self.norm_v(v) # 重塑为多头 q q.reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3) # [B, h, N, d] k k.reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3) v v.reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3) # 自注意力计算 attn (q k.transpose(-2, -1)) * self.scale # [B, h, N, N] attn attn.softmax(dim-1) out (attn v) # [B, h, N, d] out out.permute(0, 2, 1, 3).reshape(B, N, C) # [B, N, C] # 转换回2D进行输出投影 out_2d out.transpose(1, 2).reshape(B, C, H, W) out_2d self.proj(out_2d) out out_2d.flatten(2).transpose(1, 2) # 再转回序列格式 return out设计意图在生成Q、K、V之后不是直接用于注意力计算而是先分别通过一个深度可分离卷积。这个操作相当于为每个Token注入了其局部邻域的信息使得注意力计算在考虑全局关系时也“知晓”每个Token的局部上下文。这可以看作是一种隐式的相对位置编码增强了模型对局部结构的感知。3.2.4 Transformer编码器块与阶段有了注意力模块我们就可以组装完整的Transformer编码器块和阶段。class TransformerBlock(nn.Module): 包含卷积投影注意力的Transformer编码器块 def __init__(self, dim, num_heads, mlp_ratio4., drop_rate0.): super().__init__() self.norm1 nn.LayerNorm(dim) self.attn ConvAttention(dim, num_headsnum_heads) self.drop_path nn.Identity() # 简化省略DropPath self.norm2 nn.LayerNorm(dim) self.mlp nn.Sequential( nn.Linear(dim, int(dim * mlp_ratio)), nn.GELU(), nn.Dropout(drop_rate), nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(drop_rate) ) def forward(self, x, H, W): # 注意这里需要传入特征图的空间尺寸H, W x x self.drop_path(self.attn(self.norm1(x), H, W)) x x self.drop_path(self.mlp(self.norm2(x))) return x class TransformerStage(nn.Module): 一个Transformer阶段包含多个Block def __init__(self, dim, depth, num_heads, mlp_ratio4., drop_rate0.): super().__init__() self.blocks nn.ModuleList([ TransformerBlock(dim, num_heads, mlp_ratio, drop_rate) for _ in range(depth) ]) def forward(self, x, H, W): for blk in self.blocks: x blk(x, H, W) return x3.3 组装CvT-Tiny模型现在我们将卷积嵌入层和Transformer阶段组合起来形成一个三阶段的简化CvT。class SimpleCvT(nn.Module): def __init__(self, in_chans3, num_classes1000): super().__init__() # 阶段1大尺度特征提取生成较长的序列 self.stage1_embed ConvEmbedding(in_chans, embed_dim64, patch_size7, stride4, padding2) self.stage1_transformer TransformerStage(dim64, depth1, num_heads1, mlp_ratio4, drop_rate0.) # 阶段2下采样并增加维度 self.stage2_embed nn.Sequential( nn.LayerNorm(64), # 使用线性层模拟卷积下采样和维度提升 nn.Linear(64, 128), # 这里为了简化我们用一个线性层重塑来模拟空间下采样。 # 更严谨的实现需要用一个卷积层。 ) # 一个技巧我们可以通过重塑序列来模拟空间下采样例如将相邻2x2的tokens合并 # 但为了代码清晰我们假设stage1输出是(B, 3136, 64) - (56*56) # stage2我们目标得到 (B, 784, 128) - (28*28) self.stage2_transformer TransformerStage(dim128, depth2, num_heads2, mlp_ratio4, drop_rate0.) # 阶段3进一步下采样和抽象 self.stage3_embed nn.Sequential( nn.LayerNorm(128), nn.Linear(128, 256), ) # 目标: (B, 196, 256) - (14*14) self.stage3_transformer TransformerStage(dim256, depth10, num_heads4, mlp_ratio4, drop_rate0.) # 分类头 self.norm nn.LayerNorm(256) self.head nn.Linear(256, num_classes) def forward(self, x): B x.shape[0] # 阶段1 x self.stage1_embed(x) # [B, 3136, 64] H1, W1 56, 56 x self.stage1_transformer(x, H1, W1) # 模拟空间下采样将序列重塑为2D然后通过平均池化下采样再展平 # 注意这里是一个简化的、非最优的实现仅用于演示流程。 # 实际CvT使用卷积进行下采样投影。 x_2d x.transpose(1, 2).reshape(B, 64, H1, W1) x_2d F.avg_pool2d(x_2d, kernel_size2, stride2) # - [B, 64, 28, 28] x x_2d.flatten(2).transpose(1, 2) # - [B, 784, 64] x self.stage2_embed(x) # 提升维度 - [B, 784, 128] # 阶段2 H2, W2 28, 28 x self.stage2_transformer(x, H2, W2) # 阶段3下采样 x_2d x.transpose(1, 2).reshape(B, 128, H2, W2) x_2d F.avg_pool2d(x_2d, kernel_size2, stride2) # - [B, 128, 14, 14] x x_2d.flatten(2).transpose(1, 2) # - [B, 196, 128] x self.stage3_embed(x) # - [B, 196, 256] # 阶段3 H3, W3 14, 14 x self.stage3_transformer(x, H3, W3) # 全局平均池化 (在序列维度上平均) x self.norm(x) # [B, 196, 256] x x.mean(dim1) # [B, 256] x self.head(x) return x关键细节提醒上述代码中的下采样部分F.avg_pool2d是一个为了流程清晰而做的简化。在真正的CvT或更优的实现中阶段间的投影和下采样应该由一个卷积层通常是步长为2的卷积来完成这个卷积层同时负责降低空间分辨率和提升通道维度并且这个操作应该被整合到ConvEmbedding类似的模块中。我们的简化版本可能会损失一些信息但对于理解整体架构流程足够了。3.4 模型初始化与训练配置建议混合网络对初始化比较敏感。以下是一些经验性的配置建议def init_weights(m): if isinstance(m, nn.Linear): nn.init.trunc_normal_(m.weight, std.02) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) elif isinstance(m, nn.Conv2d): # 卷积层使用kaiming初始化 nn.init.kaiming_normal_(m.weight, modefan_out, nonlinearityrelu) if m.bias is not None: nn.init.constant_(m.bias, 0) model SimpleCvT(num_classes1000) model.apply(init_weights) # 训练配置建议 # 优化器AdamW 是训练Transformer类模型的首选 optimizer torch.optim.AdamW(model.parameters(), lr1e-3, weight_decay0.05) # 学习率调度余弦退火配合热身 scheduler torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max300, eta_min1e-5) # 损失函数 criterion nn.CrossEntropyLoss()为什么用AdamW和余弦退火AdamW相比Adam解耦了权重衰减能更好地防止过拟合在Transformer社区是标配。余弦退火能让学习率平滑下降配合热身Warmup策略有助于模型在训练初期稳定进入优化过程。4. 混合网络实战调优技巧与常见问题排查即使有了一个能跑起来的模型想要获得好性能调优和排错是关键。下面分享一些我踩过坑后总结的经验。4.1 超参数调优指南混合网络的超参数比单一架构更多需要系统性地调整。超参数组关键参数调优建议与影响典型范围/值架构参数各阶段Embedding维度决定模型容量。通常逐阶段翻倍如64-128-256。增加维度提升性能但增加计算量。[64, 128, 256], [96, 192, 384]各阶段Transformer深度决定模型“思考”复杂度。后期阶段可以更深。深度增加能提升性能但可能导致梯度问题。Stage1:1-2, Stage2:2-4, Stage3:6-12注意力头数通常与Embedding维度相关dim % num_heads 0。头数多有助于捕捉多样关系但过多可能冗余。通常为[1, 2, 4, 8]与dim匹配MLP扩展比率Transformer中前馈网络的隐藏层放大倍数。4是常用值减小可降低参数量。通常为4可在[3, 4]微调卷积参数卷积核大小在ConvEmbedding和ConvProjection中使用。影响局部感受野。7x7和3x3常见。嵌入层7x7 投影层3x3卷积步长在ConvEmbedding中决定下采样率和序列长度。步长越大序列越短计算越快信息损失风险越高。常用4 也可用[2, 4]组合训练参数基础学习率最重要的参数之一。混合网络通常需要比纯CNN更小的学习率。AdamW: 1e-3 到 5e-4权重衰减防止过拟合的关键。对Transformer部分尤为重要。AdamW: 0.05 到 0.1批量大小受限于GPU内存。大Batch可能需提高学习率但可能影响泛化。根据GPU常用128, 256数据增强对性能影响巨大。RandAugment, MixUp, CutMix能显著提升ViT类模型性能。RandAugment (magnitude 9-15), MixUp (alpha0.8)调优流程建议固定训练策略调架构先用一组保守的训练参数适中学习率、基础数据增强尝试不同的深度、宽度组合找到在验证集上表现最好的模型规模。固定架构精调训练选定一个候选架构后系统调整学习率、权重衰减、数据增强强度。使用学习率扫描LR Finder工具确定大致范围。消融实验验证如果时间充裕对关键设计如是否使用ConvProjection、不同的下采样方式进行消融实验用数据证明其有效性。4.2 常见训练问题与解决方案问题训练初期损失不下降或爆炸可能原因学习率过高权重初始化不当梯度流动不畅。排查与解决启用梯度裁剪torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm1.0)。检查初始化确保所有线性层和卷积层都正确初始化如使用trunc_normal_或kaiming_normal_。添加热身在前5-10个epoch使用线性或余弦热身让学习率从0缓慢增加到设定值。降低学习率尝试将初始学习率降低一个数量级。问题模型在训练集上表现好在验证集上差过拟合可能原因模型容量过大数据增强不足正则化不够。排查与解决增强数据引入更强的数据增强如RandAugment, MixUp, CutMix。这对混合网络尤其是ViT部分至关重要。增加正则化提高Dropout率在MLP和注意力权重后增加权重衰减值。尝试标签平滑使用nn.CrossEntropyLoss(label_smoothing0.1)。简化模型如果数据量有限考虑减少Transformer的深度或Embedding维度。问题训练速度慢可能原因序列长度仍然过长注意力计算是瓶颈。排查与解决优化下采样确保第一阶段卷积嵌入的步长足够大将序列长度降到可接受范围如56x563136。使用混合精度训练torch.cuda.amp可以显著加速训练并减少显存占用。检查实现确保自注意力计算是高效的避免不必要的矩阵转置和复制。问题显存溢出OOM可能原因批量大小太大序列长度太长模型参数过多。排查与解决减小批量大小最直接的方法。梯度累积如果不想减小批量大小影响优化可以使用梯度累积。例如实际批量大小为32但每次只处理8个样本累积4步后再更新权重。激活检查点对于很深的Transformer块可以使用torch.utils.checkpoint来节省显存用计算时间换空间。精简模型考虑使用更小的维度或更浅的深度。4.3 下游任务迁移技巧混合网络在ImageNet上预训练后可以很好地迁移到下游任务如目标检测Faster R-CNN, Mask R-CNN和语义分割UPerNet, DeepLabV3。特征金字塔提取我们的CvT有三个阶段自然对应了不同尺度的特征图stride 4, 8, 16。这对于FPN特征金字塔网络非常友好。可以直接将这三个阶段的输出在转换为2D特征图后作为FPN的输入。微调策略分层学习率对骨干网络预训练的混合网络设置较小的学习率如base_lr * 0.1对新添加的任务特定头如检测头、分割头设置较大的学习率。只解冻部分层在微调初期可以只解冻最后1-2个Transformer阶段和任务头冻结前面的卷积阶段和早期Transformer防止灾难性遗忘。更长的训练周期下游任务通常需要比预训练更长的微调周期尤其是当数据集与ImageNet差异较大时。构建和训练一个混合网络就像指挥一个交响乐团需要让CNN和Transformer这两种不同的“乐器”和谐共鸣。从串联式入手理解数据流再深入探索深度融合式的精妙设计这个过程本身就能极大地加深你对现代视觉模型的理解。记住没有最好的架构只有最适合你具体任务、数据和资源的架构。多实验多分析混合网络的世界充满可能。
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2626396.html
如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!