pytorch模型转换到mxnet工程中使用

有关将 PyTorch 转换为 ONNX,然后加载到 MXNet 的教程

  • 关于ONNX 概述
    • 将 PyTorch 模型转换为 ONNX,然后将模型加载到 MXNet 中
    • 函数介绍
      • 参数
    • 函数介绍
      • 参量
      • 输出

转载自: https://docs.aws.amazon.com/zh_cn/dlami/latest/devguide/tutorial-onnx-pytorch-mxnet.html

关于ONNX 概述

开放神经网络交换 (ONNX) 是一种用于表示深度学习模型的开放格式。ONNX 受到 Amazon Web Services、Microsoft、Facebook 和其他多个合作伙伴的支持。您可以使用任何选定的框架来设计、训练和部署深度学习模型。ONNX 模型的好处是,它们可以在框架之间轻松移动。

将 PyTorch 模型转换为 ONNX,然后将模型加载到 MXNet 中

  1. 使用文本编辑器创建一个新文件,并在脚本中使用以下程序来训练 PyTorch 中的模拟模型,然后将它导出为 ONNX 格式。
# Build a Mock Model in PyTorch with a convolution and a reduceMean layer
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
import torch.onnx as torch_onnx

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=(3,3), stride=1, padding=0, bias=False)

    def forward(self, inputs):
        x = self.conv(inputs)
        #x = x.view(x.size()[0], x.size()[1], -1)
        return torch.mean(x, dim=2)

# Use this an input trace to serialize the model
input_shape = (3, 100, 100)
model_onnx_path = "torch_model.onnx"
model = Model()
model.train(False)

# Export the model to an ONNX file
dummy_input = Variable(torch.randn(1, *input_shape))
output = torch_onnx.export(model, 
                          dummy_input, 
                          model_onnx_path, 
                          verbose=False)
print("Export of torch_model.onnx complete!")
  1. 使用文本编辑器创建一个新文件,并在脚本中使用以下程序以在 MXNet 中打开 ONNX 格式文件。
import onnx
import mxnet as mx
from mxnet.contrib import onnx as onnx_mxnet
import numpy as np

# Import the ONNX model into MXNet's symbolic interface
sym, arg, aux = onnx_mxnet.import_model("torch_model.onnx")
print("Loaded torch_model.onnx!")
print(sym.get_internals())
  1. 运行此脚本后,MXNet 将拥有加载的模型,并打印一些基本模型信息。

函数介绍

将一个模型导出到ONNX格式。该exporter会运行一次你的模型,以便于记录模型的执行轨迹,并将其导出

torch.onnx.export(model, args, f, export_params=True, verbose=False, training=False, input_names=None, output_names=None, dynamic_axes = {})

参数

  • model(torch.nn.Module)-要被导出的模型
  • args(参数的集合)-模型的输入,例如,这种model方式是对模型的有效调用。任何非Variable参数都将硬编码到导出的模型中;任何Variable参数都将成为导出的模型的输入,并按照他们在args中出现的顺序输入。因为export运行模型,所以我们需要提供一个输入张量x。只要是正确的类型和大小,其中的值就可以是随机的。请注意,除非指定为动态轴,否则输入尺寸将在导出的ONNX图形中固定为所有输入尺寸。在此示例中,我们使用输入batch_size 1导出模型,但随后dynamic_axes 在torch.onnx.export()。因此,导出的模型将接受大小为[batch_size,3、100、100]的输入,其中batch_size可以是可变的。
  • f-一个类文件的对象(必须实现文件描述符的返回)或一个包含文件名字符串。一个二进制Protobuf将会写入这个文件中。
  • export_params(bool,default True)-如果指定,所有参数都会被导出。如果你只想导出一个未训练的模型,就将此参数设置为False。在这种情况下,导出的模型将首先把所有parameters作为参arguments,顺序由model.state_dict().values()指定。
  • verbose(bool,default False)-如果指定,将会输出被导出的轨迹的调试描述。
  • training(bool,default False)-导出训练模型下的模型。目前,ONNX只面向推断模型的导出,所以一般不需要将该项设置为True。
  • input_names(list of strings, default empty list)-按顺序分配名称到图中的输入节点。
  • output_names(list of strings, default empty list)-按顺序分配名称到图中的输出节点。
  • dynamic_axes{‘input’ : {0 : ‘batch_size’}, ‘output’ : {0 : ‘batch_size’}}) # variable lenght axes

函数介绍

将onnx模型文件导入mxnet

mxnet.contrib.onnx.import_model(model_file)

参量

  • model_file(str)– ONNX模型文件名

输出

  • sym(Symbol)– MXNet符号对象
  • arg_params(dict of str to NDArray)–以mxnet.ndarray.NDArray格式存储的转换参数的字典
  • aux_params(dict of str to NDArray)–以mxnet.ndarray.NDArray格式存储的转换参数的字典

你可能感兴趣的:(pytorch)