告别Anchor和NMS:用PyTorch从零开始手搓DETR,理解Transformer如何颠覆目标检测
从零实现DETR用Transformer重构目标检测范式当YOLO和Faster R-CNN仍在目标检测领域占据主导地位时Facebook Research在2020年提出的DETR(DEtection TRansformer)带来了一场范式革命。这个将Transformer引入计算机视觉的架构彻底摒弃了沿用多年的anchor设计和NMS后处理用最纯粹的端到端方式重新定义了目标检测。1. 传统检测器的先天缺陷与DETR的革新在计算机视觉领域目标检测任务长期被两阶段如Faster R-CNN和单阶段如YOLO、SSD方法统治。这些方法虽然效果显著但都存在一些根本性限制Anchor机制的束缚需要精心设计anchor的大小、长宽比和数量NMS后处理的矛盾既要去除冗余框又可能误删正确预测复杂pipeline区域提议、ROI对齐等多步骤处理引入大量超参数# 传统检测器典型流程示例 anchors generate_anchors(image_size) # 生成预设anchor proposals region_proposal_network(features) # 区域提议 rois roi_align(proposals, features) # ROI对齐 predictions classification_head(rois) # 分类预测 final_boxes non_max_suppression(predictions) # NMS处理DETR的创新在于用Transformer的全局注意力机制替代了这些手工设计组件传统方法组件DETR对应方案优势对比Anchor boxesObject queries无需预设形状可学习NMS二分图匹配避免启发式阈值设置多阶段处理单阶段端到端简化训练流程2. DETR核心架构深度解析2.1 骨干网络与位置编码DETR使用标准CNN如ResNet作为骨干网络提取图像特征但与传统方法不同这些特征会与位置编码结合后输入Transformerclass DETRBackbone(nn.Module): def __init__(self, resnet): super().__init__() self.body IntermediateLayerGetter(resnet, return_layers{layer4: 0}) def forward(self, images): features self.body(images.tensors) pos_encoding self.position_encoding(images) # 位置编码 return features, pos_encoding位置编码有两种实现方式正弦位置编码固定模式具有平移不变性可学习位置编码通过训练自适应调整实际应用中正弦编码在小数据集上表现更好而可学习编码在大规模数据时可能更具优势2.2 Transformer编码器-解码器结构DETR的Transformer模块是其核心创新点与传统NLP中的Transformer有所不同class DETRTransformer(nn.Module): def __init__(self, d_model512, nhead8, num_layers6): encoder_layer TransformerEncoderLayer(d_model, nhead) self.encoder TransformerEncoder(encoder_layer, num_layers) decoder_layer TransformerDecoderLayer(d_model, nhead) self.decoder TransformerDecoder(decoder_layer, num_layers) def forward(self, src, mask, query_embed, pos_embed): memory self.encoder(src, src_key_padding_maskmask, pospos_embed) hs self.decoder(query_embed, memory, memory_key_padding_maskmask, pospos_embed, query_posquery_embed) return hs, memory关键设计细节编码器处理图像特征建立全局上下文关系解码器接收固定数量的object queries可学习参数每层解码器都会输出中间预测辅助训练2.3 Object queries的奥秘Object queries是DETR中最富创意的设计之一它们作为解码器的输入数量决定了最大检测目标数每个query对应一个潜在的检测目标通过注意力机制与全局图像特征交互# 典型实现方式 num_queries 100 # COCO数据集常用值 query_embed nn.Embedding(num_queries, hidden_dim)3. 二分图匹配替代NMS的优雅方案传统检测器使用NMS去除冗余预测框而DETR采用匈牙利算法进行一对一匹配class HungarianMatcher(nn.Module): def __init__(self, cost_class1, cost_bbox1, cost_giou1): self.cost_class cost_class self.cost_bbox cost_bbox self.cost_giou cost_giou torch.no_grad() def forward(self, outputs, targets): bs, num_queries outputs[pred_logits].shape[:2] out_prob outputs[pred_logits].flatten(0, 1).softmax(-1) out_bbox outputs[pred_boxes].flatten(0, 1) tgt_ids torch.cat([v[labels] for v in targets]) tgt_bbox torch.cat([v[boxes] for v in targets]) cost_class -out_prob[:, tgt_ids] cost_bbox torch.cdist(out_bbox, tgt_bbox, p1) cost_giou -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox)) C self.cost_bbox * cost_bbox self.cost_class * cost_class self.cost_giou * cost_giou C C.view(bs, num_queries, -1).cpu() sizes [len(v[boxes]) for v in targets] indices [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))] return [(torch.as_tensor(i, dtypetorch.int64), torch.as_tensor(j, dtypetorch.int64)) for i, j in indices]匹配成本由三部分组成分类概率成本边界框L1距离成本GIoU相似度成本4. 从零实现DETR关键组件4.1 模型构建流程完整DETR模型的搭建遵循清晰的结构def build_detr(args): # 1. 构建骨干网络 backbone build_backbone(args) # 2. 构建Transformer transformer build_transformer(args) # 3. 组合成完整模型 model DETR( backbone, transformer, num_classesargs.num_classes, num_queriesargs.num_queries ) # 4. 构建匹配器和损失函数 matcher build_matcher(args) criterion SetCriterion( num_classesargs.num_classes, matchermatcher, weight_dictweight_dict ) return model, criterion4.2 训练技巧与参数设置DETR训练需要特别注意以下方面学习率调度通常采用带warmup的分步下降策略梯度裁剪防止Transformer训练不稳定损失权重精心平衡分类与回归损失# 典型训练配置 optimizer torch.optim.AdamW([ {params: [p for n, p in model.named_parameters() if backbone not in n and p.requires_grad]}, {params: [p for n, p in model.named_parameters() if backbone in n and p.requires_grad], lr: args.lr_backbone} ], lrargs.lr, weight_decayargs.weight_decay) lr_scheduler torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop)5. DETR的局限与改进方向尽管DETR带来了创新但仍存在一些挑战训练收敛慢通常需要500epoch才能达到最佳效果小目标检测性能全局注意力可能忽略细小物体计算资源需求Transformer的自注意力复杂度随图像尺寸平方增长后续改进模型如Deformable DETR通过引入可变形注意力机制有效缓解了这些问题。
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2447741.html
如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!