pytorch_to_onnx mobilenet

import torch.utils.data
from torch.autograd import Variable
from torchvision.models import mobilenet_v2
import onnxruntime
import numpy as np
from onnxruntime.datasets import get_example
import cv2

# create model
model=mobilenet_v2(pretrained=True)
model = torch.nn.DataParallel(model)
# optionally resume from a checkpoint

model.to('cpu')
model.eval()
# model.cpu()
#accuracy(mode)
# dummy_input = Variable(torch.randn(1, 3, 224, 224))
input=cv2.imread('/home/dfy/PycharmProjects/Te_deepin/mao.jpeg')
input=cv2.resize(input,(224,224))
input=np.transpose(input, (2, 0, 1)).astype(np.float32)
now_image1 = Variable(torch.from_numpy(input))
dummy_input = now_image1.unsqueeze(0)
input_names=['input']
output_names=['output']
torch_out = torch.onnx._export(model.module, dummy_input, "mobile_net.onnx",verbose=True, input_names=input_names, output_names=output_names)
#test onnx model
examp

你可能感兴趣的:(机器学习(深度学习))