从 PyTorch 到 ONNX:深度学习模型导出全解析

news2025/5/25 22:15:58

在模型训练完毕后,我们通常希望将其部署到推理平台中,比如 TensorRT、ONNX Runtime 或移动端框架。而 ONNX(Open Neural Network Exchange)正是 PyTorch 与这些平台之间的桥梁。

本文将以一个图像去噪模型 SimpleDenoiser 为例,手把手带你完成 PyTorch 模型导出为 ONNX 格式的全过程,并解析每一行代码背后的逻辑。

准备工作

我们假设你已经训练好一个图像去噪模型并保存为 .pth 文件,模型结构自编码器实现如下(略):

class SimpleDenoiser(nn.Module):
    def __init__(self):
        super(SimpleDenoiser, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1), nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding=1), nn.ReLU()
        )
        self.decoder = nn.Sequential(
            nn.Conv2d(64, 64, 3, padding=1), nn.ReLU(),
            nn.Conv2d(64, 3, 3, padding=1)
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

导出代码分解

我们现在来看导出脚本的核心逻辑,并分块解释它的每一部分。

1. 导入模块 & 设置路径

//torch:核心框架

//train.SimpleDenoiser:从训练脚本复用模型结构

//os:用于创建输出目录

import torch
from train import SimpleDenoiser  # 模型结构
import os

2. 导出函数定义

//这个函数接收三个参数:

//pth_path: 训练得到的模型参数文件路径

//onnx_path: 导出的 ONNX 文件保存路径

//input_size: 模拟推理输入的尺寸(默认 1×3×256×256)
def export_model_to_onnx(pth_path, onnx_path, input_size=(1, 3, 256, 256)):

3. 加载模型和权重

//自动检测 CUDA 可用性,加载模型到对应设备;

//使用 load_state_dict() 加载训练好的参数;

//model.eval() 让模型切换到推理模式(关闭 Dropout/BatchNorm 更新);
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = SimpleDenoiser().to(device)
model.load_state_dict(torch.load(pth_path, map_location=device))
model.eval()

4. 构造假输入(Dummy Input)

//ONNX 导出需要一个具体的输入样本,我们这里用 torch.randn 生成一个形状为 (1, 3, 256, 256) 的随机图//像;

//输入必须放在同一个设备上(GPU 或 CPU);
dummy_input = torch.randn(*input_size).to(device)

5. 导出为 ONNX

torch.onnx.export(
    model,  //要导出的模型
    dummy_input,  //示例输入张量
    onnx_path, //	导出路径
    export_params=True,  //是否导出权重
    opset_version=11,  //ONNX 的算子集版本,通常推荐 11 或 13
    do_constant_folding=True,  //优化常量表达式,减小模型体积
    input_names=['input'],  //自定义输入输出张量的名称
    output_names=['output'],  //声明哪些维度可以变动,比如 batch size、图像大小等(部署时更灵活)
    dynamic_axes={
        'input': {0: 'batch_size', 2: 'height', 3: 'width'},
        'output': {0: 'batch_size', 2: 'height', 3: 'width'}
    }
)

6. 创建目录并调用函数

//确保输出文件夹存在,并调用导出函数生成最终模型。
if __name__ == "__main__":
    os.makedirs("onnx", exist_ok=True)
    export_model_to_onnx("weights/denoiser.pth", "onnx/denoiser.onnx")

导出后如何验证?

pip install onnxruntime
import onnxruntime
import numpy as np

sess = onnxruntime.InferenceSession("onnx/denoiser.onnx")
input = np.random.randn(1, 3, 256, 256).astype(np.float32)
output = sess.run(None, {"input": input})
print("输出 shape:", output[0].shape)

 模型预览:

总结

导出 ONNX 模型的流程主要包括:

  1. 加载模型结构 + 权重

  2. 准备 dummy 输入张量

  3. 调用 torch.onnx.export() 进行导出

  4. 设置 dynamic_axes 可变尺寸以增强部署适配性

这套流程适用于大部分视觉模型(分类、去噪、分割等),也是后续进行 TensorRT 推理或移动端部署的基础。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2336478.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Android 应用添加Tile到SystemUI QuickSettings

安卓源码里有谷歌给的关于 Tile 的说明。 frameworks/base/packages/SystemUI/docs/qs-tiles.md SystemUI QuickSettings 简称QS,指的是 下拉菜单里的区域。区域里的一个选项就是一个 Tile 。 下图是 frameworks/base/packages/SystemUI/docs/ 里的附图示例&#…

【MySQL】前缀索引、索引下推、访问方法,自适应哈希索引

最左前缀原则 对于INDEX(name, age)来说最左前缀可以是联合索引的最左N个字段, 也可以是字符串索引的最左M个字符。 SELECT * FROM t WHERE name LIKE 张%其效果和单独创建一个INDEX(name)的效果是一样的若通过调整索引字段的顺序, 可以少维护一个索引树, 那么这个顺序就是需要…

Android Studio开发知识:从基础到进阶

引言 Android开发作为移动应用开发的主流方向之一,曾吸引了无数开发者投身其中。然而,随着市场饱和和技术迭代,当前的Android开发就业形势并不乐观,竞争日益激烈。尽管如此,掌握扎实的开发技能仍然是脱颖而出的关键。本…

ocr-身份证正反面识别

在阿里云官网,申请一个token [阿里官方]身份证OCR文字识别_API专区_云市场-阿里云 (aliyun.com) 观察一下post请求body部分json字符串,我们根据这个创建一个java对象 先默认是人像面 public class IdentityBody {public String image;class configure…

单节锂电池4.2V升压5V都有哪些国产芯片推荐?国产SL4011高效,高性价比

针对单节锂电池(4.2V)升压至5V应用中 SL4011升压芯片 的核心优势解析,结合其技术参数与典型应用场景进行详细说明: 1. 宽输入电压与高兼容性 输入范围:2.7V-12V,完美覆盖单节锂电池全周期电压(3…

机器学习 | 神经网络介绍 | 概念向

文章目录 📚从生物神经元到人工神经元📚神经网络初识🐇激活函数——让神经元“动起来”🐇权重与偏置——调整信息的重要性🐇训练神经网络——学习的过程🐇过拟合与正则化——避免“死记硬背” &#x1f440…

视频孪生重构施工逻辑:智慧工地的数字化升级

当"智慧工地"概念在2017年首次写入《建筑业发展"十三五"规划》时,行业普遍将其等同于摄像头与传感器的简单叠加。十年数字浪潮冲刷下,智慧工地的内涵已发生本质跃迁:从工具层面的信息化改造,进化为基于视频数…

六根觉性:穿透表象的清净觉知之光

在喧嚣的禅堂里,老禅师轻叩茶盏,清脆的声响划破沉寂。这声"叮"不仅震动耳膜,更叩击着修行者的心性——这正是佛教揭示的六根觉性在世间万相中的妙用。当我们凝视《楞严经》中二十五圆通法门,六根觉性犹如六道澄明之光&a…

spring:注解@Component、@Controller、@Service、@Reponsitory

背景 spring框架的一个核心功能是IOC,就是将Bean初始化加载到容器中,Bean是如何加载到容器的,可以使用spring注解方式或者spring XML配置方式。 spring注解方式直接对项目中的类进行注解,减少了配置文件内容,更加便于…

Halcon应用:九点标定-手眼标定

提示:若没有查找的算子,可以评论区留言,会尽快更新 Halcon应用:九点标定-手眼标定 前言一、Halcon应用?二、应用实战1、图形理解[eye-to-hand]:1.1、开始应用2 图形理解[eye-in-hand] 前言 本篇博文主要用…

【iOS】OC高级编程 iOS多线程与内存管理阅读笔记——自动引用计数(一)

自动引用计数 前言alloc/retain/release/dealloc实现苹果的实现 autoreleaseautorelease实现苹果的实现 总结 前言 此前,写过一遍对自动引用计数的简单学习,因此掠过其中相同的部分:引用计数初步学习 alloc/retain/release/dealloc实现 由于…

Python爬虫第15节-2025今日头条街拍美图抓取实战

目录 一、项目背景与概述 二、环境准备与工具配置 2.1 开发环境要求 2.2 辅助工具配置 三、详细抓取流程解析 3.1 页面加载机制分析 3.2 关键请求识别技巧 3.3 参数规律深度分析 四、爬虫代码实现 五、实现关键 六、法律与道德规范 一、项目概述 在当今互联网时代&a…

智慧城市像一张无形大网,如何紧密连接你我他?

智慧城市作为复杂巨系统,其核心在于通过技术创新构建无缝连接的网络,使物理空间与数字空间深度融合。这张"无形大网"由物联网感知层、城市数据中台、人工智能中枢、数字服务入口和安全信任机制五大支柱编织而成,正在重塑城市运行规…

网络安全·第四天·扫描工具Nmap的运用

今天我们要介绍网络安全中常用的一种扫描工具Nmap,它被设计用来快速扫描大型网络,主要功能包括主机探测、端口扫描以及版本检测,小编将在下文详细介绍Nmap相应的命令。 Nmap的下载安装地址为:Nmap: the Network Mapper - Free Se…

黑龙江 GPU 服务器租用:开启高效计算新征程

随着人工智能、深度学习、大数据分析等技术的广泛应用,对强大计算能力的需求日益迫切。GPU 服务器作为能够提供卓越并行计算能力的关键设备,在这一进程中发挥着至关重要的作用。对于黑龙江地区的企业、科研机构和开发者而言,选择合适的 GPU 服…

SparseDrive---论文阅读

纯视觉下的稀疏场景表示 算法动机&开创性思路 算法动机: 依赖于计算成本高昂的鸟瞰图(BEV)特征表示。预测和规划的设计过于直接,没有充分利用周围代理和自我车辆之间的高阶和双向交互。场景信息是在agent周围提取&#xff…

Unchained 内容全面上链,携手 Walrus 迈入去中心化媒体新时代

加密新闻媒体 Unchained — — 业内最受信赖的声音之一 — — 现已选择 Walrus 作为其去中心化存储解决方案,正式将其所有媒体内容(文章、播客和视频)上链存储。Walrus 将替代 Unchained 现有的中心化存储架构,接管其全部历史内容…

确保连接器后壳高性能互连的完整性

本文探讨了现代后壳技术如何促进高性能互连的电气和机械完整性,以及在规范阶段需要考虑的一些关键因素。 当今的航空航天、国防和医疗应用要求连接器能够提供高速和紧凑的互连,能够承受振动和冲击,并保持对电磁和射频干扰 (EMI/R…

C++学习Day0:c++简介

目录 一、.C语言的发展史二、C特点三、面向对象的重要术语四、面向过程和面向对象的区别?五、开发环境:六、创建文件步骤:1.点击新建项目2.在弹出的开始栏中按如下操作3.在.pro文件中添加(重要!!&#xff0…

从零开始构建 Ollama + MCP 服务器

Model Context Protocol(模型上下文协议)在过去几个月里已经霸占了大家的视野,出现了许多酷炫的集成示例。我坚信它会成为一种标准,因为它正在定义工具与代理或软件与 AI 模型之间如何集成的新方式。 我决定尝试将 Ollama 中的一…