从零开始的目标检测和关键点检测(二):训练一个Glue的RTMDet模型

news2025/7/27 13:46:55

从零开始的目标检测和关键点检测(二):训练一个Glue的RTMDet模型

  • 一、config文件解读
  • 二、开始训练
  • 三、数据集分析
  • 四、ncnn部署

从零开始的目标检测和关键点检测(一):用labelme标注数据集

从零开始的目标检测和关键点检测(三):训练一个Glue的RTMPose模型

[1]用labelme标注自己的数据集中已经标注好数据集(关键点和检测框),通过labelme2coco脚本将所有的labelme json文件集成为两个coco格式的json文件,即train_coco.json和val_coco.json。训练一个RTMDet模型,需要重写config文件。

一、config文件解读

1、数据集类型即coco格式的数据集,metainfo是指框的类别,因为这里只有一个glue的类,因此NUM_CLASSES为1,注意metainfo类别名后的逗号,

# 数据集类型及路径
dataset_type = 'CocoDataset'
data_root = 'data/glue_134_Keypoint/'
metainfo = {'classes': ('glue',)}
NUM_CLASSES = len(metainfo['classes'])

2、加载backnbone预训练权重和RTMDet-tiny预训练权重

# RTMDet-tiny
load_from = 'https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_tiny_8xb32-300e_coco/rtmdet_tiny_8xb32-300e_coco_20220902_112414-78e30dcc.pth'
backbone_pretrain = 'https://download.openmmlab.com/mmdetection/v3.0/rtmdet/cspnext_rsb_pretrain/cspnext-tiny_imagenet_600e.pth'
deepen_factor = 0.167
widen_factor = 0.375
in_channels = [96, 192, 384]
neck_out_channels = 96
num_csp_blocks = 1
exp_on_reg = False

3、训练参数设置,如epoch、batchsize…

MAX_EPOCHS = 200
TRAIN_BATCH_SIZE = 8
VAL_BATCH_SIZE = 4
stage2_num_epochs = 20
base_lr = 0.004
VAL_INTERVAL = 5  # 每隔多少轮评估保存一次模型权重

4、default_runtime,即默认设置,在config文件夹的default_runtime.py可看到。不同的MM-框架的默认设置不一样(如default_scope = 'mmdet'),可以包含这个.py也可以直接复制过来。

default_scope = 'mmdet'
default_hooks = dict(
    timer=dict(type='IterTimerHook'),
    logger=dict(type='LoggerHook', interval=1),
    param_scheduler=dict(type='ParamSchedulerHook'),
    checkpoint=dict(type='CheckpointHook', interval=10, max_keep_ckpts=2, save_best='coco/bbox_mAP'),
    # auto coco/bbox_mAP_50 coco/bbox_mAP_75 coco/bbox_mAP_s
    sampler_seed=dict(type='DistSamplerSeedHook'),
    visualization=dict(type='DetVisualizationHook'))
env_cfg = dict(
    cudnn_benchmark=False,
    mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
    dist_cfg=dict(backend='nccl'))
vis_backends = [dict(type='LocalVisBackend')]
visualizer = dict(
    type='DetLocalVisualizer',
    vis_backends=[dict(type='LocalVisBackend')],
    name='visualizer')
log_processor = dict(type='LogProcessor', window_size=50, by_epoch=True)
log_level = 'INFO'
load_from = None
resume = False

5、训练超参数配置

train_cfg = dict(
    type='EpochBasedTrainLoop',
    max_epochs=MAX_EPOCHS,
    val_interval=VAL_INTERVAL,
    dynamic_intervals=[(MAX_EPOCHS - stage2_num_epochs, 1)])

val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')

# 学习率
param_scheduler = [
    dict(
        type='LinearLR', start_factor=1e-05, by_epoch=False, begin=0,
        end=1000),
    dict(
        type='CosineAnnealingLR',
        eta_min=0.0002,
        begin=150,
        end=300,
        T_max=150,
        by_epoch=True,
        convert_to_iter_based=True)
]

# 优化器
optim_wrapper = dict(
    type='OptimWrapper',
    optimizer=dict(type='AdamW', lr=base_lr, weight_decay=0.05),
    paramwise_cfg=dict(
        norm_decay_mult=0, bias_decay_mult=0, bypass_duplicate=True))
auto_scale_lr = dict(enable=False, base_batch_size=16)

6、数据处理pipeline,做数据预处理(数据增强)

# DataLoader
backend_args = None
train_pipeline = [
    dict(type='LoadImageFromFile', backend_args=None),
    dict(type='LoadAnnotations', with_bbox=True),
    dict(
        type='CachedMosaic',
        img_scale=(640, 640),
        pad_val=114.0,
        max_cached_images=20,
        random_pop=False),
    dict(
        type='RandomResize',
        scale=(1280, 1280),
        ratio_range=(0.5, 2.0),
        keep_ratio=True),
    dict(type='RandomCrop', crop_size=(640, 640)),
    dict(type='YOLOXHSVRandomAug'),
    dict(type='RandomFlip', prob=0.5),
    dict(type='Pad', size=(640, 640), pad_val=dict(img=(114, 114, 114))),
    dict(
        type='CachedMixUp',
        img_scale=(640, 640),
        ratio_range=(1.0, 1.0),
        max_cached_images=10,
        random_pop=False,
        pad_val=(114, 114, 114),
        prob=0.5),
    dict(type='PackDetInputs')
]
test_pipeline = [
    dict(type='LoadImageFromFile', backend_args=None),
    dict(type='Resize', scale=(640, 640), keep_ratio=True),
    dict(type='Pad', size=(640, 640), pad_val=dict(img=(114, 114, 114))),
    dict(
        type='PackDetInputs',
        meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
                   'scale_factor'))
]

7、加载数据和标注并用对应pipeliane做预处理

train_dataloader = dict(
    batch_size=TRAIN_BATCH_SIZE,
    num_workers=4,
    persistent_workers=True,
    sampler=dict(type='DefaultSampler', shuffle=True),
    batch_sampler=None,
    dataset=dict(
        type='CocoDataset',
        data_root=data_root,
        metainfo=metainfo,
        ann_file='train_coco.json',
        data_prefix=dict(img='images/'),
        filter_cfg=dict(filter_empty_gt=True, min_size=32),
        pipeline=train_pipeline,
        backend_args=None),
    pin_memory=True)
val_dataloader = dict(
    batch_size=VAL_BATCH_SIZE,
    num_workers=2,
    persistent_workers=True,
    drop_last=False,
    sampler=dict(type='DefaultSampler', shuffle=False),
    dataset=dict(
        type='CocoDataset',
        data_root=data_root,
        metainfo=metainfo,
        ann_file='val_coco.json',
        data_prefix=dict(img='images/'),
        test_mode=True,
        pipeline=test_pipeline,
        backend_args=None))
test_dataloader = val_dataloader

8、定义模型结构backbone + neck + head

# 模型结构
model = dict(
    type='RTMDet',
    data_preprocessor=dict(
        type='DetDataPreprocessor',
        mean=[103.53, 116.28, 123.675],
        std=[57.375, 57.12, 58.395],
        bgr_to_rgb=False,
        batch_augments=None),
    backbone=dict(
        type='CSPNeXt',
        arch='P5',
        expand_ratio=0.5,
        deepen_factor=deepen_factor,
        widen_factor=widen_factor,
        channel_attention=True,
        norm_cfg=dict(type='SyncBN'),
        act_cfg=dict(type='SiLU', inplace=True),
        init_cfg=dict(
            type='Pretrained',
            prefix='backbone.',
            checkpoint=backbone_pretrain

        )),
    neck=dict(
        type='CSPNeXtPAFPN',
        in_channels=in_channels,
        out_channels=neck_out_channels,
        num_csp_blocks=num_csp_blocks,
        expand_ratio=0.5,
        norm_cfg=dict(type='SyncBN'),
        act_cfg=dict(type='SiLU', inplace=True)),
    bbox_head=dict(
        type='RTMDetSepBNHead',
        num_classes=NUM_CLASSES,
        in_channels=neck_out_channels,
        stacked_convs=2,
        feat_channels=neck_out_channels,
        anchor_generator=dict(
            type='MlvlPointGenerator', offset=0, strides=[8, 16, 32]),
        bbox_coder=dict(type='DistancePointBBoxCoder'),
        loss_cls=dict(
            type='QualityFocalLoss',
            use_sigmoid=True,
            beta=2.0,
            loss_weight=1.0),
        loss_bbox=dict(type='GIoULoss', loss_weight=2.0),
        with_objectness=False,
        exp_on_reg=exp_on_reg,
        share_conv=True,
        pred_kernel_size=1,
        norm_cfg=dict(type='SyncBN'),
        act_cfg=dict(type='SiLU', inplace=True)),
    train_cfg=dict(
        assigner=dict(type='DynamicSoftLabelAssigner', topk=13),
        allowed_border=-1,
        pos_weight=-1,
        debug=False),
    test_cfg=dict(
        nms_pre=30000,
        min_bbox_size=0,
        score_thr=0.001,
        nms=dict(type='nms', iou_threshold=0.65),
        max_per_img=300))

二、开始训练

1、开始训练

python tools/train.py data/glue_134_Keypoint/rtmdet_tiny_glue.py

训练结果

 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.719
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.483
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.766
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.766
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = -1.000
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = -1.000
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.766

2、测试一下训练结果

python demo/image_demo.py data/glue_134_Keypoint/test_image/test.png data/glue_134_Keypoint/rtmdet_tiny_glue.py --weights work_dirs/rtmdet_tiny_glue/best_coco_bbox_mAP_epoch_180.pth --device cpu

在这里插入图片描述

3、可视化训练过程

在这里插入图片描述
在这里插入图片描述

4、由于标注数据集的glue都是小目标的,因此大目标无法识别,如下:
在这里插入图片描述

三、数据集分析

1、可视化部分图像

在这里插入图片描述

框标注-框中心点位置分布

在这里插入图片描述

框标注-框宽高分布

显然都是小目标的检测

在这里插入图片描述

四、ncnn部署

在线模型转换:Deploee

上传文件完成在线转换

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1156770.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Jmeter 汉化中文语言

找到 bin -> jmeter.propertise 修改参数:languageen --> languagazh_CN OK!

上位机底部栏 UI如何设置

上位机如果像设置个多页面切换: 位置: 代码如下: "tabBar": {"color": "black","selectedColor": "#d43c33","borderStyle":"black","backgroundColor": …

EVM6678L 开发教程: IBL-TFTP 引导 elf 文件

目录 EVM6678L 开发教程: IBL-TFTP 引导 elf 文件安装 Tftpd64测试工程测试说明 EVM6678L 开发教程: IBL-TFTP 引导 elf 文件 参考: "C:\ti\mcsdk_2_01_02_06\tools\boot_loader\examples\i2c\tftp\docs\README.txt" 此教程介绍如何在 EVM6678L 开发板上实现 IBL-…

【面试经典150 | 链表】旋转链表

文章目录 Tag题目来源题目解读解题思路方法一:遍历 其他语言python3 写在最后 Tag 【单向链表】 题目来源 61. 旋转链表 题目解读 旋转链表,将链表的每个节点向右移动 k 个位置。 解题思路 方法一:遍历 本题题目意思清晰,实现…

【Linux】jdk Tomcat MySql的安装及Linux后端接口部署

一,jdk安装 1.1 上传安装包到服务器 打开MobaXterm通过Linux地址连接到Linux并登入Linux,再将主机中的配置文件复制到MobaXterm 使用命令查看:ll 1.2 解压对应的安装包 解压jdk 解压命令:tar -xvf jdk 加键盘中Tab键即可…

企业级JAVA、数据库等编程规范之命名风格 —— 超详细准确无误

🧸欢迎来到dream_ready的博客,📜相信你对这两篇博客也感兴趣o (ˉ▽ˉ;) 📜 表白墙/留言墙 —— 初级SpringBoot项目,练手项目前后端开发(带完整源码) 全方位全步骤手把手教学 📜 用户登录前后端…

作为网工有必要了解一下什么是SRv6?

什么是SRv6? 【微|信|公|众|号:厦门微思网络】 【微思网络http://www.xmws.cn,成立于2002年,专业培训21年,思科、华为、红帽、ORACLE、VMware等厂商认证及考试,以及其他认证PMP、CISP、ITIL等】 SRv6&…

MFC简单字符串压缩程序

一个mfc简单字符串压缩程序;按以下情况进行压缩; 1 仅压缩连续重复出现的字符。比如”abcbc”无连续重复字符,压缩后还是”abcbc”。 2 压缩的格式为”字符重复的次数字符”。例如,”xxxyyyyyyz”压缩后就成为”3x6yz”。 void …

Centos7环境下cmake3.25的编译与安装

文章目录 0 视频传送门1 卸载当前版本2 下载cmake3.25.0并且解压缩3 使用root用户进入解压缩的目录4 开始执行命令5 创建软连接6 检查版本 0 视频传送门 https://www.bilibili.com/video/BV1Gu4y1J7Ev/?vd_source3353f83539e46042d8cf76efb177a8e4 07-Centos7编译安装cmake3.…

接口请求的六种常见方式详解(get、post、head等)

一.接口请求的六种常见方式: 1、Get 向特定资源发出请求(请求指定页面信息,并返回实体主体) 2、Post 向指定资源提交数据进行处理请求(提交表单、上传文件),又可能导致新的资源的建…

Leetcode—485.最大连续1的个数【中等】明天修改

2023每日刷题(十五) Leetcode—2.两数相加 迭代法实现代码 /*** Definition for singly-linked list.* struct ListNode {* int val;* struct ListNode *next;* };*/ struct ListNode* addTwoNumbers(struct ListNode* l1, struct ListNode* l…

再现“换桥奇迹”|人大金仓助力中国移动完成营销系统国产化升级

堪称传统基建奇迹的“三元桥43小时换新桥”工程的相关报道还历历在目,而中国移动也经历着类似的考验,需要在2天内完成某在线营销系统整体升级。 作为中国移动的重要数据库产品与服务提供商,留给人大金仓的时间只有每天夜间的4小时&#xff0…

要在VMware(虚拟机)上获取相机连接状态并显示在主界面上,您可以使用以下步骤:

在VM上安装相机驱动程序:确保VM中已安装对应的相机驱动程序,以便能够连接和使用相机。 检查相机连接状态:在VM中,打开设备管理器(Device Manager)并检查相机是否显示为已连接状态。如果显示为已连接&#…

什么是 DevOps

DevOps是一套融合软件开发(Dev)和 IT 运营(Ops)的实践,旨在缩短应用程序开发周期并确保以高软件质量持续交付,通过采用 DevOps 实践,您可以帮助组织更可靠、更快速、更高效地交付软件。 什么是…

python 之正则表达式详解

文章目录 r与R原始字符串的特点:示例:正则表达式示例:文件路径示例: 有没有r 带来的影响使用 r 前缀的示例:不使用 r 前缀的示例: \b 作为单词的界限匹配以 "cat" 开头的单词:匹配以 …

从零开始的目标检测和关键点检测(三):训练一个Glue的RTMPose模型

从零开始的目标检测和关键点检测(三):训练一个Glue的RTMPose模型 一、重写config文件二、开始训练三、ncnn部署 从零开始的目标检测和关键点检测(一):用labelme标注数据集 从零开始的目标检测和关键点检测…

第06章 索引的数据结构

第06章 索引的数据结构 1. 索引及其优缺点 1.1 索引概述 MySQL官方对索引的定义为:索引(Index)是帮助MySQL高效获取数据的数据结构。 **索引的本质:**索引是数据结构。你可以简单理解为“排好序的快速查找数据结构”&#xff…

Vue 监听属性 watchEffect

watchEffect 函数:自动收集依赖源,不用指定监听哪个数据,在监听的回调中用到哪个数据,就监听哪个数据。 而 watch 函数:既要指定监听的数据,也要指定监听的回调。 watchEffect 函数:类似于 co…

OpenCV 笔记(4):图像的算术运算、逻辑运算

Part11. 图像的算术运算 图像的本质是一个矩阵,所以可以对它进行一些常见的算术运算,例如加、减、乘、除、平方根、对数、绝对值等等。除此之外,还可以对图像进行逻辑运算和几何变换。 我们先从简单的图像加、减、逻辑运算开始介绍。后续会有…

Window下SRS服务器的搭建

---2023.7.23 准备材料 srs下载:GitHub - ossrs/srs at 3.0release 目前srs release到5.0版本。 srs官方文档:Introduction | SRS (ossrs.net) Docker下载:Download Docker Desktop | Docker 进入docker官网选择window版本直接下载。由…