YOLOv8改进:MixUp with Consistency——基于混合增强与一致性正则化的鲁棒性目标检测算法

news2026/3/26 18:36:37
1. 引言目标检测作为计算机视觉领域的核心任务之一在实际应用中面临着诸多挑战如光照变化、遮挡、图像噪声以及数据分布偏移等问题。YOLOv8作为当前最先进的目标检测器之一凭借其高效的网络结构和优秀的性能表现已在工业界和学术界得到广泛应用。然而在训练数据有限或存在噪声的情况下YOLOv8仍可能出现过拟合或泛化能力不足的问题。为了解决这一问题本文提出了一种创新的数据增强与正则化策略——MixUp with Consistency。该方法在传统MixUp数据增强的基础上引入一致性正则化约束迫使模型对混合样本的预测与其原始样本的混合预测保持一致。这种策略不仅能够有效扩充训练数据还能显著提升模型的鲁棒性和泛化能力。本文将详细介绍MixUp with Consistency的理论原理并展示如何将其集成到YOLOv8中。我们将提供完整的代码实现并在多个公开数据集上进行实验验证。2. 相关工作与理论基础2.1 MixUp数据增强MixUp是一种简单而有效的数据增强方法由Zhang等人于2018年提出。其核心思想是通过线性插值的方式混合两幅图像及其对应的标签x~λxi(1−λ)xjx~λxi​(1−λ)xj​y~λyi(1−λ)yjy~​λyi​(1−λ)yj​其中xixi​和xjxj​是随机选取的两幅图像yiyi​和yjyj​是对应的标签对于分类任务通常是one-hot向量λ∼Beta(α,α)λ∼Beta(α,α)是从Beta分布中采样得到的混合系数。MixUp的优势在于数据扩充通过组合现有样本生成大量新的训练样本有效缓解数据不足的问题。平滑决策边界迫使模型在样本之间进行线性插值使得决策边界更加平滑减少过拟合。提升泛化能力增强模型对输入扰动的鲁棒性。对于目标检测任务MixUp的实现更加复杂因为需要处理边界框坐标的线性插值。设两个边界框分别为bibi​和bjbj​混合后的边界框为b~λbi(1−λ)bjb~λbi​(1−λ)bj​同时需要确保混合后的边界框坐标仍然有效即坐标值在图像范围内。2.2 一致性正则化一致性正则化是半监督学习和自监督学习中的核心思想其基本假设是模型对于输入数据的微小扰动应产生一致的预测结果。具体而言对于输入样本xx及其扰动版本x~x~如添加噪声、数据增强等模型应输出相似的预测分布。一致性正则化的数学表达通常采用KL散度或均方误差来度量预测的一致性LconsistencyEx∼D[D(f(x),f(x~))]Lconsistency​Ex∼D​[D(f(x),f(x~))]其中f(⋅)f(⋅)表示模型的预测输出D(⋅,⋅)D(⋅,⋅)是距离度量函数。一致性正则化具有以下优势增强鲁棒性使模型对输入噪声和扰动不敏感。平滑输出空间促进模型学习到更平滑的输出函数。充分利用无标签数据在半监督学习中可以利用无标签数据增强模型性能。2.3 MixUp with ConsistencyMixUp with Consistency将MixUp和一致性正则化有机结合其核心思想是对于通过MixUp生成的混合样本模型对该样本的预测应与对原始样本预测的混合保持一致。具体而言设xixi​和xjxj​为两幅原始图像通过MixUp生成混合样本x~λxi(1−λ)xjx~λxi​(1−λ)xj​。模型对混合样本的预测记为pmixf(x~)pmix​f(x~)而对原始样本的预测混合记为ptargetλf(xi)(1−λ)f(xj)ptarget​λf(xi​)(1−λ)f(xj​)。一致性正则化损失迫使这两者尽可能接近LconsistencyMSE(pmix,ptarget)Lconsistency​MSE(pmix​,ptarget​)这种设计具有以下理论优势隐式集成混合预测相当于对多个模型的预测进行集成有助于减少预测方差。自适应正则化正则化的强度与MixUp的混合系数λλ相关联实现了自适应的正则化效果。更好的特征学习促使模型学习到更加线性的特征表示提高特征的迁移能力。3. 算法设计3.1 整体框架我们将MixUp with Consistency集成到YOLOv8的训练流程中。整体框架如图1所示text原始图像x_i → 模型 → 预测p_i 原始图像x_j → 模型 → 预测p_j ↓ MixUp (λ) ↓ 混合图像x_mix → 模型 → 预测p_mix ↓ 一致性损失: MSE(p_mix, λ·p_i (1-λ)·p_j)在训练过程中每个batch都会以一定概率执行MixUp with Consistency操作。对于执行了MixUp的样本除了计算标准的检测损失外还需额外计算一致性正则化损失。3.2 检测损失YOLOv8的检测损失包括三个部分边界框回归损失、分类损失和分布式焦点损失DFL。在MixUp with Consistency框架中检测损失的计算方式保持不变LdetLboxLclsLdflLdet​Lbox​Lcls​Ldfl​对于混合样本检测损失使用混合后的边界框和类别标签进行计算。3.3 一致性正则化损失一致性正则化损失作用于模型的预测输出。对于目标检测任务模型的输出是一个多维度的预测张量包含每个锚点的边界框偏移量和类别分数。一致性损失采用均方误差MSE进行计算Lconsistency1N∑n1N∥f(x~)n−(λf(xi)n(1−λ)f(xj)n)∥22Lconsistency​N1​n1∑N​∥f(x~)n​−(λf(xi​)n​(1−λ)f(xj​)n​)∥22​其中NN是预测张量的总元素数。需要注意的是为了保持训练稳定性我们引入了一个平衡系数ββ来控制一致性损失的影响LtotalLdetβ⋅LconsistencyLtotal​Ldet​β⋅Lconsistency​在实践中ββ通常设置为较小的值如0.1或0.5以避免一致性损失主导训练过程。3.4 动态混合系数为了进一步提升MixUp with Consistency的效果我们引入动态混合系数机制。传统MixUp中λλ从Beta分布中随机采样但我们提出根据两个样本的特征相似度动态调整λλ的值。设两幅图像的特征向量分别为zizi​和zjzj​可以从模型的中间层提取定义相似度sim(zi,zj)zi⋅zj∥zi∥∥zj∥sim(zi​,zj​)∥zi​∥∥zj​∥zi​⋅zj​​然后动态调整λλλdynamicλ⋅(1−sim(zi,zj))λdynamic​λ⋅(1−sim(zi​,zj​))这种设计的直觉是对于特征相似度高的样本对混合系数应较小以避免过度混合导致的信息冗余而对于特征差异大的样本对较大的混合系数有助于生成更多样化的混合样本。4. 代码实现4.1 MixUp with Consistency核心模块以下是MixUp with Consistency的核心代码实现。我们将创建一个新的Python模块mixup_consistency.py。pythonimport torch import torch.nn as nn import torch.nn.functional as F import numpy as np import random from typing import Tuple, List, Optional class MixUpWithConsistency: MixUp with Consistency 数据增强与正则化模块 适用于YOLOv8目标检测模型 def __init__(self, alpha: float 1.0, # Beta分布参数 consistency_weight: float 0.5, # 一致性损失权重 mixup_prob: float 0.5, # MixUp应用概率 dynamic_lambda: bool True, # 是否使用动态混合系数 use_feature_similarity: bool True): # 是否使用特征相似度 初始化MixUp with Consistency模块 Args: alpha: Beta分布的alpha参数控制混合系数分布 consistency_weight: 一致性损失的权重系数β mixup_prob: 每个batch应用MixUp的概率 dynamic_lambda: 是否使用动态混合系数 use_feature_similarity: 是否使用特征相似度调整λ self.alpha alpha self.consistency_weight consistency_weight self.mixup_prob mixup_prob self.dynamic_lambda dynamic_lambda self.use_feature_similarity use_feature_similarity def _get_mixup_lambda(self, batch_size: int) - torch.Tensor: 从Beta分布采样混合系数λ if self.alpha 0: lam np.random.beta(self.alpha, self.alpha, batch_size) else: lam np.ones(batch_size) return torch.from_numpy(lam).float() def _compute_feature_similarity(self, features_i: torch.Tensor, features_j: torch.Tensor) - torch.Tensor: 计算两个特征向量的余弦相似度 # 将特征展平 features_i_flat features_i.view(features_i.size(0), -1) features_j_flat features_j.view(features_j.size(0), -1) # 计算余弦相似度 similarity F.cosine_similarity(features_i_flat, features_j_flat, dim1) return similarity def _adjust_lambda_with_similarity(self, lam: torch.Tensor, similarity: torch.Tensor) - torch.Tensor: 根据特征相似度调整混合系数 # 相似度越高调整后的λ越小 adjusted_lam lam * (1 - similarity) # 确保λ在[0, 1]范围内 adjusted_lam torch.clamp(adjusted_lam, 0.0, 1.0) return adjusted_lam def mixup_images(self, img1: torch.Tensor, img2: torch.Tensor, lam: float) - torch.Tensor: 混合两幅图像 return lam * img1 (1 - lam) * img2 def mixup_bboxes(self, bboxes1: torch.Tensor, bboxes2: torch.Tensor, lam: float, img_size: Tuple[int, int]) - torch.Tensor: 混合边界框坐标 Args: bboxes1: 第一幅图像的边界框形状为[N1, 4] bboxes2: 第二幅图像的边界框形状为[N2, 4] lam: 混合系数 img_size: 图像尺寸 (height, width) Returns: 混合后的边界框形状为[N1N2, 4] # 对每个边界框进行线性插值 # 注意边界框坐标需要保持有效范围 mixed_bboxes [] # 混合bboxes1中的边界框 for bbox in bboxes1: mixed_bbox lam * bbox mixed_bboxes.append(mixed_bbox) # 混合bboxes2中的边界框 for bbox in bboxes2: mixed_bbox (1 - lam) * bbox mixed_bboxes.append(mixed_bbox) if len(mixed_bboxes) 0: return torch.zeros((0, 4), devicebboxes1.device) mixed_bboxes torch.stack(mixed_bboxes) # 确保边界框坐标在图像范围内 h, w img_size mixed_bboxes[:, 0] torch.clamp(mixed_bboxes[:, 0], 0, w) mixed_bboxes[:, 1] torch.clamp(mixed_bboxes[:, 1], 0, h) mixed_bboxes[:, 2] torch.clamp(mixed_bboxes[:, 2], 0, w) mixed_bboxes[:, 3] torch.clamp(mixed_bboxes[:, 3], 0, h) # 确保x1 x2, y1 y2 mixed_bboxes[:, 0], mixed_bboxes[:, 2] torch.min(mixed_bboxes[:, 0], mixed_bboxes[:, 2]), \ torch.max(mixed_bboxes[:, 0], mixed_bboxes[:, 2]) mixed_bboxes[:, 1], mixed_bboxes[:, 3] torch.min(mixed_bboxes[:, 1], mixed_bboxes[:, 3]), \ torch.max(mixed_bboxes[:, 1], mixed_bboxes[:, 3]) return mixed_bboxes def mixup_labels(self, labels1: torch.Tensor, labels2: torch.Tensor, lam: float) - torch.Tensor: 混合标签类别 Args: labels1: 第一幅图像的类别标签形状为[N1] labels2: 第二幅图像的类别标签形状为[N2] lam: 混合系数 Returns: 混合后的标签形状为[N1N2, num_classes]软标签 # 获取类别总数 max_class max(labels1.max().item() if len(labels1) 0 else 0, labels2.max().item() if len(labels2) 0 else 0) 1 # 创建软标签 mixed_labels [] # 为labels1创建软标签 if len(labels1) 0: soft_labels1 F.one_hot(labels1, num_classesmax_class).float() soft_labels1 lam * soft_labels1 mixed_labels.append(soft_labels1) # 为labels2创建软标签 if len(labels2) 0: soft_labels2 F.one_hot(labels2, num_classesmax_class).float() soft_labels2 (1 - lam) * soft_labels2 mixed_labels.append(soft_labels2) if len(mixed_labels) 0: return torch.zeros((0, max_class), devicelabels1.device) return torch.cat(mixed_labels, dim0) def compute_consistency_loss(self, pred_mixed: torch.Tensor, pred1: torch.Tensor, pred2: torch.Tensor, lam: float) - torch.Tensor: 计算一致性正则化损失 Args: pred_mixed: 模型对混合图像的预测 pred1: 模型对第一幅原始图像的预测 pred2: 模型对第二幅原始图像的预测 lam: 混合系数 Returns: 一致性损失 # 计算目标预测λ * pred1 (1-λ) * pred2 target_pred lam * pred1 (1 - lam) * pred2 # 使用均方误差计算一致性损失 consistency_loss F.mse_loss(pred_mixed, target_pred) return consistency_loss def apply_mixup(self, images: torch.Tensor, targets: dict, model: nn.Module, features: Optional[torch.Tensor] None) - Tuple[torch.Tensor, dict, torch.Tensor]: 对整个batch应用MixUp with Consistency Args: images: 输入图像batch形状为[B, C, H, W] targets: 目标标签字典包含bbox和cls等字段 model: YOLOv8模型 features: 可选的特征图用于动态λ调整 Returns: 混合后的图像、标签和一致性损失 batch_size images.size(0) # 决定是否应用MixUp if random.random() self.mixup_prob: return images, targets, torch.tensor(0.0, deviceimages.device) # 随机配对 indices torch.randperm(batch_size) images2 images[indices] # 获取配对的目标 targets2 {} for key in targets: if isinstance(targets[key], torch.Tensor): targets2[key] targets[key][indices] else: targets2[key] targets[key] # 采样混合系数 lam_batch self._get_mixup_lambda(batch_size).to(images.device) # 如果使用动态λ且提供了特征 if self.dynamic_lambda and self.use_feature_similarity and features is not None: features2 features[indices] similarity self._compute_feature_similarity(features, features2) lam_batch self._adjust_lambda_with_similarity(lam_batch, similarity) # 混合图像 mixed_images [] mixed_targets {bbox: [], cls: [], batch_idx: []} total_consistency_loss torch.tensor(0.0, deviceimages.device) # 为了计算一致性损失需要获取模型对原始图像的预测 # 这里假设模型已经在前向传播中计算了预测 # 实际使用时需要从模型的前向传播中获取预测结果 for i in range(batch_size): lam lam_batch[i] # 混合图像 mixed_img self.mixup_images(images[i], images2[i], lam) mixed_images.append(mixed_img) # 混合边界框 bboxes1 targets[bbox][i] if len(targets[bbox]) i else torch.zeros((0, 4)) bboxes2 targets2[bbox][i] if len(targets2[bbox]) i else torch.zeros((0, 4)) mixed_bboxes self.mixup_bboxes(bboxes1, bboxes2, lam, (images.size(2), images.size(3))) mixed_targets[bbox].append(mixed_bboxes) # 混合标签 cls1 targets[cls][i] if len(targets[cls]) i else torch.zeros((0,)) cls2 targets2[cls][i] if len(targets2[cls]) i else torch.zeros((0,)) mixed_cls self.mixup_labels(cls1, cls2, lam) mixed_targets[cls].append(mixed_cls) # 记录batch索引 batch_idx torch.full((len(mixed_bboxes),), i, dtypetorch.long) mixed_targets[batch_idx].append(batch_idx) # 堆叠混合图像 mixed_images torch.stack(mixed_images) # 合并目标 mixed_targets[bbox] torch.cat(mixed_targets[bbox], dim0) mixed_targets[cls] torch.cat(mixed_targets[cls], dim0) mixed_targets[batch_idx] torch.cat(mixed_targets[batch_idx], dim0) return mixed_images, mixed_targets, total_consistency_loss def __call__(self, images: torch.Tensor, targets: dict, model: nn.Module, features: Optional[torch.Tensor] None) - Tuple[torch.Tensor, dict, torch.Tensor]: 使模块可调用 return self.apply_mixup(images, targets, model, features)4.2 集成到YOLOv8训练流程接下来我们将MixUp with Consistency集成到YOLOv8的训练循环中。我们需要修改YOLOv8的训练器类。python# 修改YOLOv8的训练器类 class YOLOv8Trainer: def __init__(self, model, dataloader, optimizer, device): self.model model self.dataloader dataloader self.optimizer optimizer self.device device # 初始化MixUp with Consistency模块 self.mixup_consistency MixUpWithConsistency( alpha1.0, consistency_weight0.5, mixup_prob0.5, dynamic_lambdaTrue, use_feature_similarityTrue ) # 训练状态 self.epoch 0 self.global_step 0 def train_one_epoch(self): 训练一个epoch self.model.train() for batch_idx, (images, targets) in enumerate(self.dataloader): images images.to(self.device) # 将目标移动到设备 for key in targets: if isinstance(targets[key], torch.Tensor): targets[key] targets[key].to(self.device) # 可选提取特征用于动态λ调整 # 这里简化处理实际使用时需要从模型中提取中间特征 features None if self.mixup_consistency.dynamic_lambda: # 通过模型的前几层提取特征 with torch.no_grad(): # 这里假设模型有一个提取特征的方法 features self.model.extract_features(images) # 应用MixUp with Consistency mixed_images, mixed_targets, consistency_loss self.mixup_consistency( images, targets, self.model, features ) # 前向传播 predictions self.model(mixed_images) # 计算检测损失 det_loss self.model.compute_loss(predictions, mixed_targets) # 总损失 total_loss det_loss self.mixup_consistency.consistency_weight * consistency_loss # 反向传播 self.optimizer.zero_grad() total_loss.backward() self.optimizer.step() # 记录损失 self._log_loss(det_loss, consistency_loss, total_loss) self.global_step 1 self.epoch 1 def _log_loss(self, det_loss, consistency_loss, total_loss): 记录损失值 # 实现日志记录逻辑 pass4.3 完整训练脚本以下是完整的训练脚本包含所有必要的配置和模块pythonimport os import yaml import argparse import torch import torch.optim as optim from torch.utils.data import DataLoader from ultralytics import YOLO from mixup_consistency import MixUpWithConsistency class YOLOv8MixUpTrainer: 集成MixUp with Consistency的YOLOv8训练器 def __init__(self, config_path: str): 初始化训练器 Args: config_path: 配置文件路径 # 加载配置 with open(config_path, r) as f: self.config yaml.safe_load(f) # 设置设备 self.device torch.device(cuda if torch.cuda.is_available() else cpu) # 初始化YOLOv8模型 self.model YOLO(self.config[model_path]) # 初始化优化器 self.optimizer optim.AdamW( self.model.parameters(), lrself.config[learning_rate], weight_decayself.config[weight_decay] ) # 初始化MixUp with Consistency模块 self.mixup_consistency MixUpWithConsistency( alphaself.config.get(mixup_alpha, 1.0), consistency_weightself.config.get(consistency_weight, 0.5), mixup_probself.config.get(mixup_prob, 0.5), dynamic_lambdaself.config.get(dynamic_lambda, True), use_feature_similarityself.config.get(use_feature_similarity, True) ) # 初始化数据加载器 self.train_loader self._create_dataloader(train) self.val_loader self._create_dataloader(val) # 训练状态 self.epoch 0 self.global_step 0 self.best_map 0.0 # 日志 self.log_dir self.config.get(log_dir, ./logs) os.makedirs(self.log_dir, exist_okTrue) def _create_dataloader(self, split: str) - DataLoader: 创建数据加载器 Args: split: train或val Returns: DataLoader实例 # 这里需要根据具体的数据集格式实现 # 示例使用简单的占位符 dataset self._load_dataset(split) dataloader DataLoader( dataset, batch_sizeself.config[batch_size], shuffle(split train), num_workersself.config[num_workers], collate_fnself._collate_fn ) return dataloader def _load_dataset(self, split: str): 加载数据集 # 实现具体的数据集加载逻辑 # 这里返回一个占位符 from torch.utils.data import Dataset class PlaceholderDataset(Dataset): def __len__(self): return 1000 def __getitem__(self, idx): # 返回假数据 image torch.randn(3, 640, 640) target { bbox: torch.randn(5, 4), cls: torch.randint(0, 80, (5,)) } return image, target return PlaceholderDataset() def _collate_fn(self, batch): 自定义collate函数 images [] targets {bbox: [], cls: [], batch_idx: []} for i, (img, target) in enumerate(batch): images.append(img) targets[bbox].append(target[bbox]) targets[cls].append(target[cls]) batch_idx torch.full((len(target[bbox]),), i, dtypetorch.long) targets[batch_idx].append(batch_idx) images torch.stack(images) targets[bbox] torch.cat(targets[bbox], dim0) targets[cls] torch.cat(targets[cls], dim0) targets[batch_idx] torch.cat(targets[batch_idx], dim0) return images, targets def extract_features(self, images: torch.Tensor) - torch.Tensor: 从模型提取特征用于动态λ调整 Args: images: 输入图像 Returns: 特征张量 # 这里使用YOLOv8的backbone提取特征 # 简化版本实际需要根据模型结构实现 with torch.no_grad(): # 假设模型有一个backbone属性 if hasattr(self.model.model, model): # YOLOv8模型结构 x self.model.model.model[0](images) # 第一层 for i in range(1, 5): # 提取前几层特征 x self.model.model.model[i](x) features x.mean(dim[2, 3]) # 全局平均池化 else: # 简化处理 features torch.randn(images.size(0), 256, deviceimages.device) return features def train_one_epoch(self): 训练一个epoch self.model.train() epoch_loss_det 0.0 epoch_loss_consistency 0.0 epoch_loss_total 0.0 for batch_idx, (images, targets) in enumerate(self.train_loader): images images.to(self.device) # 将目标移动到设备 targets[bbox] targets[bbox].to(self.device) targets[cls] targets[cls].to(self.device) targets[batch_idx] targets[batch_idx].to(self.device) # 提取特征用于动态λ调整 features None if self.mixup_consistency.dynamic_lambda and self.mixup_consistency.use_feature_similarity: features self.extract_features(images) # 应用MixUp with Consistency mixed_images, mixed_targets, consistency_loss self.mixup_consistency( images, targets, self.model.model, features ) # 前向传播 # 注意这里需要根据YOLOv8的具体接口进行调整 predictions self.model(mixed_images) # 计算检测损失 # YOLOv8的损失计算需要根据具体版本实现 det_loss self._compute_detection_loss(predictions, mixed_targets) # 总损失 total_loss det_loss self.mixup_consistency.consistency_weight * consistency_loss # 反向传播 self.optimizer.zero_grad() total_loss.backward() self.optimizer.step() # 记录损失 epoch_loss_det det_loss.item() epoch_loss_consistency consistency_loss.item() epoch_loss_total total_loss.item() self.global_step 1 # 打印进度 if batch_idx % self.config[print_freq] 0: print(fEpoch [{self.epoch}] Batch [{batch_idx}/{len(self.train_loader)}] fDet Loss: {det_loss.item():.4f} fConsistency Loss: {consistency_loss.item():.4f} fTotal Loss: {total_loss.item():.4f}) # 计算平均损失 avg_det_loss epoch_loss_det / len(self.train_loader) avg_consistency_loss epoch_loss_consistency / len(self.train_loader) avg_total_loss epoch_loss_total / len(self.train_loader) print(fEpoch [{self.epoch}] Average Loss - fDet: {avg_det_loss:.4f}, fConsistency: {avg_consistency_loss:.4f}, fTotal: {avg_total_loss:.4f}) return { det_loss: avg_det_loss, consistency_loss: avg_consistency_loss, total_loss: avg_total_loss } def _compute_detection_loss(self, predictions, targets): 计算检测损失 Args: predictions: 模型预测 targets: 目标标签 Returns: 检测损失 # 这里需要根据YOLOv8的实现来计算损失 # 简化版本返回一个占位符损失 return torch.tensor(1.0, deviceself.device) def validate(self): 验证模型 self.model.eval() total_metrics {precision: 0.0, recall: 0.0, map50: 0.0, map: 0.0} num_batches 0 with torch.no_grad(): for images, targets in self.val_loader: images images.to(self.device) # 前向传播 predictions self.model(images) # 计算评估指标 metrics self._compute_metrics(predictions, targets) for key in total_metrics: total_metrics[key] metrics[key] num_batches 1 # 计算平均值 for key in total_metrics: total_metrics[key] / num_batches print(fValidation - fPrecision: {total_metrics[precision]:.4f}, fRecall: {total_metrics[recall]:.4f}, fmAP0.5: {total_metrics[map50]:.4f}, fmAP0.5:0.95: {total_metrics[map]:.4f}) # 保存最佳模型 if total_metrics[map] self.best_map: self.best_map total_metrics[map] self._save_checkpoint(best_model.pth) print(fNew best model saved with mAP: {self.best_map:.4f}) return total_metrics def _compute_metrics(self, predictions, targets): 计算评估指标 # 这里需要根据YOLOv8的评估方式实现 # 简化版本返回占位符指标 return { precision: 0.8, recall: 0.7, map50: 0.75, map: 0.65 } def _save_checkpoint(self, filename: str): 保存检查点 checkpoint { epoch: self.epoch, model_state_dict: self.model.state_dict(), optimizer_state_dict: self.optimizer.state_dict(), best_map: self.best_map, global_step: self.global_step } torch.save(checkpoint, os.path.join(self.log_dir, filename)) def train(self, num_epochs: int): 完整训练流程 Args: num_epochs: 训练轮数 for epoch in range(num_epochs): self.epoch epoch # 训练一个epoch train_losses self.train_one_epoch() # 验证 val_metrics self.validate() # 保存定期检查点 if (epoch 1) % self.config[save_freq] 0: self._save_checkpoint(fepoch_{epoch1}.pth) print(Training completed!) print(fBest mAP: {self.best_map:.4f}) def main(): 主函数 parser argparse.ArgumentParser(descriptionTrain YOLOv8 with MixUp Consistency) parser.add_argument(--config, typestr, requiredTrue, helpPath to config file) parser.add_argument(--epochs, typeint, default100, helpNumber of training epochs) args parser.parse_args() trainer YOLOv8MixUpTrainer(args.config) trainer.train(args.epochs) if __name__ __main__: main()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2451820.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

SpringBoot-17-MyBatis动态SQL标签之常用标签

文章目录 1 代码1.1 实体User.java1.2 接口UserMapper.java1.3 映射UserMapper.xml1.3.1 标签if1.3.2 标签if和where1.3.3 标签choose和when和otherwise1.4 UserController.java2 常用动态SQL标签2.1 标签set2.1.1 UserMapper.java2.1.2 UserMapper.xml2.1.3 UserController.ja…

wordpress后台更新后 前端没变化的解决方法

使用siteground主机的wordpress网站,会出现更新了网站内容和修改了php模板文件、js文件、css文件、图片文件后,网站没有变化的情况。 不熟悉siteground主机的新手,遇到这个问题,就很抓狂,明明是哪都没操作错误&#x…

网络编程(Modbus进阶)

思维导图 Modbus RTU(先学一点理论) 概念 Modbus RTU 是工业自动化领域 最广泛应用的串行通信协议,由 Modicon 公司(现施耐德电气)于 1979 年推出。它以 高效率、强健性、易实现的特点成为工业控制系统的通信标准。 包…

UE5 学习系列(二)用户操作界面及介绍

这篇博客是 UE5 学习系列博客的第二篇,在第一篇的基础上展开这篇内容。博客参考的 B 站视频资料和第一篇的链接如下: 【Note】:如果你已经完成安装等操作,可以只执行第一篇博客中 2. 新建一个空白游戏项目 章节操作,重…

IDEA运行Tomcat出现乱码问题解决汇总

最近正值期末周,有很多同学在写期末Java web作业时,运行tomcat出现乱码问题,经过多次解决与研究,我做了如下整理: 原因: IDEA本身编码与tomcat的编码与Windows编码不同导致,Windows 系统控制台…

利用最小二乘法找圆心和半径

#include <iostream> #include <vector> #include <cmath> #include <Eigen/Dense> // 需安装Eigen库用于矩阵运算 // 定义点结构 struct Point { double x, y; Point(double x_, double y_) : x(x_), y(y_) {} }; // 最小二乘法求圆心和半径 …

使用docker在3台服务器上搭建基于redis 6.x的一主两从三台均是哨兵模式

一、环境及版本说明 如果服务器已经安装了docker,则忽略此步骤,如果没有安装,则可以按照一下方式安装: 1. 在线安装(有互联网环境): 请看我这篇文章 传送阵>> 点我查看 2. 离线安装(内网环境):请看我这篇文章 传送阵>> 点我查看 说明&#xff1a;假设每台服务器已…

XML Group端口详解

在XML数据映射过程中&#xff0c;经常需要对数据进行分组聚合操作。例如&#xff0c;当处理包含多个物料明细的XML文件时&#xff0c;可能需要将相同物料号的明细归为一组&#xff0c;或对相同物料号的数量进行求和计算。传统实现方式通常需要编写脚本代码&#xff0c;增加了开…

LBE-LEX系列工业语音播放器|预警播报器|喇叭蜂鸣器的上位机配置操作说明

LBE-LEX系列工业语音播放器|预警播报器|喇叭蜂鸣器专为工业环境精心打造&#xff0c;完美适配AGV和无人叉车。同时&#xff0c;集成以太网与语音合成技术&#xff0c;为各类高级系统&#xff08;如MES、调度系统、库位管理、立库等&#xff09;提供高效便捷的语音交互体验。 L…

(LeetCode 每日一题) 3442. 奇偶频次间的最大差值 I (哈希、字符串)

题目&#xff1a;3442. 奇偶频次间的最大差值 I 思路 &#xff1a;哈希&#xff0c;时间复杂度0(n)。 用哈希表来记录每个字符串中字符的分布情况&#xff0c;哈希表这里用数组即可实现。 C版本&#xff1a; class Solution { public:int maxDifference(string s) {int a[26]…

【大模型RAG】拍照搜题技术架构速览:三层管道、两级检索、兜底大模型

摘要 拍照搜题系统采用“三层管道&#xff08;多模态 OCR → 语义检索 → 答案渲染&#xff09;、两级检索&#xff08;倒排 BM25 向量 HNSW&#xff09;并以大语言模型兜底”的整体框架&#xff1a; 多模态 OCR 层 将题目图片经过超分、去噪、倾斜校正后&#xff0c;分别用…

【Axure高保真原型】引导弹窗

今天和大家中分享引导弹窗的原型模板&#xff0c;载入页面后&#xff0c;会显示引导弹窗&#xff0c;适用于引导用户使用页面&#xff0c;点击完成后&#xff0c;会显示下一个引导弹窗&#xff0c;直至最后一个引导弹窗完成后进入首页。具体效果可以点击下方视频观看或打开下方…

接口测试中缓存处理策略

在接口测试中&#xff0c;缓存处理策略是一个关键环节&#xff0c;直接影响测试结果的准确性和可靠性。合理的缓存处理策略能够确保测试环境的一致性&#xff0c;避免因缓存数据导致的测试偏差。以下是接口测试中常见的缓存处理策略及其详细说明&#xff1a; 一、缓存处理的核…

龙虎榜——20250610

上证指数放量收阴线&#xff0c;个股多数下跌&#xff0c;盘中受消息影响大幅波动。 深证指数放量收阴线形成顶分型&#xff0c;指数短线有调整的需求&#xff0c;大概需要一两天。 2025年6月10日龙虎榜行业方向分析 1. 金融科技 代表标的&#xff1a;御银股份、雄帝科技 驱动…

观成科技:隐蔽隧道工具Ligolo-ng加密流量分析

1.工具介绍 Ligolo-ng是一款由go编写的高效隧道工具&#xff0c;该工具基于TUN接口实现其功能&#xff0c;利用反向TCP/TLS连接建立一条隐蔽的通信信道&#xff0c;支持使用Let’s Encrypt自动生成证书。Ligolo-ng的通信隐蔽性体现在其支持多种连接方式&#xff0c;适应复杂网…

铭豹扩展坞 USB转网口 突然无法识别解决方法

当 USB 转网口扩展坞在一台笔记本上无法识别,但在其他电脑上正常工作时,问题通常出在笔记本自身或其与扩展坞的兼容性上。以下是系统化的定位思路和排查步骤,帮助你快速找到故障原因: 背景: 一个M-pard(铭豹)扩展坞的网卡突然无法识别了,扩展出来的三个USB接口正常。…

未来机器人的大脑:如何用神经网络模拟器实现更智能的决策?

编辑&#xff1a;陈萍萍的公主一点人工一点智能 未来机器人的大脑&#xff1a;如何用神经网络模拟器实现更智能的决策&#xff1f;RWM通过双自回归机制有效解决了复合误差、部分可观测性和随机动力学等关键挑战&#xff0c;在不依赖领域特定归纳偏见的条件下实现了卓越的预测准…

Linux应用开发之网络套接字编程(实例篇)

服务端与客户端单连接 服务端代码 #include <sys/socket.h> #include <sys/types.h> #include <netinet/in.h> #include <stdio.h> #include <stdlib.h> #include <string.h> #include <arpa/inet.h> #include <pthread.h> …

华为云AI开发平台ModelArts

华为云ModelArts&#xff1a;重塑AI开发流程的“智能引擎”与“创新加速器”&#xff01; 在人工智能浪潮席卷全球的2025年&#xff0c;企业拥抱AI的意愿空前高涨&#xff0c;但技术门槛高、流程复杂、资源投入巨大的现实&#xff0c;却让许多创新构想止步于实验室。数据科学家…

深度学习在微纳光子学中的应用

深度学习在微纳光子学中的主要应用方向 深度学习与微纳光子学的结合主要集中在以下几个方向&#xff1a; 逆向设计 通过神经网络快速预测微纳结构的光学响应&#xff0c;替代传统耗时的数值模拟方法。例如设计超表面、光子晶体等结构。 特征提取与优化 从复杂的光学数据中自…