保姆级教程:用ONNXRuntime对比YOLO11的PyTorch与ONNX输出差异
保姆级教程用ONNXRuntime对比YOLO11的PyTorch与ONNX输出差异在模型部署的实践中PyTorch到ONNX的转换是常见需求但转换后的模型输出是否与原始模型一致却容易被忽视。本文将手把手教你如何通过ONNXRuntime对比YOLO11模型在PyTorch和ONNX两种格式下的输出差异建立完整的验证流程。1. 环境准备与模型导出首先确保已安装必要的Python包pip install ultralytics onnxruntime numpy opencv-pythonYOLO11模型的导出有两种常见方式1.1 直接导出预训练模型from ultralytics import YOLO # 加载官方预训练模型 model YOLO(yolo11n.pt) # 导出为ONNX格式 model.export( formatonnx, dynamicFalse, simplifyFalse, nmsTrue, conf0.25, iou0.45 )1.2 自定义训练后导出# 训练自定义模型 model YOLO(yolo11n.pt) model.train(datacoco128.yaml, epochs10) # 导出训练好的模型 model.export( formatonnx, opset17, nmsFalse # 保留原始输出格式 )关键区别nmsTrue时输出格式为[1,N,6]False时为[1,84,8400]2. PyTorch模型推理验证在转换前必须确保原始PyTorch模型工作正常import cv2 import numpy as np from PIL import Image # 加载测试图像 img cv2.imread(bus.jpg) img_rgb cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img_resized cv2.resize(img_rgb, (640, 640)) # PyTorch推理 results model.predict(img_resized, conf0.25) boxes results[0].boxes # 输出关键信息 print(f检测框数量: {len(boxes)}) print(前5个框的置信度:, boxes.conf[:5].tolist()) print(类别分布:, boxes.cls.unique(return_countsTrue))典型输出示例检测框数量: 42 前5个框的置信度: [0.92, 0.89, 0.87, 0.85, 0.82] 类别分布: (tensor([2, 5, 7]), tensor([15, 3, 24]))3. ONNX模型推理与对比3.1 基础推理设置import onnxruntime as ort # 创建推理会话 sess ort.InferenceSession(yolo11n.onnx, providers[CPUExecutionProvider]) # 准备输入数据 input_name sess.get_inputs()[0].name input_data np.expand_dims( img_resized.transpose(2,0,1).astype(np.float32)/255.0, axis0 ) # 执行推理 onnx_output sess.run(None, {input_name: input_data})[0]3.2 输出差异分析我们需要从三个维度对比差异1. 坐标差异分析# 计算框坐标的L2距离 pytorch_boxes boxes.xyxyn.cpu().numpy() onnx_boxes onnx_output[0,:,:4] coord_diff np.sqrt(np.sum((pytorch_boxes - onnx_boxes)**2, axis1)) print(f平均坐标差异: {np.mean(coord_diff):.4f}) print(f最大坐标差异: {np.max(coord_diff):.4f})2. 置信度波动分析conf_diff np.abs(boxes.conf.cpu().numpy() - onnx_output[0,:,4]) print(f置信度平均差异: {np.mean(conf_diff):.4f}) print(f差异超过0.1的比例: {np.mean(conf_diff0.1):.2%})3. 类别一致性检查cls_match (boxes.cls.cpu().numpy() onnx_output[0,:,5]).mean() print(f类别一致率: {cls_match:.2%})3.3 差异可视化import matplotlib.pyplot as plt plt.figure(figsize(12,4)) plt.subplot(131) plt.hist(coord_diff, bins50) plt.title(坐标差异分布) plt.subplot(132) plt.hist(conf_diff, bins50) plt.title(置信度差异分布) plt.subplot(133) plt.bar([匹配, 不匹配], [cls_match, 1-cls_match]) plt.title(类别一致性) plt.tight_layout() plt.show()4. 常见问题排查指南4.1 输出维度不匹配当遇到输出shape不一致时按此流程排查检查PyTorch模型的输出层结构确认ONNX导出时的nms参数设置使用Netron可视化模型结构4.2 数值差异过大若发现显著差异# 检查输入数据一致性 print(输入数据范围对比:) print(fPyTorch输入: {input_data.min():.3f}~{input_data.max():.3f}) print(fONNX输入: {img_resized.min():.3f}~{img_resized.max():.3f}) # 检查预处理是否一致 assert np.allclose( input_data, img_resized.transpose(2,0,1)[None]/255.0, atol1e-5 )4.3 典型异常案例案例1输出全零可能原因模型导出时opset版本不兼容输入数据未归一化案例2置信度异常低解决方案# 调整导出参数重新导出 model.export(conf0.01) # 降低置信度阈值5. 高级对比技巧5.1 批处理差异分析# 准备批处理数据 batch_size 4 batch_data np.stack([input_data[0]]*batch_size) # PyTorch推理 pt_output model(batch_data)[0] # ONNX推理 onnx_output sess.run(None, {input_name: batch_data})[0] # 计算批处理差异 batch_diff np.mean(np.abs(pt_output - onnx_output)) print(f批处理平均差异: {batch_diff:.6f})5.2 量化误差分析# 将模型导出为FP16 model.export(formatonnx, halfTrue) # 对比FP32与FP16结果 fp16_sess ort.InferenceSession(yolo11n_fp16.onnx, providers[CPUExecutionProvider]) fp16_output fp16_sess.run(None, {input_name: input_data})[0] quant_error np.max(np.abs(onnx_output - fp16_output)) print(fFP16量化最大误差: {quant_error:.4f})5.3 跨设备一致性验证# 在CUDA设备上运行 cuda_sess ort.InferenceSession(yolo11n.onnx, providers[CUDAExecutionProvider]) cuda_output cuda_sess.run(None, {input_name: input_data})[0] device_diff np.max(np.abs(onnx_output - cuda_output)) print(fCPU与CUDA输出差异: {device_diff:.6f})6. 自动化验证脚本以下是一个完整的验证脚本模板import json from pathlib import Path class ModelValidator: def __init__(self, model_path): self.model_path Path(model_path) self.results {} def run_validation(self, test_imagebus.jpg): # 实现完整的验证流程 self._validate_pytorch() self._export_onnx() self._validate_onnx() self._compare_results() # 保存验证结果 with open(self.model_path.parent/validation.json, w) as f: json.dump(self.results, f, indent2) # 各验证方法实现...使用方式validator ModelValidator(yolo11n.pt) validator.run_validation()在实际项目中建议将这类验证流程集成到CI/CD管道中确保每次模型更新都能自动验证输出一致性。
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2449192.html
如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!