从零构建MMRotate旋转检测实战:自定义数据集制作与模型调优全解析
1. 环境准备与MMRotate安装第一次接触旋转目标检测时我被各种坐标转换搞得头晕眼花。直到发现MMRotate这个神器才让整个流程变得清晰可控。作为OpenMMLab家族成员它封装了R3Det、Rotated Faster RCNN等主流旋转检测算法特别适合遥感影像中的飞机、船舶等旋转目标识别。安装过程其实比想象中简单。我习惯先用conda创建独立环境避免包冲突conda create -n mmrotate python3.8 -y conda activate mmrotate重点来了——PyTorch版本选择。实测发现PyTorch 1.8 CUDA 11.1组合最稳定安装命令如下conda install pytorch1.8.0 torchvision0.9.0 cudatoolkit11.1 -c pytorch接着安装MMCV全家桶这里容易踩坑。必须根据CUDA版本选择对应的mmcv-fullpip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu111/torch1.8.0/index.html最后安装MMRotate本体git clone https://github.com/open-mmlab/mmrotate.git cd mmrotate pip install -r requirements/build.txt pip install -v -e .验证安装是否成功时建议直接跑demo测试。修改demo/image_demo.py中的三个关键参数--img测试图片路径--config选择configs/rotated_retinanet/rotated_retinanet_r50_fpn_1x_dota_le90.py--checkpoint下载官方预训练模型2. 自定义数据集制作实战2.1 旋转标注的艺术roLabelImg详解给旋转目标打标签就像玩俄罗斯方块角度不对全盘皆输。经过多次实践我总结出roLabelImg的黄金三原则长边对齐规则始终让标注框的长边(w)与目标长轴重合。就像给斜放的铅笔盒贴标签应该沿着最长边贴角度范围控制保持角度在[-90°,90°]之间。超过这个范围会导致后续DOTA格式转换出错快捷键流操作熟练使用A/D旋转角度、Z/X微调、C复制上个标注效率提升300%标注完成后会生成包含的XML文件关键字段如下object nameship/name robndbox cx512.3/cx cy256.8/cy w124.5/w h32.7/h angle-45.2/angle /robndbox /object2.2 DOTA格式转换的数学奥秘将roLabelImg的XML转为DOTA格式需要坐标系转换。核心是这行旋转公式def rotatePoint(xc, yc, xp, yp, theta): xoff xp - xc yoff yp - yc cosTheta math.cos(theta) sinTheta math.sin(theta) pResx cosTheta * xoff sinTheta * yoff pResy -sinTheta * xoff cosTheta * yoff return xc pResx, yc pResy转换后的DOTA标签示例1024 768 1152 768 1152 896 1024 896 ship 0 512 256 640 128 768 256 640 384 plane 1每行代表一个目标包含8个坐标点、类别名和难度标记。建议转换后立即可视化检查我曾因角度符号搞反导致所有检测框偏移90度。2.3 智能裁剪与数据集划分遥感图像往往尺寸巨大直接训练会爆显存。MMRotate内置的img_split.py脚本能智能切割图像我的推荐配置{ image_path: data/train/images, label_path: data/train/labels, output_path: data/split_train, patch_width: 1024, patch_height: 1024, overlap_area: 200 }重叠区域(overlap_area)设置很关键太小会导致目标被切割太大会增加计算量。对于船舶检测200像素重叠刚刚好。数据集划分建议比例类型比例用途train70%主训练集val15%超参调优test15%最终评估3. 模型训练中的魔鬼细节3.1 配置文件的双重修改法则MMRotate采用级联配置系统需要修改两处关键文件模型配置文件如rotated_faster_rcnn_r50_fpn_1x_dota_le90.pymodel dict( roi_headdict( bbox_headdict( num_classes5, # 修改为你的类别数 loss_bboxdict(typeGDLoss, loss_typekld)) ), train_cfgdict( rcnndict( assignerdict( pos_iou_thr0.5, # 正样本IoU阈值 neg_iou_thr0.3 # 负样本IoU阈值 ) ) ) )数据集配置文件dotav1.pydata dict( samples_per_gpu4, # 根据显存调整 workers_per_gpu2, # 建议CPU核心数/GPU数量 traindict( typeDOTADataset, ann_filedata/train/annfiles/, img_prefixdata/train/images/ ), valdict(...), testdict(...) )3.2 学习率调优的温度计法则初始学习率设置有个实用公式base_lr 0.02 * (batch_size / 16)当使用4卡GPU、每卡4张图片时optimizer dict( typeSGD, lr0.02 * (4*4)/16, # 0.02 momentum0.9, weight_decay0.0001)训练过程中如果出现loss震荡可以尝试余弦退火策略lr_config dict( policyCosineAnnealing, warmuplinear, warmup_iters500, warmup_ratio1.0/3, min_lr1e-5)4. 实战中的避坑指南4.1 显存不足的三大解决方案梯度累积通过累积多个batch的梯度模拟大batchoptimizer_config dict( typeGradientCumulativeOptimizerHook, cumulative_iters4) # 累积4次混合精度训练减少显存占用还能加速fp16 dict(loss_scale512.)模型轻量化使用更小的backbonemodel dict( backbonedict( depth18, # 改用ResNet18 init_cfgdict(typePretrained, checkpointtorchvision://resnet18)) )4.2 数据增强的黄金组合对于旋转目标检测这套组合拳效果显著train_pipeline [ dict(typeLoadImageFromFile), dict(typeLoadAnnotations, with_bboxTrue), dict(typeRResize, img_scale(1024, 1024)), dict( typeRRandomFlip, flip_ratio0.5, direction[horizontal, vertical, diagonal]), dict( typePolyRandomRotate, rotate_ratio0.5, angles_range180, auto_boundFalse), dict(typeNormalize, mean[123.675, 116.28, 103.53], std[58.395, 57.12, 57.375]), dict(typePad, size_divisor32), dict(typeDefaultFormatBundle), dict(typeCollect, keys[img, gt_bboxes, gt_labels]) ]4.3 模型选择的决策树根据场景选择合适模型高精度需求Rotated Cascade RCNN速度优先Rotated RetinaNet小目标检测R3Det带有特征精炼模块不规则目标S2ANet自适应特征对齐我在船舶检测项目中对比发现模型mAP速度(FPS)显存占用Rotated Faster RCNN72.315.210.4GBR3Det75.811.713.2GBS2ANet77.19.814.5GB最终选择R3Det作为平衡点通过模型剪枝将推理速度提升到18FPS。
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2510159.html
如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!