PyTorch模型的多种导出方式提供给其他程序使用
flyfish
 
PyTorch模型的多种导出方式
- PyTorch模型的多种导出方式提供给其他程序使用
- 1 模型可视化
- 2 预训练模型
- 3 ONNX模型导出有输入有输出
- TRAINING导出方式
- EVAL导出方式
 
- 4 自定义输入输出的名字,并可批量推理
- 5 导出JIT模型
 
1 模型可视化
以下使用模型可视化工具时netron
工具下载到本地
 https://github.com/lutzroeder/netron/releases/
 或者在使用
 https://netron.app/
2 预训练模型
当下载一个预训练模型时,只是一个一个的module
 
3 ONNX模型导出有输入有输出
import torch
import torchvision
if __name__ == '__main__':
    input = torch.randn(1, 3, 224, 224)        
    model = torchvision.models.resnet18()                          
    model.load_state_dict(torch.load("resnet18-f37072fd.pth")) 
    model.eval()                               
    torch.onnx.export(model, input, "a.onnx", training=torch.onnx.TrainingMode.TRAINING) 
    torch.onnx.export(model, input, "b.onnx", training=torch.onnx.TrainingMode.EVAL) 
TRAINING导出方式
算子没有融合
 
EVAL导出方式

 当采用EVAL方式进行模型导出的时候,Conv和BatchNorm层进行了合并
4 自定义输入输出的名字,并可批量推理
import torch
import torchvision
if __name__ == '__main__':
	model = torchvision.models.resnet18()                          
	model.load_state_dict(torch.load("resnet18-f37072fd.pth")) 
	model.eval()                               
	batch_size = 4 
	input_data = torch.randn(batch_size, 3, 224, 224)
	output_path = "c.onnx"
	torch.onnx.export(model, input_data, output_path,
		          input_names=["input"], output_names=["output"],
		          dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}})

5 导出JIT模型
JIT(Just-In-Time)
 在Yolov5中叫torchscript
import torch
import torchvision
if __name__ == '__main__':
	model = torchvision.models.resnet18()                          
	model.load_state_dict(torch.load("resnet18-f37072fd.pth")) 
	model.eval()                               
	input = torch.rand(1, 3, 224, 224)
	jit_model = torch.jit.trace(model, input)
	torch.jit.save(jit_model, 'resnet18_jit.trace.pth')
	#script_model = torch.jit.script(model, input)
	#torch.jit.save(script_model, 'resnet18_jit.script.pth')

 本文使用的PyTorch版本 1.10.1



















