Pytorch复习笔记--导出Onnx模型为动态输入和静态输入

news2025/7/17 7:45:59

目录

1--动态输入和静态输入

2--Pytorch API

3--完整代码演示

4--模型可视化

5--测试动态导出的Onnx模型


1--动态输入和静态输入

        当使用 Pytorch 将网络导出为 Onnx 模型格式时,可以导出为动态输入和静态输入两种方式。动态输入即模型输入数据的部分维度是动态的,可以由用户在使用模型时自主设定;静态输入即模型输入数据的维度是静态的,不能够改变,当用户使用模型时只能输入指定维度的数据进行推理。

        显然,动态输入的通用性比静态输入更强。

2--Pytorch API

        在 Pytorch 中,通过 torch.onnx.export() 的 dynamic_axes 参数来指定动态输入和静态输入,dynamic_axes 的默认值为 None,即默认为静态输入。

        以下展示动态导出的用法,通过定义 dynamic_axes 参数来设置动态导出输入。dynamic_axes 中的 0、2、3 表示相应的维度设置为动态值;

# 导出为动态输入
input_name = 'input'
output_name = 'output'
torch.onnx.export(model, 
                    input_data, 
                    "Dynamics_InputNet.onnx",
                    opset_version=11,
                    input_names=[input_name],
                    output_names=[output_name],
                    dynamic_axes={
                        input_name: {0: 'batch_size', 2: 'input_height', 3: 'input_width'},
                        output_name: {0: 'batch_size', 2: 'output_height', 3: 'output_width'}})

3--完整代码演示

        在以下代码中,定义了一个网络,并使用动态导出和静态导出两种方式,将网络导出为 Onnx 模型格式。

import torch
import torch.nn as nn

class Model_Net(nn.Module):
    def __init__(self):
        super(Model_Net, self).__init__()
        self.layer1 = nn.Sequential(

            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),

            nn.Conv2d(in_channels=64, out_channels=256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
        )
        
    def forward(self, data):
        data = self.layer1(data)
        return data

if __name__ == "__main__":

    # 设置输入参数
    Batch_size = 8
    Channel = 3
    Height = 256
    Width = 256
    input_data = torch.rand((Batch_size, Channel, Height, Width))

    # 实例化模型
    model = Model_Net()

    # 导出为静态输入
    input_name = 'input'
    output_name = 'output'
    torch.onnx.export(model, 
                      input_data, 
                      "Static_InputNet.onnx", 
                      verbose=True, 
                      input_names=[input_name], 
                      output_names=[output_name])

    # 导出为动态输入
    torch.onnx.export(model, 
                      input_data, 
                      "Dynamics_InputNet.onnx",
                      opset_version=11,
                      input_names=[input_name],
                      output_names=[output_name],
                      dynamic_axes={
                          input_name: {0: 'batch_size', 2: 'input_height', 3: 'input_width'},
                          output_name: {0: 'batch_size', 2: 'output_height', 3: 'output_width'}})

4--模型可视化

        通过 netron 库可视化导出的静态模型和动态模型,代码如下:

import netron

netron.start("./Dynamics_InputNet.onnx")

        静态模型可视化:

         动态模型可视化:

5--测试动态导出的Onnx模型

import numpy as np
import onnx
import onnxruntime
 
if __name__ == "__main__":
    input_data1 = np.random.rand(4, 3, 256, 256).astype(np.float32)
    input_data2 = np.random.rand(8, 3, 512, 512).astype(np.float32)
    
    # 导入 Onnx 模型
    Onnx_file = "./Dynamics_InputNet.onnx"
    Model = onnx.load(Onnx_file)
    onnx.checker.check_model(Model) # 验证Onnx模型是否准确
    
    # 使用 onnxruntime 推理
    model = onnxruntime.InferenceSession(Onnx_file, providers=['TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider'])
    input_name = model.get_inputs()[0].name
    output_name = model.get_outputs()[0].name
 
    output1 = model.run([output_name], {input_name:input_data1})
    output2 = model.run([output_name], {input_name:input_data2})
 
    print('output1.shape: ', np.squeeze(np.array(output1), 0).shape)
    print('output2.shape: ', np.squeeze(np.array(output2), 0).shape)

         由输出结果可知,对应动态输入 Onnx 模型,其输出维度也是动态的,并且为对应关系,则表明导出的 Onnx 模型无误。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/155378.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Elasticsearch:运用 Go 语言实现 Elasticsearch 搜索 - 8.x

在我之前的文章 “Elasticsearch:Go 客户端简介 - 8.x”,我对 Elasticsearch golang 客户端做了一个简单的介绍。在今天的这篇文章中,我将详细介绍如何使用这个客户端来一步一步地连接到 Elasticsearch,进而创建索引,搜…

流程编辑器bpmnjs的改造1:设计器外观和布局

重新设计页面,弄一个比较规范的设计器外观和布局,bpmnjs.css加入如下的代码:/* Frame CSS */html,body{width:100%;height:100%}.toolsBar{position:fixed;width:100%;height:40px;background-color:#FFF; border-bottom:1px solid #E1E1E1;d…

Linux安装Docker完整详细教程

目录 Docker及系统版本 Docker的自动化安装 Docker的手动安装(CentOS7) 1.1 卸载历史版本的Docker 1.2 安装依赖包 1.3 更新本地镜像源(也可以叫做:设置源仓库) 1.4 Docker安装 1.5 配置镜像加速 Docker启动 删除Docker Docker其…

连接池PgBouncer部署与踩坑实践

安装 可以直接使用 yum install pgbouncer 安装(rpm管理的是1.14版本) 或者在http://www.pgbouncer.org/downloads/ 下载最新的tat.gz包 解压出来进入目录,通过 ./configure --prefix/home/pgbouncermake & make install 安装&…

01等概率发生器、随机函数、对数器

1.数据结构 数据结构:是由连续结构、跳转结构或者连续加跳转(可能有多个叉)结构组成 数据结构是很多算法得以进行的载体 数组:便于寻址不便于删增数据(需要不断移动数据,如果不动可能就不是连续结构) 链表(跳转结构…

jupyter notebook 暗黑模式新方法

1 直接浏览器采用暗黑模式 (1)首先我们打开谷歌浏览器,在浏览器地址栏中输入“chrome://flags”然后按下回车键。 (2)之后我们会进入谷歌浏览器的实验室页面,在页面左上方的搜索框中输入“enable-force-…

DocuWare客户案例——温德姆镇使用 DocuWare Cloud 改善市民服务

DocuWare客户案例——温德姆镇使用 DocuWare Cloud 改善市民服务 新冠疫情刚开始时,州和地方政府除了发挥传统作用以外,还要负责遏制疫情的关键措施。税收和联邦援助的收入没有增加,跟不上这一新职责的需求。采用减轻管理负担的技术是节省资源…

2022十大边缘计算开源项目

随着“开源”被纳入“十四五”规划发展纲要,“支持数字技术开源社区等创新联合体发展,完善开源知识产权和法律体系,鼓励企业开放软件源代码、硬件设计和应用服务”。开源发展按下了加速键! 开源软件生态蓬勃发展,边缘…

Internet结构和ISP

目录 1. ISP / IXP / ICP 定义 2. 网络连接宏观结构 3. 网络连接层级结构 4. ISP 连接方式 1. ISP / IXP / ICP 定义 ISP:Internet Service Provider,即互联网服务提供商。主要为用户提供互联网接入业务、信息业务的运营商,如移动和电信等。 …

数据结构学习之栈

这里写目录标题栈的定义与性质栈的实现栈的定义栈的功能栈的创建入栈出栈栈顶判断栈为空得到栈的个数栈的销毁栈的定义与性质 第一个问题:什么是栈? 栈的定义是: 一种特殊的线性表,其只允许在固定的一端进行插入和删除元素操作。…

【从零开始学习深度学习】45. Pytorch迁移学习微调方法实战:使用微调技术进行2分类图片热狗识别模型训练【含源码与数据集】

通常为了使模型的预测精度达到较高的标准,需要收集十分庞大的数据集来进行模型训练。一种比较巧妙解决该问题的办法是应用迁移学习(transfer learning),将从某个已有的数据集学到的知识迁移到目标数据集上。例如,假如我…

微信小程序安装 Vant 组件库与API Promise组件库并实现简单的增删改查

在项目内右键空白处选择在外部终端打开2、在终端窗口输入 npm init -y,创建package-lock.jsonnpm init -y3、在终端输入npm i vant/weapp1.3.3 -S --production,创建node_modules文件夹npm i vant/weapp1.3.3 -S --production4、详情-本地设置&#xff0…

Vue2.0开发之——Vue组件-组件的实例对象(36)

一 概述 浏览器无法直接解析Vue文件package.json中的’vue-template-compiler’将vue结尾的文件解析为js文件交给浏览器处理Count组件实例对象 二 浏览器无法直接解析Vue文件 将Vue文件拖放到浏览器中无法直接显示 三 package.json中的’vue-template-compiler’将vue结尾的文…

软件著作权登记指南

一、什么是计算机软件《计算机软件保护条例》第二条、第三条规定,本条例所称计算机软件(以下简称软件),是指计算机程序及其有关文档;(一)计算机程序,是指为了得到某种结果而可以由计…

第13章 Token的Postman、Swagger和Vue调试

1 准备工作 1.1 WebApi.Controllers.JwtSettingModel namespace WebApi.Test { /// <summary> /// 【Jwt设置模型--纪录】 /// <remarks> /// 摘要&#xff1a; /// 通过该纪录中的属性成员实例存储“AppSettings.json”文件中的Jwt相关设置数据&#xff0…

java应用程序多级缓存架构

多级缓存架构 一级缓存&#xff1a;OpenResty—Lua—Redis 二级缓存&#xff1a;Nginx proxy-cache 三级缓存&#xff1a;Redis 使用OpenResty lua脚本访问redis proxy-cache 缓存注解 <!--依赖--> <dependency><groupId>org.springframework.boot</gr…

最新研究发现:天然海绵含有抑制Omicron变体感染的天然化合物

本文原文首发于2023年1月9日E-LIFESTYLE &#xff08;阅读时间4分钟&#xff09; 附标题&#xff1a;通过研究370多种来自植物、真菌和海绵等天然来源的化合物&#xff0c;寻找可用于治疗新冠肺炎的新抗病毒药物&#xff0c;用这些天然化合物制成的溶液中沐浴人类被SARS-CoV-2感…

SolidWorks装配体保存成零件,能有效压缩文件体积,方便二次装配

SolidWorks装配体保存成零件&#xff0c;能有效压缩文件体积&#xff0c;方便二次装配1. 先使用solidworks打开我们要转换成零件的装配体2. 然后点击上方保存下面的小三角&#xff0c;选择另存为3.之后选择要保存的位置&#xff0c;点击文件格式&#xff0c;然后在文件格式里找…

Zabbix监控服务详解+实战

目录 一、监控体系概述 1. 为什么需要监控 2. 监控目标与流程 &#xff08;1&#xff09;监控的目标 &#xff08;2&#xff09; 监控的流程 3. 监控的对象 &#xff08;1&#xff09;CPU监控 &#xff08;2&#xff09;磁盘监控 &#xff08;3&#xff09;内存监控 …

win7电脑怎么录屏?免费的录屏软件分享

现在大家的电脑一般是win10、11系统&#xff0c;但是还是有一些小伙伴喜欢使用win7系统的电脑。那你知道win7电脑怎么录屏吗&#xff1f;有没有好用且简单的win7电脑录屏软件推荐&#xff1f;当然有&#xff01;今天小编给使用win7电脑的小伙伴推荐两款简单且好用的电脑录屏软件…