Detectron2模型训练实战:用自定义数据集训练Mask R-CNN(PyTorch 1.8+环境)
Detectron2模型训练实战用自定义数据集训练Mask R-CNNPyTorch 1.8环境1. 环境准备与框架安装在开始训练之前确保你的系统满足以下基本要求操作系统Linux或Windows需额外配置Python版本≥3.7PyTorch版本≥1.8CUDA版本≥10.2推荐11.3GPU显存≥8GB训练Mask R-CNN建议11GB以上1.1 安装PyTorch和依赖项首先安装与CUDA版本匹配的PyTorch# 对于CUDA 11.3 conda install pytorch1.12.1 torchvision0.13.1 torchaudio0.12.1 cudatoolkit11.3 -c pytorch然后安装其他必要依赖pip install cython opencv-python pyyaml matplotlib tqdm1.2 安装COCO APIDetectron2需要COCO API来处理数据集pip install githttps://github.com/cocodataset/cocoapi.git#subdirectoryPythonAPI1.3 安装Detectron2推荐从源码编译安装以获得最佳兼容性git clone https://github.com/facebookresearch/detectron2.git cd detectron2 pip install -e .注意如果遇到编译错误请确保已安装正确版本的GCC≥5.4和NVCC2. 准备自定义数据集2.1 数据集格式转换Detectron2原生支持COCO格式如果你的数据是VOC或其他格式需要先转换from detectron2.data.datasets import register_coco_instances import json # 创建COCO格式的标注文件 def convert_to_coco(input_annotations, output_file): coco_format { info: {}, licenses: [], categories: [{id: 1, name: your_class}], # 替换为你的类别 images: [], annotations: [] } # 添加转换逻辑... with open(output_file, w) as f: json.dump(coco_format, f) # 注册数据集 register_coco_instances(my_dataset_train, {}, path/to/train.json, path/to/images) register_coco_instances(my_dataset_val, {}, path/to/val.json, path/to/images)2.2 数据集验证加载并可视化数据集以确保格式正确from detectron2.utils.visualizer import Visualizer from detectron2.data import MetadataCatalog dataset_dicts DatasetCatalog.get(my_dataset_train) metadata MetadataCatalog.get(my_dataset_train) for d in random.sample(dataset_dicts, 3): img cv2.imread(d[file_name]) visualizer Visualizer(img[:, :, ::-1], metadatametadata, scale0.5) vis visualizer.draw_dataset_dict(d) cv2.imshow(Sample, vis.get_image()[:, :, ::-1]) cv2.waitKey(0)3. 配置训练参数3.1 基础配置Detectron2使用YAML文件配置模型参数。以下是Mask R-CNN的关键配置项MODEL: META_ARCHITECTURE: GeneralizedRCNN WEIGHTS: detectron2://ImageNetPretrained/MSRA/R-50.pkl # 预训练权重 BACKBONE: NAME: build_resnet_fpn_backbone RESNETS: DEPTH: 50 OUT_FEATURES: [res2, res3, res4, res5] FPN: IN_FEATURES: [res2, res3, res4, res5] ROI_HEADS: NUM_CLASSES: 1 # 你的类别数 ROI_BOX_HEAD: NAME: FastRCNNConvFCHead NUM_FC: 2 POOLER_RESOLUTION: 7 ROI_MASK_HEAD: NAME: MaskRCNNConvUpsampleHead NUM_CONV: 4 POOLER_RESOLUTION: 143.2 训练超参数优化根据硬件条件调整以下关键参数参数推荐值说明SOLVER.BASE_LR0.001-0.01学习率SOLVER.MAX_ITER20000-50000最大迭代次数SOLVER.STEPS(10000, 18000)学习率衰减步长SOLVER.IMS_PER_BATCH2-8每批图像数MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE64-512每图RoI数3.3 数据增强策略在配置文件中添加数据增强INPUT: MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) # 随机缩放 MAX_SIZE_TRAIN: 1333 CROP: ENABLED: True TYPE: absolute SIZE: (512, 512) COLOR_AUG_SSD: True # 颜色扰动4. 训练过程实现4.1 启动训练使用DefaultTrainer进行训练from detectron2.engine import DefaultTrainer from detectron2.config import get_cfg cfg get_cfg() cfg.merge_from_file(path/to/config.yaml) cfg.DATASETS.TRAIN (my_dataset_train,) cfg.DATASETS.TEST (my_dataset_val,) cfg.OUTPUT_DIR output os.makedirs(cfg.OUTPUT_DIR, exist_okTrue) trainer DefaultTrainer(cfg) trainer.resume_or_load(resumeFalse) trainer.train()4.2 自定义训练逻辑如需扩展训练流程可继承DefaultTrainerclass CustomTrainer(DefaultTrainer): classmethod def build_evaluator(cls, cfg, dataset_name): return COCOEvaluator(dataset_name, output_dircfg.OUTPUT_DIR) def build_hooks(self): hooks super().build_hooks() hooks.insert(-1, LossEvalHook( cfg.TEST.EVAL_PERIOD, self.model, build_detection_test_loader( self.cfg, self.cfg.DATASETS.TEST[0], DatasetMapper(self.cfg, False) ) )) return hooks4.3 训练监控Detectron2提供多种监控方式TensorBoard集成tensorboard --logdir output自定义指标记录from detectron2.utils.events import EventStorage with EventStorage() as storage: storage.put_scalar(lr, optimizer.param_groups[0][lr], smoothingFalse) storage.put_image(input, torch.from_numpy(img[:, :, ::-1]))验证集评估from detectron2.evaluation import inference_on_dataset from detectron2.data import build_detection_test_loader evaluator COCOEvaluator(my_dataset_val, cfg, False, output_dir./output) val_loader build_detection_test_loader(cfg, my_dataset_val) print(inference_on_dataset(trainer.model, val_loader, evaluator))5. 模型优化与调参技巧5.1 学习率策略优化针对不同层设置差异化学习率from detectron2.solver import build_optimizer def build_optimizer(cfg, model): params [] for key, value in model.named_parameters(): if backbone in key: params.append({params: value, lr: cfg.SOLVER.BASE_LR * 0.1}) else: params.append({params: value}) return torch.optim.SGD(params, lrcfg.SOLVER.BASE_LR, momentum0.9)5.2 模型架构调整FPN改进方案增加P6/P7特征层FPN: IN_FEATURES: [res2, res3, res4, res5] OUT_CHANNELS: 256 TOP_BLOCK: LastLevelMaxPool使用Deformable ConvolutionRESNETS: DEFORM_ON_PER_STAGE: [False, True, True, True] DEFORM_MODULATED: True5.3 损失函数调优自定义损失权重from detectron2.modeling import ROI_HEADS_REGISTRY ROI_HEADS_REGISTRY.register() class CustomROIHeads(StandardROIHeads): def __init__(self, cfg, input_shape): super().__init__(cfg, input_shape) self.box_predictor.loss_weight {loss_cls: 1.0, loss_box_reg: 2.0} self.mask_head.loss_weight 1.56. 模型部署与推理6.1 模型导出将训练好的模型导出为TorchScript格式from detectron2.export import scripting model trainer.build_model(cfg) scripted_model scripting.export_script_model(model, (320, 320)) torch.jit.save(scripted_model, model_scripted.pt)6.2 高性能推理优化推理流程的关键参数参数推荐值说明MODEL.ROI_HEADS.SCORE_THRESH_TEST0.5-0.7检测置信度阈值MODEL.ROI_HEADS.NMS_THRESH_TEST0.3-0.5NMS重叠阈值INPUT.MIN_SIZE_TEST800输入图像最小尺寸INPUT.MAX_SIZE_TEST1333输入图像最大尺寸6.3 批量推理实现from detectron2.data import DatasetMapper from detectron2.engine import DefaultPredictor predictor DefaultPredictor(cfg) def batch_inference(images): mapper DatasetMapper(cfg, False) inputs [{image: torch.from_numpy(img[:, :, ::-1].astype(float32))} for img in images] with torch.no_grad(): outputs predictor.model(inputs) return outputs7. 常见问题解决方案7.1 训练问题排查内存不足错误减小SOLVER.IMS_PER_BATCH使用梯度累积SOLVER: GRADIENT_ACCUMULATION_STEPS: 2NaN损失检查数据标注降低学习率添加梯度裁剪SOLVER: CLIP_GRADIENTS: ENABLED: True CLIP_VALUE: 1.0 CLIP_TYPE: value7.2 性能优化技巧混合精度训练SOLVER: AMP: ENABLED: True数据加载优化from detectron2.data import build_detection_train_loader from detectron2.data.samplers import RepeatFactorTrainingSampler def build_train_loader(cfg): sampler RepeatFactorTrainingSampler( repeat_factors..., # 设置样本重复因子 shuffleTrue ) return build_detection_train_loader(cfg, samplersampler)
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2457885.html
如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!