yolov7改进优化之蒸馏(一)

news2025/7/16 4:41:45

最近比较忙,有一段时间没更新了,最近yolov7用的比较多,总结一下。上一篇yolov5及yolov7实战之剪枝_CodingInCV的博客-CSDN博客 我们讲了通过剪枝来裁剪我们的模型,达到在精度损失不大的情况下,提高模型速度的目的。上一篇是从速度的角度,这一篇我们从检测性能的角度来改进yolov7(yolov5也类似)。
对于提高检测器的性能,我们除了可以从增加数据、修改模型结构、修改loss等模型本身的角度出发外,深度学习领域还有一个方式—蒸馏。简单的说,蒸馏就是让性能更强的模型(teacher, 参数量更大)来指导性能更弱student模型,从而提高student模型的性能。
蒸馏的方式有很多种,比较简单暴力的比如直接让student模型来拟合teacher模型的输出特征图,当然蒸馏也不是万能的,毕竟student模型和teacher模型的参数量有差距,student模型不一定能很好的学习teacher的知识,对于自己的任务有没有作用也需要尝试。
本篇选择的方法是去年CVPR上的针对目标检测的蒸馏算法:
yzd-v/FGD: Focal and Global Knowledge Distillation for Detectors (CVPR 2022) (github.com)
针对该方法的解读可以参考:FGD-CVPR2022:针对目标检测的焦点和全局蒸馏 - 知乎 (zhihu.com)
本篇暂时不涉及理论,重点在把这个方法集成到yolov7训练。步骤如下。

载入teacher模型

蒸馏首先需要有一个teacher模型,这个teacher模型一般和student同样结构,只是参数量更大、层数更多。比如对于yolov5,可以尝试用yolov5m来蒸馏yolov5s。
train.py增加一个命令行参数:

    parser.add_argument("--teacher-weights", type=str, default="", help="initial weights path")

在train函数中载入teacher weights,过程与原有的载入过程类似,注意,DP或者DDP模型也要对teacher模型做对应的处理。

# teacher model
    if opt.teacher_weights:
        teacher_weights = opt.teacher_weights
        # with torch_distributed_zero_first(rank):
        #     teacher_weights = attempt_download(teacher_weights)  # download if not found locally
        teacher_model = Model(teacher_weights, ch=3, nc=nc).to(device)  # create    
        # load state_dict
        ckpt = torch.load(teacher_weights, map_location=device)  # load checkpoint
        state_dict = ckpt["model"].float().state_dict()  # to FP32
        teacher_model.load_state_dict(state_dict, strict=True)  # load
        #set to eval
        teacher_model.eval()
        #set IDetect to train mode
        # teacher_model.model[-1].train()
        logger.info(f"Load teacher model from {teacher_weights}")  # report

    # DP mode
    if cuda and rank == -1 and torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model)
        if opt.teacher_weights:
            teacher_model = torch.nn.DataParallel(teacher_model)
            
	 # SyncBatchNorm
    if opt.sync_bn and cuda and rank != -1:
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
        logger.info("Using SyncBatchNorm()")
        if opt.teacher_weights:
	        teacher_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(teacher_model).to(device)

teacher模型不进行梯度计算,因此:

if opt.teacher_weights:
        for param in teacher_model.parameters():
            param.requires_grad = False

蒸馏Loss

蒸馏loss是计算teacher模型的一层或者多层与student的对应层的相似度,监督student模型向teacher模型靠近。对于yolov7,可以去监督三个特征层。
参考FGD的开源代码,我们在loss.py中增加一个FeatureLoss类, 参数暂时使用默认:

class FeatureLoss(nn.Module):

    """PyTorch version of `Feature Distillation for General Detectors`
   
    Args:
        student_channels(int): Number of channels in the student's feature map.
        teacher_channels(int): Number of channels in the teacher's feature map. 
        temp (float, optional): Temperature coefficient. Defaults to 0.5.
        name (str): the loss name of the layer
        alpha_fgd (float, optional): Weight of fg_loss. Defaults to 0.001
        beta_fgd (float, optional): Weight of bg_loss. Defaults to 0.0005
        gamma_fgd (float, optional): Weight of mask_loss. Defaults to 0.0005
        lambda_fgd (float, optional): Weight of relation_loss. Defaults to 0.000005
    """
    def __init__(self,
                 student_channels,
                 teacher_channels,
                 temp=0.5,
                 alpha_fgd=0.001,
                 beta_fgd=0.0005,
                 gamma_fgd=0.001,
                 lambda_fgd=0.000005,
                 ):
        super(FeatureLoss, self).__init__()
        self.temp = temp
        self.alpha_fgd = alpha_fgd
        self.beta_fgd = beta_fgd
        self.gamma_fgd = gamma_fgd
        self.lambda_fgd = lambda_fgd
    
        if student_channels != teacher_channels:
            self.align = nn.Conv2d(student_channels, teacher_channels, kernel_size=1, stride=1, padding=0)
        else:
            self.align = None
        
        self.conv_mask_s = nn.Conv2d(teacher_channels, 1, kernel_size=1)
        self.conv_mask_t = nn.Conv2d(teacher_channels, 1, kernel_size=1)
        self.channel_add_conv_s = nn.Sequential(
            nn.Conv2d(teacher_channels, teacher_channels//2, kernel_size=1),
            nn.LayerNorm([teacher_channels//2, 1, 1]),
            nn.ReLU(inplace=True),  # yapf: disable
            nn.Conv2d(teacher_channels//2, teacher_channels, kernel_size=1))
        self.channel_add_conv_t = nn.Sequential(
            nn.Conv2d(teacher_channels, teacher_channels//2, kernel_size=1),
            nn.LayerNorm([teacher_channels//2, 1, 1]),
            nn.ReLU(inplace=True),  # yapf: disable
            nn.Conv2d(teacher_channels//2, teacher_channels, kernel_size=1))

        self.reset_parameters()

    def forward(self,
                preds_S,
                preds_T,
                gt_bboxes,
                img_metas):
        """Forward function.
        Args:
            preds_S(Tensor): Bs*C*H*W, student's feature map
            preds_T(Tensor): Bs*C*H*W, teacher's feature map
            gt_bboxes(tuple): Bs*[nt*4], pixel decimal: (tl_x, tl_y, br_x, br_y)
            img_metas (list[dict]): Meta information of each image, e.g.,
            image size, scaling factor, etc.
        """
        assert preds_S.shape[-2:] == preds_T.shape[-2:], 'the output dim of teacher and student differ'
        device = gt_bboxes.device
        self.to(device)
        if self.align is not None:
            preds_S = self.align(preds_S)

        N,C,H,W = preds_S.shape

        S_attention_t, C_attention_t = self.get_attention(preds_T, self.temp)
        S_attention_s, C_attention_s = self.get_attention(preds_S, self.temp)
        
        Mask_fg = torch.zeros_like(S_attention_t)
        # Mask_bg = torch.ones_like(S_attention_t)
        wmin,wmax,hmin,hmax = [],[],[],[]
        img_h, img_w = img_metas
        bboxes = gt_bboxes[:,2:6]
        #xywh2xyxy
        bboxes = xywh2xyxy(bboxes)
        new_boxxes = torch.ones_like(bboxes)
        new_boxxes[:, 0] = torch.floor(bboxes[:, 0]*W)
        new_boxxes[:, 2] = torch.ceil(bboxes[:, 2]*W)
        new_boxxes[:, 1] = torch.floor(bboxes[:, 1]*H)
        new_boxxes[:, 3] = torch.ceil(bboxes[:, 3]*H)

        #to int
        new_boxxes = new_boxxes.int()

        for i in range(N):
            new_boxxes_i = new_boxxes[torch.where(gt_bboxes[:,0]==i)]

            wmin.append(new_boxxes_i[:, 0])
            wmax.append(new_boxxes_i[:, 2])
            hmin.append(new_boxxes_i[:, 1])
            hmax.append(new_boxxes_i[:, 3])

            area = 1.0/(hmax[i].view(1,-1)+1-hmin[i].view(1,-1))/(wmax[i].view(1,-1)+1-wmin[i].view(1,-1))

            for j in range(len(new_boxxes_i)):
                Mask_fg[i][hmin[i][j]:hmax[i][j]+1, wmin[i][j]:wmax[i][j]+1] = \
                        torch.maximum(Mask_fg[i][hmin[i][j]:hmax[i][j]+1, wmin[i][j]:wmax[i][j]+1], area[0][j])

        Mask_bg = torch.where(Mask_fg > 0, 0., 1.)
        Mask_bg_sum = torch.sum(Mask_bg, dim=(1,2))
        Mask_bg[Mask_bg_sum>0] /= Mask_bg_sum[Mask_bg_sum>0].unsqueeze(1).unsqueeze(2)

        fg_loss, bg_loss = self.get_fea_loss(preds_S, preds_T, Mask_fg, Mask_bg, 
                        C_attention_s, C_attention_t, S_attention_s, S_attention_t)
        mask_loss = self.get_mask_loss(C_attention_s, C_attention_t, S_attention_s, S_attention_t)
        rela_loss = self.get_rela_loss(preds_S, preds_T)

        loss = self.alpha_fgd * fg_loss + self.beta_fgd * bg_loss \
            + self.gamma_fgd * mask_loss + self.lambda_fgd * rela_loss
            
        return loss, loss.detach()

    def get_attention(self, preds, temp):
        """ preds: Bs*C*W*H """
        N, C, H, W= preds.shape

        value = torch.abs(preds)
        # Bs*W*H
        fea_map = value.mean(axis=1, keepdim=True)
        S_attention = (H * W * F.softmax((fea_map/temp).view(N,-1), dim=1)).view(N, H, W)

        # Bs*C
        channel_map = value.mean(axis=2,keepdim=False).mean(axis=2,keepdim=False)
        C_attention = C * F.softmax(channel_map/temp, dim=1)

        return S_attention, C_attention


    def get_fea_loss(self, preds_S, preds_T, Mask_fg, Mask_bg, C_s, C_t, S_s, S_t):
        loss_mse = nn.MSELoss(reduction='sum')
        
        Mask_fg = Mask_fg.unsqueeze(dim=1)
        Mask_bg = Mask_bg.unsqueeze(dim=1)

        C_t = C_t.unsqueeze(dim=-1)
        C_t = C_t.unsqueeze(dim=-1)

        S_t = S_t.unsqueeze(dim=1)

        fea_t= torch.mul(preds_T, torch.sqrt(S_t))
        fea_t = torch.mul(fea_t, torch.sqrt(C_t))
        fg_fea_t = torch.mul(fea_t, torch.sqrt(Mask_fg))
        bg_fea_t = torch.mul(fea_t, torch.sqrt(Mask_bg))

        fea_s = torch.mul(preds_S, torch.sqrt(S_t))
        fea_s = torch.mul(fea_s, torch.sqrt(C_t))
        fg_fea_s = torch.mul(fea_s, torch.sqrt(Mask_fg))
        bg_fea_s = torch.mul(fea_s, torch.sqrt(Mask_bg))

        fg_loss = loss_mse(fg_fea_s, fg_fea_t)/len(Mask_fg)
        bg_loss = loss_mse(bg_fea_s, bg_fea_t)/len(Mask_bg)

        return fg_loss, bg_loss


    def get_mask_loss(self, C_s, C_t, S_s, S_t):

        mask_loss = torch.sum(torch.abs((C_s-C_t)))/len(C_s) + torch.sum(torch.abs((S_s-S_t)))/len(S_s)

        return mask_loss
     
    
    def spatial_pool(self, x, in_type):
        batch, channel, width, height = x.size()
        input_x = x
        # [N, C, H * W]
        input_x = input_x.view(batch, channel, height * width)
        # [N, 1, C, H * W]
        input_x = input_x.unsqueeze(1)
        # [N, 1, H, W]
        if in_type == 0:
            context_mask = self.conv_mask_s(x)
        else:
            context_mask = self.conv_mask_t(x)
        # [N, 1, H * W]
        context_mask = context_mask.view(batch, 1, height * width)
        # [N, 1, H * W]
        context_mask = F.softmax(context_mask, dim=2)
        # [N, 1, H * W, 1]
        context_mask = context_mask.unsqueeze(-1)
        # [N, 1, C, 1]
        context = torch.matmul(input_x, context_mask)
        # [N, C, 1, 1]
        context = context.view(batch, channel, 1, 1)

        return context


    def get_rela_loss(self, preds_S, preds_T):
        loss_mse = nn.MSELoss(reduction='sum')

        context_s = self.spatial_pool(preds_S, 0)
        context_t = self.spatial_pool(preds_T, 1)

        out_s = preds_S
        out_t = preds_T

        channel_add_s = self.channel_add_conv_s(context_s)
        out_s = out_s + channel_add_s

        channel_add_t = self.channel_add_conv_t(context_t)
        out_t = out_t + channel_add_t

        rela_loss = loss_mse(out_s, out_t)/len(out_s)
        
        return rela_loss


    def last_zero_init(self, m):
        if isinstance(m, nn.Sequential):
            constant_init(m[-1], val=0)
        else:
            constant_init(m, val=0)

    
    def reset_parameters(self):
        kaiming_init(self.conv_mask_s, mode='fan_in')
        kaiming_init(self.conv_mask_t, mode='fan_in')
        self.conv_mask_s.inited = True
        self.conv_mask_t.inited = True

        self.last_zero_init(self.channel_add_conv_s)
        self.last_zero_init(self.channel_add_conv_t)

实例化FeatureLoss

在train.py中,实例化我们定义的FeatureLoss,由于我们要蒸馏三层,所以需要定一个蒸馏损失的数组:

if opt.teacher_weights:
        student_kd_layers = hyp["student_kd_layers"]
        teacher_kd_layers = hyp["teacher_kd_layers"]
        dump_image = torch.zeros((1, 3, imgsz, imgsz), device=device)
        targets = torch.Tensor([[0, 0, 0, 0, 0, 0]]).to(device)
        _, features = model(dump_image, extra_features = student_kd_layers)  # forward
        _, teacher_features = teacher_model(dump_image,
                                               extra_features=teacher_kd_layers)
        kd_losses = []
        for i in range(len(features)):
            feature = features[i]
            teacher_feature = teacher_features[i]
            _, student_channels, _ , _ = feature.shape
            _, teacher_channels, _ , _ = teacher_feature.shape

            kd_losses.append(FeatureLoss(student_channels,teacher_channels))

其中hyp[‘xxx_kd_layers’]是用于指定我们要蒸馏的层序号。
为了提取出我们需要的层的特征图,我们还需要对模型推理的代码进行修改,这个放在下一篇,这一篇先把主要流程过一遍。

蒸馏训练

与普通loss一样,在训练中,首先计算蒸馏loss, 然后进行反向传播,区别只是计算蒸馏loss时需要使用teacher模型也对数据进行推理。

if opt.teacher_weights:
	pred, features = model(imgs, extra_features = student_kd_layers)  # forward
	_, teacher_features = teacher_model(imgs, extra_features = teacher_kd_layers)
	if "loss_ota" not in hyp or hyp["loss_ota"] == 1 and epoch >= ota_start:
		loss, loss_items = compute_loss_ota(
			pred, targets.to(device), imgs
		)
	else:
		loss, loss_items = compute_loss(
			pred, targets.to(device)
		)  # loss scaled by batch_size
	# kd loss
	loss_items = torch.cat((loss_items[0].unsqueeze(0), loss_items[1].unsqueeze(0), loss_items[2].unsqueeze(0), torch.zeros(1, device=device), loss_items[3].unsqueeze(0)))
	loss_items[-1]*=imgs.shape[0]
	for i in range(len(features)):
		feature = features[i]
		teacher_feature = teacher_features[i]

		kd_loss, kd_loss_item = kd_losses[i](feature, teacher_feature, targets.to(device), [imgsz,imgsz])
		loss += kd_loss
		loss_items[3] += kd_loss_item
		loss_items[4] += kd_loss_item

在这里,我们将kd_loss累加到了loss上。计算出总的loss,其他就与普通训练一样了。

结语

这篇文章简述了一下yolov7的蒸馏过程,更多细节将在下一篇中讲述。
f77d79a3b79d6d9849231e64c8e1cdfa~tplv-dy-resize-origshort-autoq-75_330.jpeg

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1102831.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

6.6 图的应用

思维导图: 6.6.1 最小生成树 ### 6.6 图的应用 #### 主旨:图的概念可应用于现实生活中的许多问题,如网络构建、路径查询、任务排序等。 --- #### 6.6.1 最小生成树 **概念**:要在n个城市中建立通信联络网,则最少需…

【Mysql】Innodb数据结构(四)

概述 MySQL 服务器上负责对表中数据的读取和写入工作的部分是存储引擎 ,而服务器又支持不同类型的存储引擎,比如 InnoDB 、MyISAM 、Memory 等,不同的存储引擎一般是由不同的人为实现不同的特性而开发的,真实数据在不同存储引擎中…

如何让大模型自由使用外部知识与工具

本文将分享为什么以及如何使用外部的知识和工具来增强视觉或者语言模型。 全文目录: 1. 背景介绍 OREO-LM: 用知识图谱推理来增强语言模型 REVEAL: 用多个知识库检索来预训练视觉语言模型 AVIS: 让大模型用动态树决策来调用工具 技术交流群 建了技术交流群&a…

微信小程序会议OA系统

Flex弹性布局 Flex弹性布局是一种 CSS3 的布局模式,也叫Flexbox。它可以让容器中的元素按一定比例自动分配空间,使得它们在不同宽度、高度等情况下仍能保持整齐和密集不间隙地排列。 在使用Flexbox弹性布局时,首先需要创建一个容器和若干个…

JNDI-Injection-Exploit工具安装

从github上下载安装 git clone https://github.com/welk1n/JNDI-Injection-Exploit.git 打开 cd JNDI-Injection-Exploit 编译安装,Maven入门百科_maven中quickstart是什么意思-CSDN博客 mvn clean package -DskipTests 因为提示mvn错误,解决下…

Spring中Setter注入详解

目录 一、setter注入是什么 二、setter注入详解 三、JDK内置类型注入方式 3.1 数组类型 3.2 set集合类型 3.3 list集合 3.4 map集合 3.5 properties类型 四、用户自定义类型 一、setter注入是什么 书接上回,我们发现在Spring配置文件中为类对象的属性赋值时&#x…

java SpringBoot基础

目录 SpringBootWeb快速入门前言需求开发步骤创建SpringBoot工程(需要联网)定义请求处理类运行测试 HTTP协议HTTP概述HTTP-请求协议格式GET方式的请求协议POST方式的请求协议 HTTP-响应协议格式HTTP-协议解析 WEB服务器-Tomcat简介基本使用注意事项 Spri…

智慧渔业方案:AI渔政视频智能监管平台助力水域禁渔执法

一、方案背景 国内有很多水库及河流设立了禁渔期,加强渔政执法监管对保障国家渔业权益、维护渔业生产秩序、保护渔民群众生命财产安全、推进水域生态文明建设具有重要意义。目前,部分地区的监管手段信息化水平低下,存在人员少、职责多、任务…

排序【七大排序】

文章目录 1. 排序的概念及引用1.1 排序的概念1.2 常见的排序算法 2. 常见排序算法的实现2.1 插入排序2.1.1基本思想:2.1.2 直接插入排序2.1.3 希尔排序( 缩小增量排序 ) 2.2 选择排序2.2.1基本思想:2.2.2 直接选择排序:2.2.3 堆排序 2.3 交换排序2.3.1冒…

新一代开源语音库CoQui TTS冲到了GitHub 20.5k Star

Coqui TTS 项目介绍 Coqui 文本转语音(Text-to-Speech,TTS)是新一代基于深度学习的低资源零样本文本转语音模型,具有合成多种语言语音的能力。该模型能够利用共同学习技术,从各语言的训练资料集转换知识,来…

Leetcode刷题详解——将x减到0的最小操作数

1. 题目链接:1658. 将 x 减到 0 的最小操作数 2. 题目描述: 给你一个整数数组 nums 和一个整数 x 。每一次操作时,你应当移除数组 nums 最左边或最右边的元素,然后从 x 中减去该元素的值。请注意,需要 修改 数组以供接下来的操作…

windows常用命令

一.文件操作 dir:查看文件当前路径目录列表 cd .. :返回上一级目录 cd 路径:进入路径

解决telnet不是内部或外部以及验证某个端口是否开放

1.怎么解决telnet不是内部或外部命令 (1)telnet在win10下默认是不开启的,所以需要我们自己手动开启。 (2)在控制面板中,我们选择程序–启动或关闭windows功能,然后勾选Telnet客户端选项&#…

Python 连接数据库添加字段

任务需求: 数据库hospital集合所有数据添加一个八位数的编码 import pymongo# 连接数据customer(库)hospital(集合) client pymongo.MongoClient(host127.0.0.1) db client.customer collection db.hospitalhospit…

正向代理与反向代理

正向代理 客户端想要直接与目标服务器连接,但是无法直接进行连接,就需要先去访问中间的代理服务器,让代理服务器代替客户端去访问目标服务器 反向代理 屏蔽掉服务器的信息,经常用在多台服务器的分布式部署上,像一些大型…

Unreal Engine 4 + miniconda + Python2.7 + Pycharm

1.​首先启用UE4插件里的Python Scripting插件 ​ 2. 在UE4项目设置中 开启Python开发者模式 生成unreal.py文件,用于在Pychram中引入Unreal PythonAPI 生成的unreal.py 在: "项目路径\Intermediate\PythonStub\unreal.py"3. 安装Miniconda…

问题记录1 json解析问题

问题: json解析int类型不符合预期,使用json.NewDecoder解决。 示例如下: package mainimport ("bytes""encoding/json""fmt" )func main() {data1 : map[string]interface{}{}data1["id"] int64(4…

纽交所上市公司安费诺宣布将以1.397亿美元收购无线解决方案提供商PCTEL

来源:猛兽财经 作者:猛兽财经 猛兽财经获悉,纽交所上市公司安费诺(APH)宣布将以每股7美元现金,总价格1.397亿美元收购无线解决方案提供商PCTEL(PCTI)。 该交易预计将在第四季度或2024年初完成。 Lake Street Capital Markets担任…

如何通过工单管理系统提高服务质量和客户满意度?

在这个高速发展的时代,企业面临着前所未有的挑战。其中,如何提高工作服务效率,成为了摆在每个企业面前的关键问题。在这个背景下,一款全新的工单管理系统——“的修”应运而生,它可以为您提供了优化工单流程的解决方案…