从零到一:基于MMPretrain框架定制化训练专属图像分类模型
1. 环境准备与框架安装第一次接触MMPretrain时我对着官方文档折腾了半天环境配置。后来发现用mim这个包管理工具能省去80%的依赖问题。先确保你的Python环境是3.7版本然后执行下面这组命令pip install openmim mim install mmengine mim install mmcv mim install mmpretrain这里有个坑要注意如果系统里有多个Python版本记得用python -m pip指定版本。我之前在Ubuntu上就遇到过pip默认指向Python2.7的情况装完一堆报错。安装完成后验证下是否成功import mmpretrain print(mmpretrain.__version__)建议用conda创建独立环境特别是当你要跑不同版本的实验时。有次我在服务器上同时跑两个项目因为环境冲突浪费了一整天。Windows用户可能会遇到VC编译问题直接安装Visual Studio Build Tools就能解决。2. 数据集准备实战官方示例用的都是标准数据集但实际项目中我们往往要处理自定义数据。以花卉识别为例我的文件夹结构是这样的flower_data/ ├── train/ │ ├── rose/ │ ├── tulip/ │ └── ... └── val/ ├── rose/ ├── tulip/ └── ...关键点在于类别子目录的命名。有次我把daisy拼成dasiy训练时直接报维度错误。建议先用这个脚本检查数据完整性from pathlib import Path data_root Path(flower_data) for split in [train, val]: for cls_dir in (data_root/split).iterdir(): if not any(cls_dir.glob(*.jpg)): print(f空文件夹警告: {cls_dir})对于非标准尺寸的图片MMPretrain会自动resize但建议提前用OpenCV批量处理到相近尺寸。我遇到过一批4000x3000的图片直接训练把显存撑爆了。3. 模型配置魔改技巧官方提供的ResNet18配置是个不错的起点但需要修改几个关键参数复制configs/resnet/resnet18_8xb32_in1k.py为my_resnet18_8xb32_flowers.py修改num_classes102根据你的类别数调整学习率策略把milestones[30,60,90]改为[20,40]对小数据集更友好最容易被忽略的是data_preprocessor里的mean和std值。如果用预训练模型却不改这些参数效果会大打折扣。有个取巧的方法from mmpretrain import get_model model get_model(resnet18_8xb32_in1k) print(model.data_preprocessor.mean) # 输出预训练模型的归一化参数对于自定义数据集加载建议继承CustomDataset而不是照搬ImageNet的写法。这是我改良后的数据集类模板from mmpretrain.datasets import CustomDataset class FlowerDataset(CustomDataset): METAINFO { classes: (rose, tulip, ...), # 你的类别列表 palette: [(255,0,0), (0,255,0), ...] # 可视化用的颜色 } def __init__(self, **kwargs): super().__init__(**kwargs) # 自定义初始化逻辑4. 训练调参实战心得启动训练前先运行以下命令检查配置是否有效mim train mmpretrain my_resnet18_8xb32_flowers.py --work-dir ./work_dirs --validate几个实用参数--cfg-options临时覆盖配置项比如optim_wrapper.optimizer.lr0.01--auto-scale-lr根据batch size自动缩放学习率--resume从上次中断处继续训练训练过程中要盯紧这几个指标train/acc如果一直不上升可能是学习率太小val/acc与训练集差距过大说明过拟合memory显存占用突然飙升可能有bug我用RTX 3090训练ResNet18的实测数据批量大小32显存占用约5GB100个epoch耗时约2小时1万张图片最佳验证准确率出现在第65epoch左右5. 模型部署与优化训练完的模型可以通过tools/test.py快速验证mim test mmpretrain \ ./work_dirs/resnet18_8xb32_flowers/epoch_100.pth \ --config my_resnet18_8xb32_flowers.py \ --metrics accuracy precision recall想要部署到生产环境建议导出为ONNX格式from mmpretrain import get_model model get_model(resnet18_8xb32_in1k, pretrainedwork_dirs/epoch_100.pth) torch.onnx.export(model, torch.rand(1,3,224,224), flower.onnx)对于边缘设备部署可以试试量化压缩model.cpu() quantized_model torch.quantization.quantize_dynamic( model, {torch.nn.Linear}, dtypetorch.qint8)6. 避坑指南路径问题Windows用户注意反斜杠转义建议用pathlib.Path处理路径版本冲突MMCV和PyTorch版本必须严格匹配参考官方兼容性表格显存不足尝试减小batch_size或使用amp自动混合精度标签错误先用小批量数据(--max-keep-ckpts1)快速验证流程过拟合添加model.head.dropout0.5或数据增强有次我遇到验证集准确率始终为0最后发现是val_dataloader里忘了设置shuffleFalse。这种错误日志不会直接报错但会导致评估失效。
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2487896.html
如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!