基于SSD的安全帽检测

news2025/7/18 14:06:55

目录

  • 1. 作者介绍
  • 2. SSD算法介绍
    • 2.1 SSD算法网络结构
    • 2.2 SSD算法训练过程
    • 2.3 SSD算法优缺点
  • 3. 基于SSD的安全帽检测实验
    • 3.1 VOC 2007安全帽数据集
    • 3.2 SSD网络架构
    • 3.3 训练和验证所需的2007_train.txt和2007_val.txt文件生成
    • 3.4 模型训练
    • 3.5 GUI界面
    • 3.6 结果展示
    • 3.7 文件下载
  • 4. 参考连接

1. 作者介绍

胡振远,男,西安工程大学电子信息学院,2023级研究生
研究方向:机器视觉与人工智能
电子邮件:zhenyuan@stu.xpu.edu.cn

张思怡,女,西安工程大学电子信息学院,2022级研究生,张宏伟人工智能课题组
研究方向:机器视觉与人工智能
电子邮件:981664791@qq.com

2. SSD算法介绍

单次检测多框检测器(Single Shot MultiBox Detector,SSD)是一种目标检测算法,它可以在一张图像上同时检测多个目标,并返回它们的位置和类别。SSD算法的基本原理是将图像分成多个不同尺度的特征图,然后在每个特征图上使用卷积操作来检测目标。这种方法能够捕捉不同尺度的目标,因为小尺寸的目标可能只出现在高分辨率的特征图中,而大尺寸的目标可能会出现在低分辨率的特征图中。

SSD算法首先使用一个基础网络(如图VGG16)来提取图像的特征。然后,在基础网络的顶层添加了一系列卷积层,这些层会生成不同分辨率的特征图。这些多尺度的特征图能够覆盖不同大小和形状的目标。在每个特征图的每个位置,SSD生成一组预定义的锚框(也称为先验框)。这些锚框具有不同的宽高比和大小,用于捕捉各种可能的目标。在每个特征图上,SSD使用卷积滤波器进行检测。对于每个锚框,SSD通过分类分支预测其目标类别的概率分布,并通过回归分支预测锚框的边界框偏移量。分类分支负责判断锚框内是否存在目标及其类别,而回归分支负责调整锚框的位置和大小以更准确地围绕目标。

在所有锚框的预测完成后,SSD应用非极大值抑制(Non-Maximum Suppression, NMS)算法来删除重叠的锚框。NMS根据预测的置信度分数选择最优的锚框,同时抑制与其重叠较大的其他锚框,确保每个目标只保留一个检测框。最终,SSD算法将经过过滤的锚框作为检测结果,返回目标的类别和精确位置。通过这种多尺度的检测机制,SSD能够高效地进行目标检测,兼顾检测速度和精度,广泛应用于实时检测任务中。
VGG16 网络架构

2.1 SSD算法网络结构

SSD算法的网络结构可以分为两个部分:基础网络和检测网络。基础网络通常采用经典的卷积神经网络(如VGG16),用于提取图像特征。检测网络由一系列卷积层和池化层组成,用于在特征图上执行目标检测任务。
具体来说,SSD算法的检测网络通常包括以下几个部分:

  1. 特征提取层:在基础网络的基础上,添加一些卷积层和池化层来进一步提取图像的深层特征。这些层通过不断的卷积和下采样操作,提取出更加抽象且具有语义信息的特征表示。
  2. 卷积特征图层:将特征映射到不同的尺度,生成一系列的特征图。这些特征图对应于不同的分辨率,使得网络可以在不同尺度上进行检测,从而能够检测出不同大小的目标。
  3. 检测层:在每个特征图上执行目标检测任务,具体包括分类分支和回归分支。分类分支负责确定特定位置上是否存在目标及其类别,回归分支则负责预测目标的边界框坐标。每个检测层通常包括多个先验框(default boxes),并针对每个先验框进行分类和回归预测。
  4. 后处理层:应用非极大值抑制(Non-Maximum Suppression, NMS)算法来过滤掉冗余的检测结果,并返回最终的目标检测结果。NMS通过选择置信度最高的边界框,并抑制与其重叠较大的其他框,确保每个目标只保留一个最优检测框,从而提高检测结果的准确性和鲁棒性。
    在这里插入图片描述

2.2 SSD算法训练过程

SSD算法的训练过程可以分为两个阶段:预训练和微调。

在预训练阶段,SSD算法使用基础网络进行图像分类任务的训练,以提取图像特征。这个阶段通常采用经典的卷积神经网络(如VGG16或ResNet),并在大型图像分类数据集(如ImageNet)上进行训练,从而获得高质量的图像特征表示。这些预训练的特征随后用于构建SSD的检测网络,在此基础上添加额外的卷积层以生成多尺度特征图。

在微调阶段,SSD算法对整个网络进行微调,以优化目标检测性能。这个阶段的训练数据集通常包括带有目标检测标注的真实图像,如PASCAL VOC或COCO数据集。数据集中的每张图像都包含目标的位置和类别标签。为了增强模型的泛化能力和鲁棒性,常采用各种数据增强技术,如随机裁剪、水平翻转、颜色抖动和缩放等。

在微调阶段,SSD算法通过优化损失函数来学习网络的权重和偏置参数,以最小化目标检测误差。SSD的损失函数通常由两部分组成:分类损失和回归损失。分类损失用于评估分类分支的预测结果是否正确,常使用交叉熵损失函数来衡量。回归损失用于评估回归分支的预测结果是否准确,通常采用平滑L1损失函数(又称Huber损失)来度量预测边界框与真实边界框之间的偏差。

为了加快训练速度和提高模型性能,SSD算法还使用了一些技术措施。例如,数据增强通过增加训练数据的多样性来提高模型的泛化能力;批量归一化(Batch Normalization)用于加速训练过程并稳定模型的训练;Dropout技术用于防止过拟合;学习率调整策略(如学习率衰减或自适应学习率优化器)用于在训练过程中动态调整学习率,从而更有效地找到最优解。

2.3 SSD算法优缺点

SSD算法作为一种用于对象检测的深度学习模型,其主要优点在于高效的检测速度和较高的检测精度。SSD通过在单次前向传递中同时预测多个尺度和纵横比的边界框,从而实现了实时检测。它采用不同尺度的特征图进行多尺度检测,使得对不同大小的对象具有较好的适应性。

此外,SSD的网络结构相对简单,不需要像R-CNN系列那样进行候选区域的生成和分类,因此在推理速度上具有显著优势,适合于实时应用场景,如自动驾驶、视频监控和移动设备上的应用。

然而,SSD算法也存在一些不足之处。首先,在检测小目标时,SSD的表现往往不如一些更加复杂的算法,因为较早层的特征图分辨率较低,导致小目标信息容易丢失。其次,虽然SSD在速度上占优,但在极高精度要求的任务中,其检测精度可能不如一些后续发展的检测算法,如RetinaNet和YOLOv4。最后,由于SSD直接对特征图进行检测,对于背景复杂的场景,其误检率可能较高,需要进一步的后处理步骤来提高精度。

3. 基于SSD的安全帽检测实验

3.1 VOC 2007安全帽数据集

在这里插入图片描述
VOC 2007安全帽数据集是Pascal Visual Object Classes (VOC) Challenge 2007的一部分,旨在为对象检测和分类任务提供标准化的数据集和评估框架。该数据集包含一系列具有复杂场景和多样化物体的图像,其中安全帽是一个具体的目标类别。VOC 2007数据集包括训练集、验证集和测试集,分别用于模型训练、参数调优和性能评估。每个图像都附带有详细的标注信息,包括物体类别、边界框位置等,这些标注信息是由人工精确标记的,以确保高质量的标签数据。此外,数据集还提供了预定义的评价指标,如平均精度(AP),用于衡量模型在对象检测任务中的表现。VOC 2007安全帽数据集在计算机视觉领域广泛应用,尤其是在训练和测试对象检测算法方面,是许多研究工作的基准数据集之一。

3.2 SSD网络架构

所需环境:
torch == 1.2.0

创建一个用于图像检测和预测的SSD对象。首先导入了必要的库和模块,然后定义了一些默认参数,包括模型路径、类别文件路径、输入图像尺寸、主干网络类型、置信度阈值、非极大抑制阈值、先验框尺寸、是否使用无失真缩放以及是否使用CUDA加速。在初始化SSD对象时,更新这些默认参数,计算类别数量并加载先验框,同时初始化用于绘制边界框的颜色,并加载模型和预训练权重。对于图像检测,代码将输入图像转换为RGB格式并调整尺寸,然后进行预处理,包括归一化和添加batch维度,接着将图像输入网络获取预测结果并解码,最后在原图像上绘制边界框和标签。提供计算模型每秒处理帧数的方法,以及生成用于评估模型性能的mAP评估结果的功能。

import colorsys
import os
import time
import warnings

import numpy as np
import torch
import torch.backends.cudnn as cudnn
from PIL import Image, ImageDraw, ImageFont

from nets.ssd import SSD300
from utils.anchors import get_anchors
from utils.utils import cvtColor, get_classes, resize_image, preprocess_input
from utils.utils_bbox import BBoxUtility

warnings.filterwarnings("ignore")

class SSD(object):
    _defaults = {
        #--------------------------------------------------------------------------#
        #   使用自己训练好的模型进行预测要修改model_path和classes_path
        #   model_path指向logs文件夹下的权值文件,classes_path指向model_data下的txt
        #
        #--------------------------------------------------------------------------#
        "model_path": "E:\yanyi\yanyixia\AI\ssd-pytorch-bilibili\logs\ep023-loss3.253-val_loss3.239.pth",
        "classes_path": 'model_data/voc_classes.txt',
        #---------------------------------------------------------------------#
        #   用于预测的图像大小,和train时使用同一个即可
        #---------------------------------------------------------------------#
        "input_shape": [300, 300],
        #-------------------------------#
        #   主干网络的选择
        #   vgg
        #-------------------------------#
        "backbone": "vgg",
        #---------------------------------------------------------------------#
        #   只有得分大于置信度的预测框会被保留下来
        #---------------------------------------------------------------------#
        "confidence": 0.5,
        #---------------------------------------------------------------------#
        #   非极大抑制所用到的nms_iou大小
        #---------------------------------------------------------------------#
        "nms_iou": 0.45,
        #---------------------------------------------------------------------#
        #   用于指定先验框的大小
        #---------------------------------------------------------------------#
        'anchors_size': [30, 60, 111, 162, 213, 264, 315],
        #---------------------------------------------------------------------#
        #   该变量用于控制是否使用letterbox_image对输入图像进行不失真的resize,
        #---------------------------------------------------------------------#
        "letterbox_image": False,
        #-------------------------------#
        #   是否使用Cuda
        #   没有GPU可以设置成False
        #-------------------------------#
        "cuda": False,
    }

    @classmethod
    def get_defaults(cls, n):
        if n in cls._defaults:
            return cls._defaults[n]
        else:
            return "Unrecognized attribute name '" + n + "'"

    #---------------------------------------------------#
    #   初始化ssd
    #---------------------------------------------------#
    def __init__(self, **kwargs):
        self.__dict__.update(self._defaults)
        for name, value in kwargs.items():
            setattr(self, name, value)
        #---------------------------------------------------#
        #   计算总的类的数量
        #---------------------------------------------------#
        self.class_names, self.num_classes = get_classes(self.classes_path)
        self.anchors = torch.from_numpy(get_anchors(self.input_shape, self.anchors_size, self.backbone)).type(torch.FloatTensor)
        if self.cuda:
            self.anchors = self.anchors.cuda()
        self.num_classes                    = self.num_classes + 1

        #---------------------------------------------------#
        #   画框设置不同的颜色
        #---------------------------------------------------#
        hsv_tuples = [(x / self.num_classes, 1., 1.) for x in range(self.num_classes)]
        self.colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples))
        self.colors = list(map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), self.colors))

        self.bbox_util = BBoxUtility(self.num_classes)
        self.generate()

    #---------------------------------------------------#
    #   载入模型
    #---------------------------------------------------#
    def generate(self):
        #-------------------------------#
        #   载入模型与权值
        #-------------------------------#
        self.net    = SSD300(self.num_classes, self.backbone)
        device      = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.net.load_state_dict(torch.load(self.model_path, map_location=device))
        self.net    = self.net.eval()
        print('{} model, anchors, and classes loaded.'.format(self.model_path))

        if self.cuda:
            self.net = torch.nn.DataParallel(self.net)
            cudnn.benchmark = True
            self.net = self.net.cuda()

    #---------------------------------------------------#
    #   检测图片
    #---------------------------------------------------#
    def detect_image(self, image):
        #---------------------------------------------------#
        #   计算输入图片的高和宽
        #---------------------------------------------------#
        image_shape = np.array(np.shape(image)[0:2])
        #---------------------------------------------------------#
        #   在这里将图像转换成RGB图像,防止灰度图在预测时报错。
        #   代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
        #---------------------------------------------------------#
        image       = cvtColor(image)
        #---------------------------------------------------------#
        #   给图像增加灰条,实现不失真的resize
        #   也可以直接resize进行识别
        #---------------------------------------------------------#
        image_data  = resize_image(image, (self.input_shape[1], self.input_shape[0]), self.letterbox_image)
        #---------------------------------------------------------#
        #   添加上batch_size维度,图片预处理,归一化。
        #---------------------------------------------------------#
        image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0)

        with torch.no_grad():
            #---------------------------------------------------#
            #   转化成torch的形式
            #---------------------------------------------------#
            images = torch.from_numpy(image_data).type(torch.FloatTensor)
            if self.cuda:
                images = images.cuda()
            #---------------------------------------------------------#
            #   将图像输入网络当中进行预测!
            #---------------------------------------------------------#
            outputs = self.net(images)
            #-----------------------------------------------------------#
            #   将预测结果进行解码
            #-----------------------------------------------------------#
            results = self.bbox_util.decode_box(outputs, self.anchors, image_shape, self.input_shape, self.letterbox_image,
                                                    nms_iou = self.nms_iou, confidence = self.confidence)
            #--------------------------------------#
            #   如果没有检测到物体,则返回原图
            #--------------------------------------#
            if len(results[0]) <= 0:
                return image

            top_label   = np.array(results[0][:, 4], dtype = 'int32')
            top_conf    = results[0][:, 5]
            top_boxes   = results[0][:, :4]
        #---------------------------------------------------------#
        #   设置字体与边框厚度
        #---------------------------------------------------------#
        font = ImageFont.truetype(font='model_data/simhei.ttf', size=np.floor(3e-2 * np.shape(image)[1] + 0.5).astype('int32'))
        thickness = max((np.shape(image)[0] + np.shape(image)[1]) // self.input_shape[0], 1)
        
        #---------------------------------------------------------#
        #   图像绘制
        #---------------------------------------------------------#


        for i, c in enumerate(top_label):
            predicted_class = self.class_names[int(c)]
            box = top_boxes[i]
            score = top_conf[i]

            top, left, bottom, right = box

            # 确保坐标在图像边界内
            top = max(0, np.floor(top).astype('int32'))
            left = max(0, np.floor(left).astype('int32'))
            bottom = min(image.size[1], np.floor(bottom).astype('int32'))
            right = min(image.size[0], np.floor(right).astype('int32'))

            label = '{} {:.2f}'.format(predicted_class, score)
            draw = ImageDraw.Draw(image)

            # 获取文本的尺寸
            label_size = draw.textbbox((0, 0), label, font=font)[2:]  # 返回值是(left, top, right, bottom),我们只需要宽高

            label = label.encode('utf-8')
            print(label, top, left, bottom, right)

            if top - label_size[1] >= 0:
                text_origin = np.array([left, top - label_size[1]])
            else:
                text_origin = np.array([left, top + 1])

            # 绘制边界框
            for j in range(thickness):
                draw.rectangle([left + j, top + j, right - j, bottom - j], outline=self.colors[int(c)])
            # 绘制标签背景
            draw.rectangle([tuple(text_origin), tuple(text_origin + label_size)], fill=self.colors[int(c)])
            # 绘制文本
            draw.text(text_origin, str(label, 'UTF-8'), fill=(0, 0, 0), font=font)
            del draw

        return image

    def get_FPS(self, image, test_interval):
        #---------------------------------------------------#
        #   计算输入图片的高和宽
        #---------------------------------------------------#
        image_shape = np.array(np.shape(image)[0:2])
        #---------------------------------------------------------#
        #   在这里将图像转换成RGB图像,防止灰度图在预测时报错。
        #   代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
        #---------------------------------------------------------#
        image       = cvtColor(image)
        #---------------------------------------------------------#
        #   给图像增加灰条,实现不失真的resize
        #   也可以直接resize进行识别
        #---------------------------------------------------------#
        image_data  = resize_image(image, (self.input_shape[1], self.input_shape[0]), self.letterbox_image)
        #---------------------------------------------------------#
        #   添加上batch_size维度,图片预处理,归一化。
        #---------------------------------------------------------#
        image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0)

        with torch.no_grad():
            #---------------------------------------------------#
            #   转化成torch的形式
            #---------------------------------------------------#
            images = torch.from_numpy(image_data).type(torch.FloatTensor)
            if self.cuda:
                images = images.cuda()
            #---------------------------------------------------------#
            #   将图像输入网络当中进行预测!
            #---------------------------------------------------------#
            outputs     = self.net(images)
            #-----------------------------------------------------------#
            #   将预测结果进行解码
            #-----------------------------------------------------------#
            results     = self.bbox_util.decode_box(outputs, self.anchors, image_shape, self.input_shape, self.letterbox_image, 
                                                    nms_iou = self.nms_iou, confidence = self.confidence)

        t1 = time.time()
        for _ in range(test_interval):
            with torch.no_grad():
                #---------------------------------------------------------#
                #   将图像输入网络当中进行预测!
                #---------------------------------------------------------#
                outputs     = self.net(images)
                #-----------------------------------------------------------#
                #   将预测结果进行解码
                #-----------------------------------------------------------#
                results     = self.bbox_util.decode_box(outputs, self.anchors, image_shape, self.input_shape, self.letterbox_image, 
                                                        nms_iou = self.nms_iou, confidence = self.confidence)

        t2 = time.time()
        tact_time = (t2 - t1) / test_interval
        return tact_time

    def get_map_txt(self, image_id, image, class_names, map_out_path):
        f = open(os.path.join(map_out_path, "detection-results/"+image_id+".txt"),"w") 
        #---------------------------------------------------#
        #   计算输入图片的高和宽
        #---------------------------------------------------#
        image_shape = np.array(np.shape(image)[0:2])
        #---------------------------------------------------------#
        #   在这里将图像转换成RGB图像,防止灰度图在预测时报错。
        #   代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
        #---------------------------------------------------------#
        image       = cvtColor(image)
        #---------------------------------------------------------#
        #   给图像增加灰条,实现不失真的resize
        #   也可以直接resize进行识别
        #---------------------------------------------------------#
        image_data  = resize_image(image, (self.input_shape[1], self.input_shape[0]), self.letterbox_image)
        #---------------------------------------------------------#
        #   添加上batch_size维度,图片预处理,归一化。
        #---------------------------------------------------------#
        image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0)

        with torch.no_grad():
            #---------------------------------------------------#
            #   转化成torch的形式
            #---------------------------------------------------#
            images = torch.from_numpy(image_data).type(torch.FloatTensor)
            if self.cuda:
                images = images.cuda()
            #---------------------------------------------------------#
            #   将图像输入网络当中进行预测!
            #---------------------------------------------------------#
            outputs     = self.net(images)
            #-----------------------------------------------------------#
            #   将预测结果进行解码
            #-----------------------------------------------------------#
            results     = self.bbox_util.decode_box(outputs, self.anchors, image_shape, self.input_shape, self.letterbox_image, 
                                                    nms_iou = self.nms_iou, confidence = self.confidence)
            #--------------------------------------#
            #   如果没有检测到物体,则返回原图
            #--------------------------------------#
            if len(results[0]) <= 0:
                return 

            top_label   = np.array(results[0][:, 4], dtype = 'int32')
            top_conf    = results[0][:, 5]
            top_boxes   = results[0][:, :4]
        
        for i, c in list(enumerate(top_label)):
            predicted_class = self.class_names[int(c)]
            box             = top_boxes[i]
            score           = str(top_conf[i])

            top, left, bottom, right = box
            if predicted_class not in class_names:
                continue

            f.write("%s %s %s %s %s %s\n" % (predicted_class, score[:6], str(int(left)), str(int(top)), str(int(right)),str(int(bottom))))

        f.close()
        return

3.3 训练和验证所需的2007_train.txt和2007_val.txt文件生成

该部分代码用于处理和生成用于训练目标检测模型的数据集标签文件,支持不同模式的操作。首先定义了几个参数,如annotation_modeclasses_pathtrainval_percenttrain_percent,并指定了VOC数据集的路径。根据annotation_mode的值,代码可以生成不同的文件。在annotation_mode为0或1时,代码从VOC数据集中的Annotations文件夹读取XML文件,生成训练集、验证集和测试集的图片ID列表文件(trainval.txt、train.txt、val.txt、test.txt)。在annotation_mode为0或2时,代码进一步处理生成2007_train.txt和2007_val.txt文件,这些文件包含每个图像的路径以及目标边界框和类别信息。通过解析XML文件中的目标对象标签,代码提取边界框坐标和类别信息,并将其写入对应的训练或验证文件中。总之,这段代码主要用于将VOC格式的数据集转换为适合目标检测模型训练的格式。

import os
import random
import xml.etree.ElementTree as ET

from utils.utils import get_classes

#--------------------------------------------------------------------------------------------------------------------------------#
#   annotation_mode用于指定该文件运行时计算的内容
#   annotation_mode为0代表整个标签处理过程,包括获得VOCdevkit/VOC2007/ImageSets里面的txt以及训练用的2007_train.txt、2007_val.txt
#   annotation_mode为1代表获得VOCdevkit/VOC2007/ImageSets里面的txt
#   annotation_mode为2代表获得训练用的2007_train.txt、2007_val.txt
#--------------------------------------------------------------------------------------------------------------------------------#
annotation_mode     = 0
#----------------------------------------------------------------
#   仅在annotation_mode为0和2的时候有效
#-------------------------------------------------------------------#
classes_path        = 'model_data/voc_classes.txt'

trainval_percent    = 0.9
train_percent       = 0.9
#-------------------------------------------------------#
#   指向VOC数据集所在的文件夹
#   默认指向根目录下的VOC数据集
#-------------------------------------------------------#
VOCdevkit_path  = 'VOCdevkit'

VOCdevkit_sets  = [('2007', 'train'), ('2007', 'val')]
classes, _      = get_classes(classes_path)

def convert_annotation(year, image_id, list_file):
    in_file = open(os.path.join(VOCdevkit_path, 'VOC%s/Annotations/%s.xml'%(year, image_id)), encoding='utf-8')
    tree=ET.parse(in_file)
    root = tree.getroot()

    for obj in root.iter('object'):
        difficult = 0 
        if obj.find('difficult')!=None:
            difficult = obj.find('difficult').text
        cls = obj.find('name').text
        if cls not in classes or int(difficult)==1:
            continue
        cls_id = classes.index(cls)
        xmlbox = obj.find('bndbox')
        b = (int(float(xmlbox.find('xmin').text)), int(float(xmlbox.find('ymin').text)), int(float(xmlbox.find('xmax').text)), int(float(xmlbox.find('ymax').text)))
        list_file.write(" " + ",".join([str(a) for a in b]) + ',' + str(cls_id))
        
if __name__ == "__main__":
    random.seed(0)
    if annotation_mode == 0 or annotation_mode == 1:
        print("Generate txt in ImageSets.")
        xmlfilepath     = os.path.join(VOCdevkit_path, 'VOC2007/Annotations')
        saveBasePath    = os.path.join(VOCdevkit_path, 'VOC2007/ImageSets/Main')
        temp_xml        = os.listdir(xmlfilepath)
        total_xml       = []
        for xml in temp_xml:
            if xml.endswith(".xml"):
                total_xml.append(xml)

        num     = len(total_xml)  
        list    = range(num)  
        tv      = int(num*trainval_percent)  
        tr      = int(tv*train_percent)  
        trainval= random.sample(list,tv)  
        train   = random.sample(trainval,tr)  
        
        print("train and val size",tv)
        print("train size",tr)
        ftrainval   = open(os.path.join(saveBasePath,'trainval.txt'), 'w')  
        ftest       = open(os.path.join(saveBasePath,'test.txt'), 'w')  
        ftrain      = open(os.path.join(saveBasePath,'train.txt'), 'w')  
        fval        = open(os.path.join(saveBasePath,'val.txt'), 'w')  
        
        for i in list:  
            name=total_xml[i][:-4]+'\n'  
            if i in trainval:  
                ftrainval.write(name)  
                if i in train:  
                    ftrain.write(name)  
                else:  
                    fval.write(name)  
            else:  
                ftest.write(name)  
        
        ftrainval.close()  
        ftrain.close()  
        fval.close()  
        ftest.close()
        print("Generate txt in ImageSets done.")

    if annotation_mode == 0 or annotation_mode == 2:
        print("Generate 2007_train.txt and 2007_val.txt for train.")
        for year, image_set in VOCdevkit_sets:
            image_ids = open(os.path.join(VOCdevkit_path, 'VOC%s/ImageSets/Main/%s.txt'%(year, image_set)), encoding='utf-8').read().strip().split()
            list_file = open('%s_%s.txt'%(year, image_set), 'w', encoding='utf-8')
            for image_id in image_ids:
                list_file.write('%s/VOC%s/JPEGImages/%s.jpg'%(os.path.abspath(VOCdevkit_path), year, image_id))

                convert_annotation(year, image_id, list_file)
                list_file.write('\n')
            list_file.close()
        print("Generate 2007_train.txt and 2007_val.txt for train done.")

3.4 模型训练

该部分代码使用SSD模型进行目标检测的训练过程,主要包括冻结阶段和解冻阶段的训练。首先,代码加载必要的库并设置了一些超参数,比如输入形状、骨干网络类型、预训练模型路径等。然后,代码获取类别和锚点配置,并初始化SSD模型。如果预训练权重存在,则会加载这些权重。

接下来定义了损失函数和历史记录器,并从训练和验证数据集对应的txt文件中读取数据。对于冻结阶段训练,代码设置了批次大小和学习率,并创建数据加载器。通过设置网络的一部分参数不可训练,代码实现了冻结部分网络的功能。在训练过程中,代码使用Adam优化器和学习率调度器进行优化。

在解冻阶段训练,再次设置了批次大小和学习率,重新创建数据加载器,并解冻之前冻结的网络部分,使其参与训练。最后,代码在两个阶段内都调用fit_one_epoch函数来执行实际的训练过程,包括前向传播、计算损失、反向传播和参数更新。通过这种方式,代码逐步调整模型的权重,使其在训练数据上表现良好。

import warnings
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.hub import load_state_dict_from_url
from nets.ssd import SSD300
from nets.ssd_training import MultiboxLoss, weights_init
from utils.anchors import get_anchors
from utils.callbacks import LossHistory
from utils.dataloader import SSDDataset, ssd_dataset_collate
from utils.utils import get_classes
from utils.utils_fit import fit_one_epoch

warnings.filterwarnings("ignore")

if __name__ == "__main__":
    # 参数配置
    Cuda = False
    classes_path = 'model_data/voc_classes.txt'
    model_path = 'E:\yanyi\yanyixia\AI\model\ssd_weights.pth'
    input_shape = [300, 300]
    backbone = "vgg"
    pretrained = False
    anchors_size = [30, 60, 111, 162, 213, 264, 315]
    Init_Epoch = 0
    Freeze_Epoch = 50
    Freeze_batch_size = 16
    Freeze_lr = 5e-4
    UnFreeze_Epoch = 100
    Unfreeze_batch_size = 4
    Unfreeze_lr = 1e-4
    Freeze_Train = True
    num_workers = 4
    train_annotation_path = '2007_train.txt'
    val_annotation_path = '2007_val.txt'

    # 获取classes和anchor
    class_names, num_classes = get_classes(classes_path)
    num_classes += 1
    anchors = get_anchors(input_shape, anchors_size, backbone)

    model = SSD300(num_classes, backbone, pretrained)
    if not pretrained:
        weights_init(model)
    if model_path != '':
        # 加载预训练权重
        print('Load weights {}.'.format(model_path))
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        model_dict = model.state_dict()
        pretrained_dict = torch.load(model_path, map_location=device)
        pretrained_dict = {k: v for k, v in pretrained_dict.items() if np.shape(model_dict[k]) == np.shape(v)}
        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict)

    model_train = model.train()
    if Cuda:
        model_train = torch.nn.DataParallel(model)
        cudnn.benchmark = True
        model_train = model_train.cuda()

    criterion = MultiboxLoss(num_classes, neg_pos_ratio=3.0)
    loss_history = LossHistory("logs/")

    # 读取数据集对应的txt
    with open(train_annotation_path) as f:
        train_lines = f.readlines()
    with open(val_annotation_path) as f:
        val_lines = f.readlines()
    num_train = len(train_lines)
    num_val = len(val_lines)

    # 冻结阶段训练
    if True:
        batch_size = Freeze_batch_size
        lr = Freeze_lr
        start_epoch = Init_Epoch
        end_epoch = Freeze_Epoch

        epoch_step = num_train // batch_size
        epoch_step_val = num_val // batch_size

        if epoch_step == 0 or epoch_step_val == 0:
            raise ValueError("数据集过小,无法进行训练,请扩充数据集。")

        optimizer = optim.Adam(model_train.parameters(), lr, weight_decay=5e-4)
        lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.94)

        train_dataset = SSDDataset(train_lines, input_shape, anchors, batch_size, num_classes, train=True)
        val_dataset = SSDDataset(val_lines, input_shape, anchors, batch_size, num_classes, train=False)

        gen = DataLoader(train_dataset, shuffle=True, batch_size=batch_size, num_workers=num_workers, pin_memory=True,
                         drop_last=True, collate_fn=ssd_dataset_collate)
        gen_val = DataLoader(val_dataset, shuffle=True, batch_size=batch_size, num_workers=num_workers, pin_memory=True,
                             drop_last=True, collate_fn=ssd_dataset_collate)

        # 冻结一定部分训练
        if Freeze_Train:
            if backbone == "vgg":
                for param in model.vgg[:28].parameters():
                    param.requires_grad = False
            else:
                for param in model.mobilenet.parameters():
                    param.requires_grad = False

        for epoch in range(start_epoch, end_epoch):
            fit_one_epoch(model_train, model, criterion, loss_history, optimizer, epoch,
                          epoch_step, epoch_step_val, gen, gen_val, end_epoch, Cuda)
            lr_scheduler.step()

    # 解冻阶段训练
    if True:
        batch_size = Unfreeze_batch_size
        lr = Unfreeze_lr
        start_epoch = Freeze_Epoch
        end_epoch = UnFreeze_Epoch

        epoch_step = num_train // batch_size
        epoch_step_val = num_val // batch_size

        if epoch_step == 0 or epoch_step_val == 0:
            raise ValueError("数据集过小,无法进行训练,请扩充数据集。")

        optimizer = optim.Adam(model_train.parameters(), lr, weight_decay=5e-4)
        lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.94)

        train_dataset = SSDDataset(train_lines, input_shape, anchors, batch_size, num_classes, train=True)
        val_dataset = SSDDataset(val_lines, input_shape, anchors, batch_size, num_classes, train=False)

        gen = DataLoader(train_dataset, shuffle=True, batch_size=batch_size, num_workers=num_workers, pin_memory=True,
                         drop_last=True, collate_fn=ssd_dataset_collate)
        gen_val = DataLoader(val_dataset, shuffle=True, batch_size=batch_size, num_workers=num_workers, pin_memory=True,
                             drop_last=True, collate_fn=ssd_dataset_collate)

        # 解冻后训练
        if Freeze_Train:
            if backbone == "vgg":
                for param in model.vgg[:28].parameters():
                    param.requires_grad = True
            else:
                for param in model.mobilenet.parameters():
                    param.requires_grad = True

        for epoch in range(start_epoch, end_epoch):
            fit_one_epoch(model_train, model, criterion, loss_history, optimizer, epoch,
                          epoch_step, epoch_step_val, gen, gen_val, end_epoch, Cuda)
            lr_scheduler.step()

3.5 GUI界面

该部分代码实现了一个基于SSD模型的安全帽检测系统的图形用户界面(GUI),使用PyQt5框架。它提供了两个功能:检测图像中的安全帽和检测视频中的安全帽。

import sys
import cv2
import numpy as np
from PyQt5.QtWidgets import QApplication, QMainWindow, QPushButton, QLabel, QFileDialog, QVBoxLayout, QWidget, QHBoxLayout, QFrame
from PyQt5.QtGui import QPixmap, QImage, QFont
from PyQt5.QtCore import Qt, QTimer
from PIL import Image
from ssd import SSD

class MainWindow(QMainWindow):
    def __init__(self):
        super().__init__()

        self.ssd = SSD()
        self.initUI()
        self.timer = QTimer()
        self.timer.timeout.connect(self.update_frame)

    def initUI(self):
        self.setWindowTitle('基于SSD的安全帽检测系统')
        self.setGeometry(100, 100, 1200, 800)  # 设置窗口大小

        mainLayout = QVBoxLayout()
        titleLabel = QLabel('基于SSD的安全帽检测系统', self)
        titleLabel.setAlignment(Qt.AlignCenter)
        titleLabel.setFont(QFont('Arial', 24))
        mainLayout.addWidget(titleLabel)

        centerLayout = QHBoxLayout()
        centerLayout.setAlignment(Qt.AlignCenter)

        self.imageLabel = QLabel(self)
        self.imageLabel.setAlignment(Qt.AlignCenter)
        self.imageLabel.setFrameShape(QFrame.Box)
        self.imageLabel.setFixedSize(1100, 600)  # 设置固定大小
        centerLayout.addWidget(self.imageLabel)

        mainLayout.addLayout(centerLayout)

        buttonLayout = QHBoxLayout()

        self.selectImageButton = QPushButton('请选择图片进行检测', self)
        self.selectImageButton.setStyleSheet("font-size: 30px;")
        self.selectImageButton.setFont(QFont('Arial', 20))
        self.selectImageButton.clicked.connect(self.select_image)
        buttonLayout.addWidget(self.selectImageButton)

        self.selectVideoButton = QPushButton('请选择视频进行检测', self)
        self.selectVideoButton.setStyleSheet("font-size: 30px;")
        self.selectVideoButton.setFont(QFont('Arial', 20))
        self.selectVideoButton.clicked.connect(self.select_video)
        buttonLayout.addWidget(self.selectVideoButton)

        mainLayout.addLayout(buttonLayout)

        container = QWidget()
        container.setLayout(mainLayout)
        self.setCentralWidget(container)

        self.setStyleSheet("""
            QPushButton {
                background-color: #4CAF50;
                color: white;
                border: none;
                padding: 15px 32px;
                text-align: center;
                text-decoration: none;
                display: inline-block;
                font-size: 16px;
                margin: 4px 2px;
                transition-duration: 0.4s;
                cursor: pointer;
                border-radius: 12px;
            }
            QPushButton:hover {
                background-color: white;
                color: black;
                border: 2px solid #4CAF50;
            }
            QLabel {
                background-color: white;
            }
            QFrame {
                border: 2px solid #4CAF50;
                border-radius: 15px;
            }
        """)

    def select_image(self):
        imagePath, _ = QFileDialog.getOpenFileName(self, "选择图片", "",
                                                   "Images (*.png *.xpm *.jpg *.jpeg *.bmp *.tif *.tiff)")
        if imagePath:
            image = Image.open(imagePath)
            result_image = self.ssd.detect_image(image)
            result_image = result_image.convert("RGB")
            result_image = np.array(result_image)
            height, width, channel = result_image.shape
            bytesPerLine = 3 * width
            qImg = QImage(result_image.data, width, height, bytesPerLine, QImage.Format_RGB888)
            pixmap = QPixmap.fromImage(qImg)

            self.imageLabel.setPixmap(pixmap.scaled(self.imageLabel.size(), Qt.KeepAspectRatio, Qt.SmoothTransformation))
            self.imageLabel.adjustSize()

    def select_video(self):
        videoPath, _ = QFileDialog.getOpenFileName(self, "选择视频", "", "Videos (*.mp4 *.avi *.mkv *.mov)")
        if videoPath:
            self.video_path = videoPath
            self.capture = cv2.VideoCapture(self.video_path)
            self.timer.start(30)  # 每30ms更新一次

    def update_frame(self):
        ret, frame = self.capture.read()
        if not ret:
            self.timer.stop()
            return

        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        image = Image.fromarray(frame)
        result_image = self.ssd.detect_image(image)
        result_frame = np.array(result_image)
        # result_frame = cv2.cvtColor(result_frame, cv2.COLOR_RGB2BGR)

        height, width, channel = result_frame.shape
        bytesPerLine = 3 * width
        qImg = QImage(result_frame.data, width, height, bytesPerLine, QImage.Format_RGB888)
        pixmap = QPixmap.fromImage(qImg)

        self.imageLabel.setPixmap(pixmap.scaled(self.imageLabel.size(), Qt.KeepAspectRatio, Qt.SmoothTransformation))
        self.imageLabel.adjustSize()


if __name__ == '__main__':
    app = QApplication(sys.argv)
    mainWindow = MainWindow()
    mainWindow.show()
    sys.exit(app.exec_())

3.6 结果展示

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

3.7 文件下载

训练所需的ssd_weights.pth和主干的权值可以在百度云下载。

链接: https://pan.baidu.com/s/1iUVE50oLkzqhtZbUL9el9w
提取码: jgn8

4. 参考连接

睿智的目标检测23——Pytorch搭建SSD目标检测平台

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1822201.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

C#版 iText7——画发票PDF(完整)

显示描述&#xff1a; 1、每页显示必须带有发票头、“销售方和购买方信息” 2、明细填充为&#xff1a;当n≤8 行时&#xff0c;发票总高度140mm&#xff0c;每条发票明细行款高度4.375mm&#xff1b; 当8<n≤12行时&#xff0c;发票高度增加17.5mm&#xff0c;不换页&#…

人工智能内容生成元年-AI绘画原理解析

随着人工智能技术的飞速发展&#xff0c;AI绘画作为其引人注目的应用领域&#xff0c;正在以惊人的速度崭露头角。从最初的生成对抗网络&#xff08;GAN&#xff09;到如今的深度学习&#xff0c;AI绘画技术在艺术创作、设计等领域展现出了无限的可能性。其独特的算法和智能化特…

构建 deno/fresh 的 docker 镜像

众所周知, 最近 docker 镜像的使用又出现了新的困难. 但是不怕, 窝们可以使用曲线救国的方法: 自己制作容器镜像 ! 下面以 deno/fresh 举栗, 部署一个简单的应用. 目录 1 创建 deno/fresh 项目2 构建 docker 镜像3 部署和测试4 总结与展望 1 创建 deno/fresh 项目 执行命令…

情侣飞行棋系统微信小程序+H5+微信公众号+APP 源码

情侣飞行棋系统&#xff1a;浪漫与策略并存的双人游戏 &#x1f3b2; 一、引言&#xff1a;寻找爱情的乐趣 在繁忙的生活中&#xff0c;情侣们总是渴望找到一种既能增进感情又能带来乐趣的活动。而“情侣飞行棋系统”正是这样一个完美的选择。它结合了传统飞行棋的玩法和情侣…

接口自动化测试工程化——了解接口测试

什么是接口测试 接口测试也是一种功能测试 我理解的接口测试&#xff0c;其实也是一种功能测试&#xff0c;只是平时大家说的功能测试更多代指 UI 层面的功能测试&#xff0c;而接口测试更偏向于服务端层面的功能测试。 接口测试的目的 测试左移&#xff0c;尽早介入测试&a…

失眠焦虑?这些小妙招助你重拾宁静之夜

在这个快节奏的时代&#xff0c;失眠与焦虑似乎成了不少人的“常客”。每当夜幕降临&#xff0c;躺在床上却辗转反侧&#xff0c;思绪万千&#xff0c;仿佛整个世界的喧嚣都涌入了脑海。&#x1f4ad; 其实&#xff0c;放松心情&#xff0c;调整心态&#xff0c;是缓解失眠焦虑…

【MATLAB源码-第225期】基于matlab的计算器GUI设计仿真,能够实现基础运算,三角函数以及幂运算。

操作环境&#xff1a; MATLAB 2022a 1、算法描述 界面布局 计算器界面的主要元素分为几大部分&#xff1a;显示屏、功能按钮、数字按钮和操作符按钮。 显示屏 显示屏&#xff08;Edit Text&#xff09;&#xff1a;位于界面顶部中央&#xff0c;用于显示用户输入的表达式和…

Java聚合快递系统对接云洋系统快递小程序APP公众号系统源码

聚合快递对接云洋系统小程序&#xff1a;一键解决物流难题 一、引言&#xff1a;为何选择聚合快递对接&#xff1f; 在电商日益繁荣的今天&#xff0c;物流成为了连接卖家与买家的关键桥梁。然而&#xff0c;面对市场上琳琅满目的快递公司&#xff0c;如何高效、便捷地进行快…

Fluid 1.0 版发布,打通云原生高效数据使用的“最后一公里”

作者&#xff1a;顾荣 前言 得益于云原生技术在资源成本集约、部署运维便捷、算力弹性灵活方面的优势&#xff0c;越来越多企业和开发者将数据密集型应用&#xff0c;特别是 AI 和大数据领域应用&#xff0c;运行于云原生环境中。然而&#xff0c;云原生计算与存储分离架构虽…

低代码组件扩展方案在复杂业务场景下的设计与实践

组件是爱速搭的前端页面可视化模块的核心能力之一&#xff0c;它将前端研发人员从无休止的页面样式微调和分辨率兼容工作中解放了出来。 目前&#xff0c;爱速搭通过内置的上百种功能组件&#xff08;120&#xff09;&#xff0c;基本可以覆盖大部分中后台页面的可视化设计场景…

什么是数字人大?一分钟带你了解!

在数字化浪潮席卷全球的今天&#xff0c;中国作为数字经济的领跑者&#xff0c;正积极推动数字技术与国家治理体系的深度融合。其中&#xff0c;“数字人大”作为新时代国家治理体系和治理能力现代化的重要一环&#xff0c;正逐步成为推动民主法治建设、提升人大工作效能的新引…

ThinkPHP5.0 apache服务器配置URL重写,index.php去除

本地环境wamp .htaccess文件代码 <IfModule mod_rewrite.c>Options FollowSymlinks -MultiviewsRewriteEngine onRewriteCond %{REQUEST_FILENAME} !-dRewriteCond %{REQUEST_FILENAME} !-fRewriteRule ^(.*)$ index.php?/$1 [QSA,PT,L] </IfModule> 踩过这个坑&a…

不用下软件,51建模网上传模型就能直接在网页预览!

在数字化时代&#xff0c;3D建模和渲染技术日益成为各行各业不可或缺的工具。然而&#xff0c;传统的建模和预览流程往往需要用户安装复杂的软件&#xff0c;这不仅增加了技术门槛&#xff0c;也限制了模型在不同设备间的共享和查看。 为了解决这一痛点&#xff0c;51建模网凭…

【在线OJ】发帖功能前后段代码实现

一、页面布局 二、前端代码 <template><div id"app"><div style"height: 100vh"><div style"display: flex" ><el-input style"width: 95%" v-model"title" placeholder"输入标题"&g…

光储充一体化,开启绿色出行新篇章

一、追光逐梦&#xff0c;绿色能源点亮未来 在蔚蓝的天空下&#xff0c;光伏发电板如同一片片金色的叶子&#xff0c;静静地捕捉着太阳的光芒。它们不仅为大地带来光明&#xff0c;更是绿色出行的强大后盾。光储充一体化充电站&#xff0c;以光伏为源&#xff0c;储能为桥&…

蓝牙模块的不同版本迭代发展与技术趋势

蓝牙技术自1999年首次亮相以来&#xff0c;已经历了从1.0到5.0的多个版本迭代&#xff0c;每一次的更新都带来了显著的性能提升和广泛的应用前景。本文将综述蓝牙模块的版本迭代&#xff0c;分析其主要改进点&#xff0c;并探讨蓝牙模块在物联网、医疗、穿戴式设备等领域的应用…

Luma AI 推出梦幻机:据说吊打Sora和快手可灵(KLING)|TodayAI

近日&#xff0c;美国初创公司 Luma AI 宣布推出其最新的文本生成视频工具——梦幻机&#xff08;Dream Machine&#xff09;。这一消息发布的时间正好在中国科技公司快手推出其文本生成视频模型可灵&#xff08;KLING&#xff09;几天之后&#xff0c;标志着视频生成领域的又一…

560亿美元薪酬获批!马斯克:特斯拉未来市值将不止5万亿美元

KlipC报道&#xff1a;6月13日&#xff0c;美国电动汽车制造商特斯拉公司举办年度股东大会&#xff0c;其CEO马斯克对特斯拉生产销售、未来车型计划和在无人驾驶能等领域的发展进行了报告。此外&#xff0c;特斯拉股东批准了马斯克的560亿美元薪酬方案以及特斯拉总部迁至得克萨…

qt如何在linux平台上设置编译生成windows程序文件,跨平台?

在开始前刚好我有一些资料&#xff0c;是我根据网友给的问题精心整理了一份「qt的资料从专业入门到高级教程」&#xff0c; 点个关注在评论区回复“888”之后私信回复“888”&#xff0c;全部无偿共享给大家&#xff01;&#xff01;&#xff01;QT本来目标就是跨平台&#xf…

【PL理论】(23) 函数式语言:let-in 示例的分解 | 谁在使用动态作用域?

&#x1f4ad; 写在前面&#xff1a;本章我们将对函数式语言的讲解进行收尾&#xff0c;分解一下之前讲的 let-in 示例。然后讨论一下谁在使用动态作用域。 目录 0x00 let-in 示例的分解 0x01 谁使用动态作用域&#xff1f; 0x00 let-in 示例的分解 让我们详细检查这个示例…