YOLO-Master 的MoE方案分解
之前进行论文精度。今天看下具体代码文章目录1. OptimizedMOEImproved加载模块过程2. 路由模块 EfficientSpatialRouter3. 专家 SimpleExpert实例条件自适应MoE 剪枝 (MoEPruner)聚类加权 NMS (CW-NMS)1. OptimizedMOEImproved同构专家通常使用相同的 SimpleExpert便于并行优化。引入 Shared Expert (共享专家)增加了一个始终激活的并行分支。这是现代 MoE 的标配保证了模型的保底性能显著提升训练稳定性。高效空间路由使用 EfficientSpatialRouter引入预池化Pre-pooling减少路由计算量。稳定性增强引入 Z-Loss防止 Router 输出的 Logits 数值爆炸进一步稳定训练。标准化的辅助损失整合了负载均衡损失Load Balancing Loss和 Z-Loss。初始化策略对 Router 进行了专门的初始化高斯分布 std0.01防止初期“赢家通吃”。加载模块router_type 是 EfficientSpatialRouterself.routingEfficientSpatialRouter(in_channels,num_experts,top_ktop_k,noise_stdnoise_std)expert_type 是 SimpleExpertself.experts.append(SimpleExpert(in_channels,out_channels,**kwargs))共享专家self.shared_expertnn.Sequential(nn.Conv2d(in_channels,out_channels,1,biasFalse),nn.BatchNorm2d(out_channels),nn.SiLU(inplaceTrue))过程获得路由信息1 routing_weights, routing_indices, loss_dict self.routing(x)共享专家 shared_out self.shared_expert(x)3topk选择indices_flatrouting_indices.view(B,adaptive_top_k)weights_flatrouting_weights.view(B,adaptive_top_k)4专家计算# Select input and computeinpx_input[batch_idx]outself.experts[i](inp)wweights_flat[batch_idx,k_idx].view(-1,1,1,1)expert_output.index_add_(0,batch_idx,out*w)输出final_output shared_out expert_output2. 路由模块 EfficientSpatialRouter先降采样再路由。通过 AvgPool 减小特征图尺寸大幅降低 FLOPs。self.routernn.Sequential(nn.Conv2d(in_channels,reduced_channels,3,padding1,biasFalse),nn.BatchNorm2d(reduced_channels),nn.SiLU(inplaceTrue),nn.Conv2d(reduced_channels,num_experts,1,biasFalse),nn.BatchNorm2d(num_experts)# numerical stability)处理过程global_logitstorch.mean(out,dim[2,3])# [B, E]self._process_logits(global_logits,self.noise_std,self.training)def_process_logits(self,logits:torch.Tensor,noise_std:float,training:bool)-Tuple[torch.Tensor,torch.Tensor,Dict]:Unified logic to process logits into Top-K selection.Blogits.shape[0]# 1) Add noise during training (simplified Gumbel-Softmax trick)iftrainingandnoise_std0:logitslogitstorch.randn_like(logits)*noise_std# 2) Compute probabilitiesprobsF.softmax(logits.float(),dim1).type_as(logits)# 3) Select Top-Ktopk_vals,topk_indicestorch.topk(probs,self.top_k,dim1)# 4) Normalize weightssum_valstopk_vals.sum(dim1,keepdimTrue)1e-6topk_valstopk_vals/sum_vals# 5) Collect loss-related info (train only)loss_dict{}iftraining:loss_dict[router_logits]logits loss_dict[router_probs]probs loss_dict[topk_indices]topk_indicesreturntopk_vals,topk_indices,loss_dict3. 专家 SimpleExpertConv-BN-SiLU-Conv-BN标准结构易于优化。参数量标准。classSimpleExpert(nn.Module):def__init__(self,in_channels,out_channels,expand_ratio2):super().__init__()hidden_dimint(in_channels*expand_ratio)self.convnn.Sequential(nn.Conv2d(in_channels,hidden_dim,1,biasFalse),nn.BatchNorm2d(hidden_dim),nn.SiLU(inplaceTrue),nn.Conv2d(hidden_dim,out_channels,1,biasFalse),nn.BatchNorm2d(out_channels))defforward(self,x):returnself.conv(x)defcompute_flops(self,input_shape):returnFlopsUtils.count_conv2d(self.conv,input_shape)实例条件自适应LoRAMoE 剪枝 (MoEPruner)自动剪枝低利用率专家20-30% 推理加速 moe/pruning.py聚类加权 NMS (CW-NMS)基于聚类理论的检测框融合算法使用高斯加权平均代替硬抑制显著提升定位精度。方法 策略 优点 缺点传统 NMS 直接丢弃重叠框 速度快 可能丢失精确定位Soft-NMS 置信度衰减 保留更多候选框 参数敏感CW-NMS 高斯加权融合 高精度、鲁棒 略微增加计算量
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2457318.html
如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!