从 PyTorch 到 TensorFlow Lite:模型训练与推理

news2025/6/4 22:15:25

一、方案介绍

  1. 研发阶段:利用 PyTorch 的动态图特性进行快速原型验证,快速迭代模型设计。
    • 灵活性与易用性:PyTorch 是一个非常灵活且易于使用的深度学习框架,特别适合研究和实验。其动态计算图特性使得模型的构建和调试变得更加直观,开发者可以在运行时修改模型结构。
    • 快速原型开发:许多研究人员和开发者选择 PyTorch 进行模型训练,因为它支持快速原型开发和灵活的模型设计,能够快速验证新想法并进行迭代。
  2. 转换阶段:将训练好的模型通过 TorchScript 导出为 ONNX 格式,再转换为 TensorFlow 格式,最后生成 TFLite 模型。
    • 专为移动和嵌入式设备优化:TensorFlow Lite 是专为移动和嵌入式设备设计的推理框架,能够在资源有限的环境中高效运行模型,确保在各种设备上实现实时推理。
    • 支持模型量化和优化:TFLite 支持模型量化和优化,能够显著减小模型大小并提高推理速度,适合在手机、边缘设备等场景中使用。这使得开发者能够在不牺牲准确度的情况下,提升模型的运行效率。
  3. 部署阶段:将 TFLite 模型集成到 Android、iOS 或嵌入式系统中,确保模型能够在目标设备上高效运行。
    • 内存和计算资源的优化:在推理阶段,使用 TFLite 可以减少内存占用和计算资源消耗,尤其是在移动设备和嵌入式系统上。这对于需要长时间运行的应用尤为重要,可以延长设备的电池寿命。
    • 多种优化技术:TFLite 提供了多种优化技术,如模型量化(将浮点数转换为整数),可以进一步提高推理速度并降低功耗。这使得在实时应用中能够实现更快的响应时间,提升用户体验。
      在这里插入图片描述

二、实例1:CNN模型的转换

注:python 版本为3.10

2.1 pytorch模型训练

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision import datasets
from torch.utils.data import DataLoader

# 检查是否支持 MPS
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")


# 定义 CNN 模型
class CNNModel(nn.Module):
    def __init__(self):
        super(CNNModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = nn.functional.relu(self.conv1(x))
        x = nn.functional.max_pool2d(x, 2)
        x = nn.functional.relu(self.conv2(x))
        x = nn.functional.max_pool2d(x, 2)
        x = x.view(-1, 64 * 7 * 7)
        x = nn.functional.relu(self.fc1(x))
        x = self.fc2(x)
        return x


# 数据预处理
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# 加载 MNIST 数据集
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)

# 初始化模型、损失函数和优化器
model = CNNModel().to(device)  # 将模型移动到 MPS 设备
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 训练模型
for epoch in range(20):
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)  # 将数据移动到 MPS 设备
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
    print(f'Epoch [{epoch + 1}/20], Loss: {loss.item():.6f}')

# 保存模型
torch.save(model.state_dict(), 'cnn_mnist.pth')
print("Model saved as cnn_mnist.pth")

2.2 pth模型转onnx 并验证一致性

import numpy as np
import onnx
import onnxruntime as ort
import torch
import torch.nn as nn


# 定义 CNN 模型
class CNNModel(nn.Module):
    def __init__(self):
        super(CNNModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = nn.functional.relu(self.conv1(x))
        x = nn.functional.max_pool2d(x, 2)
        x = nn.functional.relu(self.conv2(x))
        x = nn.functional.max_pool2d(x, 2)
        x = x.view(-1, 64 * 7 * 7)
        x = nn.functional.relu(self.fc1(x))
        x = self.fc2(x)
        return x


# 加载模型并进行推理
model = CNNModel()
model.load_state_dict(torch.load('cnn_mnist.pth', weights_only=True))  # 加载保存的模型权重
model.eval()  # 设置为评估模式

# 创建一个示例输入
dummy_input = torch.randn(1, 1, 28, 28)  # MNIST 图像的形状

# 使用 PyTorch 进行推理
with torch.no_grad():
    pytorch_output = model(dummy_input)

# 导出模型为 ONNX 格式
torch.onnx.export(model, dummy_input, 'cnn_mnist.onnx', export_params=True, opset_version=11)
print("Model exported to cnn_mnist.onnx")

# 使用 ONNX 进行推理
onnx_model = onnx.load('cnn_mnist.onnx')
ort_session = ort.InferenceSession('cnn_mnist.onnx')

# 准备输入数据
onnx_input = dummy_input.numpy()  # 将 PyTorch 张量转换为 NumPy 数组
onnx_input = onnx_input.astype(np.float32)  # 确保数据类型为 float32

# 使用 ONNX 进行推理
onnx_output = ort_session.run(None, {ort_session.get_inputs()[0].name: onnx_input})

# 比较输出
pytorch_output_np = pytorch_output.numpy()  # 将 PyTorch 输出转换为 NumPy 数组
onnx_output_np = onnx_output[0]  # ONNX 输出是一个列表,取第一个元素

# 检查输出是否一致
if np.allclose(pytorch_output_np, onnx_output_np, atol=1e-5):
    print("The outputs are consistent between PyTorch and ONNX.")
else:
    print("The outputs are NOT consistent between PyTorch and ONNX.")

# 打印输出结果
print("PyTorch output:", pytorch_output_np)
print("ONNX output:", onnx_output_np)

The outputs are consistent between PyTorch and ONNX.
PyTorch output: [[ -1.5153266 -11.934659    0.5428004 -16.058285   -3.6684208  -4.596178
  -14.53585    -3.3159208  -5.7872214  -5.3301578]]
ONNX output: [[ -1.5153263 -11.934658    0.5428015 -16.058285   -3.66842    -4.5961757
  -14.53585    -3.3159204  -5.787223   -5.3301597]]

2.3 onnx模型转tflite

参考这个项目:onnx2tflite

git clone https://github.com/MPolaris/onnx2tflite.git
cd onnx2tflite
conda install tensorflow=2.11.0
pip install .
python -m onnx2tflite --weights ../pth2onnx/cnn_mnist.onnx

在这里插入图片描述

2.4 onnx模型和tflite一致性验证

import numpy as np
import onnxruntime as ort
import tensorflow as tf

# 1. 加载 ONNX 模型
onnx_model_path = 'cnn_mnist.onnx'
onnx_session = ort.InferenceSession(onnx_model_path)

# 2. 加载 TFLite 模型
tflite_model_path = 'cnn_mnist.tflite'
tflite_interpreter = tf.lite.Interpreter(model_path=tflite_model_path)
tflite_interpreter.allocate_tensors()

# 3. 准备输入数据
# 假设输入数据是 MNIST 数据集的一部分,形状为 (1, 28, 28, 1)
input_data = np.random.rand(1, 28, 28, 1).astype(np.float32)  # Keras 输入
input_data_onnx = input_data.transpose(0, 3, 1, 2)  # 转换为 ONNX 输入格式 (1, 1, 28, 28)

# 4. 使用相同的输入数据进行推理

# ONNX 模型推理
onnx_input_name = onnx_session.get_inputs()[0].name
onnx_output = onnx_session.run(None, {onnx_input_name: input_data_onnx})[0]
print("ONNX Output:", onnx_output)

# TFLite 模型推理
tflite_input_details = tflite_interpreter.get_input_details()
tflite_output_details = tflite_interpreter.get_output_details()

# 检查 TFLite 输入形状
print("TFLite Input Shape:", tflite_input_details[0]['shape'])

# 设置 TFLite 输入
# 确保输入数据的形状与 TFLite 模型的输入要求一致
tflite_interpreter.set_tensor(tflite_input_details[0]['index'], input_data)
tflite_interpreter.invoke()
tflite_output = tflite_interpreter.get_tensor(tflite_output_details[0]['index'])
print("TFLite Output:", tflite_output)

# 5. 比较输出结果
# 计算输出的差异
onnx_difference = np.abs(onnx_output - tflite_output)

# 输出结果
print("Difference (ONNX vs TFLite):", onnx_difference)

# 检查是否一致
if np.all(onnx_difference < 1e-5):  # 设定一个阈值
    print("The outputs are consistent between ONNX and TFLite models.")
else:
    print("The outputs are not consistent between ONNX and TFLite models.")
ONNX Output: [[ -3.7372704  -6.5073314  -1.1807165  -2.4232314 -10.638929    2.2660115
   -4.5868526  -2.7494073  -0.5609715  -6.331989 ]]
TFLite Input Shape: [ 1 28 28  1]
TFLite Output: [[ -3.7372704   -6.5073323   -1.180716    -2.4232314  -10.638928
    2.2660117   -4.5868545   -2.7494078   -0.56097114  -6.331988  ]]
Difference (ONNX vs TFLite): [[0.0000000e+00 9.5367432e-07 4.7683716e-07 0.0000000e+00 9.5367432e-07
  2.3841858e-07 1.9073486e-06 4.7683716e-07 3.5762787e-07 9.5367432e-07]]
The outputs are consistent between ONNX and TFLite models.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2397101.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【存储基础】存储设备和服务器的关系和区别

文章目录 1. 存储设备和服务器的区别2. 客户端访问数据路径场景1&#xff1a;经过服务器处理场景2&#xff1a;客户端直连 3. 服务器作为"中转站"的作用 刚开始接触存储的时候&#xff0c;以为数据都是存放在服务器上的&#xff0c;服务器和存储设备是一个东西&#…

5.29打卡

浙大疏锦行 DAY 38 Dataset和Dataloader类 知识点回顾&#xff1a; 1. Dataset类的__getitem__和__len__方法&#xff08;本质是python的特殊方法&#xff09; 2. Dataloader类 3. minist手写数据集的了解 作业&#xff1a;了解下cifar数据集&#xff0c;尝试获取其中一张图…

【黑马程序员uniapp】项目配置、请求函数封装

黑马程序员前端项目uniapp小兔鲜儿微信小程序项目视频教程&#xff0c;基于Vue3TsPiniauni-app的最新组合技术栈开发的电商业务全流程_哔哩哔哩_bilibili 参考 有代码&#xff0c;还有app、h5页面、小程序的演示 小兔鲜儿-vue3ts-uniapp-一套代码多端部署: 小兔鲜儿-vue3ts-un…

PyTorch——DataLoader的使用

batch_size, drop_last 的用法 shuffle shuffleTrue 各批次训练的图像不一样 shuffleFalse 在第156step顺序一致

Predixy的docker化

概述 当前已有一套redis cluster的集群&#xff0c;但是fs中的hiredis只能配置单实例redis。 AI了一下方案&#xff0c;可以使用redis的proxy组件来实现从hiredis到redis cluster的互通。 代码地址&#xff1a;https://github.com/joyieldInc/predixy Predixy特性介绍&…

C++ 之 多态 【虚函数表、多态的原理、动态绑定与静态绑定】

目录 前言 1.多态的原理 1.1虚函数表 1.2派生类中的虚表 1.3虚函数、虚表存放位置 1.4多态的原理 1.5多态条件的思考 2.动态绑定与静态绑定 3.单继承和虚继承中的虚函数表 3.1单继承中的虚函数表 3.2多继承(非菱形继承)中的虚函数表 4.问答题 前言 需要声明的&#x…

【JavaWeb】Maven、Servlet、cookie/session

目录 5. Maven6. Servlet6.1 Servlet 简介6.2 HelloServlet6.3 Servlet原理6.4 Mapping( **<font style"color:rgb(44, 44, 54);">映射 ** )问题6.5 ServletContext6.6 HttpServletResponse<font style"color:rgb(232, 62, 140);background-color:rgb(…

Rust 编程实现猜数字游戏

文章目录 编程实现猜数字游戏游戏规则创建新项目默认代码处理用户输入代码解析 生成随机数添加依赖生成逻辑 比较猜测值与目标值类型转换 循环与错误处理优化添加循环优雅处理非法输入​ 最终完整代码核心概念总结 编程实现猜数字游戏 我们使用cargo和rust实现一个经典编程练习…

关于神经网络中的激活函数

这篇博客主要介绍一下神经网络中的激活函数以及为什么要存在激活函数。 首先&#xff0c;我先做一个简单的类比&#xff1a;激活函数的作用就像给神经网络里的 “数字信号” 加了一个 “智能阀门”&#xff0c;让机器能学会像人类一样思考复杂问题。 没有激活i函数的神经网络…

CentOS_7.9 2U物理服务器上部署系统简易操作步骤

近期单位网站革新&#xff0c;鉴于安全加固&#xff0c;计划将原有Windows环境更新到Linux-CentOS 7.9&#xff0c;这版本也没的说&#xff08;绝&#xff09;了&#xff08;版&#xff09;官方停止更新&#xff0c;但无论如何还是被sisi的牵挂着这一大批人&#xff0c;毕竟从接…

短视频平台差异视角下开源AI智能名片链动2+1模式S2B2C商城小程序的适配性研究——以抖音与快手为例

摘要 本文以抖音与快手两大短视频平台为研究对象&#xff0c;从用户群体、内容生态、推荐逻辑三维度分析其差异化特征&#xff0c;并探讨开源AI智能名片链动21模式与S2B2C商城小程序在平台适配中的创新价值。研究发现&#xff0c;抖音的流量中心化机制与优质内容导向适合品牌化…

【笔记】Windows 下载并安装 ChromeDriver

以下是 在 Windows 上下载并安装 ChromeDriver 的笔记&#xff1a; ✅ Windows 下载并安装 ChromeDriver 1️⃣ 确认 Chrome 浏览器版本 打开 Chrome 浏览器 点击右上角 ︙ → 帮助 → 关于 Google Chrome 记下版本号&#xff0c;例如&#xff1a;114.0.5735.199 2️⃣ 下载…

Spark-Core Project

RDD转换算子总结 RDD转换算子分为Value类型、双Value类型和Key - Value类型。 1、Value类型 map&#xff1a;对数据逐条映射转换&#xff0c;可改变数据类型或值。如 dataRDD.map(num > num * 2 运行结果&#xff1a; 2&#xff09;mapPartitions&#xff1a;以分区为单位处…

Wireshark 使用教程:让抓包不再神秘

一、什么是 tshark&#xff1f; tshark 是 Wireshark 的命令行版本&#xff0c;支持几乎所有 Wireshark 的核心功能。它可以用来&#xff1a; 抓包并保存为 pcap 文件 实时显示数据包信息 提取指定字段进行分析 配合 shell 脚本完成自动化任务 二、安装与验证 Kali Linux…

JWT安全:接收无签名令牌.【签名算法设置为none绕过验证】

JWT安全&#xff1a;假密钥【签名随便写实现越权绕过.】 JSON Web 令牌 (JWT)是一种在系统之间发送加密签名 JSON 数据的标准化格式。理论上&#xff0c;它们可以包含任何类型的数据&#xff0c;但最常用于在身份验证、会话处理和访问控制机制中发送有关用户的信息(“声明”)。…

白银价格查询接口如何用Java进行调用?

一、什么是白银价格查询接口&#xff1f; 它聚焦于上海黄金交易所、上海期货交易所等权威市场&#xff0c;精准提供白银价格行情数据&#xff0c;助力用户实时把握市场脉搏&#xff0c;做出明智的投资决策。 二、应用场景 分析软件&#xff1a;金融类平台可以集成本接口&…

FreeBSD 14.3 候选版本附带 Docker 镜像和关键修复

新的月份已经到来&#xff0c;FreeBSD 14.3 候选发布版 1 现已开放测试&#xff0c;它带来了一些您可能会觉得有用的更新&#xff0c;特别是如果您对Docker容器感兴趣的话。RC1 版本中一个非常受欢迎的改进是&#xff0c;FreeBSD 项目已开始将官方开放容器计划 (OCI) 镜像发布到…

「Java教案」算术运算符与表达式

课程目标 1&#xff0e;知识目标 能够区分Java运算符的种类&#xff0c;例如&#xff0c;算术、赋值、关系、逻辑、位运算等。能够区分Java各类运算符的功能和使用场景。能够根据表达式的构成和计算规则&#xff0c;写出正确的表达式。能够根据运算符优先级与结合性&#xff…

论文写作核心要点

不要只读论文里的motivation和method 论文里的图表和统计特征 在论文里找到具有统计意义的东西&#xff0c;那么在语料里也肯定遵循这样的规律&#xff0c;我们就能用机器学习的方法&#xff0c; 我们再用不同方法解决&#xff0c;哪种方法好&#xff0c;就用哪种 实验分析 …

[java]eclipse中windowbuilder插件在线安装

目录 一、打开eclipse 二、打开插件市场 三、输入windowbuilder&#xff0c;点击install 四、进入安装界面 五、勾选我同意... 重启即可 一、打开eclipse 二、打开插件市场 三、输入windowbuilder&#xff0c;点击install 四、进入安装界面 五、勾选我同意... 重启即可