转载自: https://docs.aws.amazon.com/zh_cn/dlami/latest/devguide/tutorial-onnx-pytorch-mxnet.html
开放神经网络交换 (ONNX) 是一种用于表示深度学习模型的开放格式。ONNX 受到 Amazon Web Services、Microsoft、Facebook 和其他多个合作伙伴的支持。您可以使用任何选定的框架来设计、训练和部署深度学习模型。ONNX 模型的好处是,它们可以在框架之间轻松移动。
# Build a Mock Model in PyTorch with a convolution and a reduceMean layer
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
import torch.onnx as torch_onnx
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=(3,3), stride=1, padding=0, bias=False)
def forward(self, inputs):
x = self.conv(inputs)
#x = x.view(x.size()[0], x.size()[1], -1)
return torch.mean(x, dim=2)
# Use this an input trace to serialize the model
input_shape = (3, 100, 100)
model_onnx_path = "torch_model.onnx"
model = Model()
model.train(False)
# Export the model to an ONNX file
dummy_input = Variable(torch.randn(1, *input_shape))
output = torch_onnx.export(model,
dummy_input,
model_onnx_path,
verbose=False)
print("Export of torch_model.onnx complete!")
import onnx
import mxnet as mx
from mxnet.contrib import onnx as onnx_mxnet
import numpy as np
# Import the ONNX model into MXNet's symbolic interface
sym, arg, aux = onnx_mxnet.import_model("torch_model.onnx")
print("Loaded torch_model.onnx!")
print(sym.get_internals())
将一个模型导出到ONNX格式。该exporter会运行一次你的模型,以便于记录模型的执行轨迹,并将其导出
torch.onnx.export(model, args, f, export_params=True, verbose=False, training=False, input_names=None, output_names=None, dynamic_axes = {})
将onnx模型文件导入mxnet
mxnet.contrib.onnx.import_model(model_file)