Pytorch模型转为onnx,onnx 模型 inference测试

1、mobilenet_v2生成.pt

用torchvision导出mobilenet_v2网络结构,生成mobilenet_v2.pt

import torch
from torch import nn
from torchvision import models
import torch.nn.functional as F
class  MobileNet_v2(nn.Module):
    def __init__(self):
        super(MobileNet_v2, self).__init__()
        model = models.mobilenet_v2(pretrained=True)
        # Remove linear and pool layers (since we're not doing classification)
        modules = list(model.children())[:-1]
        self.resnet = nn.Sequential(*modules)
        self.fc = nn.Linear(1280, 2)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, images):
        x = self.resnet(images)  # [N, 1280, 1, 1]
        x =  F.adaptive_avg_pool2d(x,(1,1))
        x = x.view(-1, 1280)  # [N, 1280]
        x = self.fc(x)
        out= self.softmax(x)
        return out

model = MobileNet_v2()
x = torch.rand(1, 3,224, 224)
torch.save(model.state_dict(),"mobilenet_v2.pt")
out=model(x)
print(out)

2、mobilenet_v2.pt 转化为mobilenet_v2.onnx,此时便可脱离pytorch框架,进行跨平台部署。

import torch
from torch import nn
from torchvision import models
import torch.nn.functional as F
import torch.onnx
class  MobileNet_v2(nn.Module):
    def __init__(self):
        super(MobileNet_v2, self).__init__()
        model = models.mobilenet_v2(pretrained=True)
        # Remove linear and pool layers (since we're not doing classification)
        modules = list(model.children())[:-1]
        self.resnet = nn.Sequential(*modules)
        self.fc = nn.Linear(1280, 2)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, images):
        x = self.resnet(images)  # [N, 1280, 1, 1]
        x =  F.adaptive_avg_pool2d(x,(1,1))
        x = x.view(-1, 1280)  # [N, 1280]
        x = self.fc(x)
        out= self.softmax(x)
        return out

model = MobileNet_v2()
model.load_state_dict(torch.load("mobilenet_v2.pt",map_location=torch.device('cpu')))
# An example input you would normally provide to your model's forward() method
x = torch.rand(1, 3,224, 224)
# Export the model
torch_out = torch.onnx.export(model, x, "mobilenet_v2.onnx", export_params=True)

3、onnx模型进行推理

注意:.pt输入类型为tensor,而onnx输入类型为numpy

import torch
import torch.onnx
import onnxruntime

class OnnxModel():
    def __init__(self, onnx_path):
        """
        :param onnx_path:
        """
        self.onnx_session = onnxruntime.InferenceSession(onnx_path)
        self.input_name = self.get_input_name(self.onnx_session)
        self.output_name = self.get_output_name(self.onnx_session)
    def get_output_name(self, onnx_session):
        """
        output_name = onnx_session.get_outputs()[0].name
        :param onnx_session:
        :return:
        """
        output_name = []
        for node in onnx_session.get_outputs():
            output_name.append(node.name)
        return output_name

    def get_input_name(self, onnx_session):
        """
        input_name = onnx_session.get_inputs()[0].name
        :param onnx_session:
        :return:
        """
        input_name = []
        for node in onnx_session.get_inputs():
            input_name.append(node.name)
        return input_name

    def get_input_feed(self, input_name, image_numpy):
        """
        input_feed={self.input_name: image_numpy}
        :param input_name:
        :param image_numpy:
        :return:
        """
        input_feed = {}
        for name in input_name:
            input_feed[name] = image_numpy
        return input_feed

    def forward(self, image_numpy):
        '''
        # image_numpy = image.transpose(2, 0, 1)
        # image_numpy = image_numpy[np.newaxis, :]
        # onnx_session.run([output_name], {input_name: x})
        # :param image_numpy:
        # :return:
        '''

        input_feed = self.get_input_feed(self.input_name, image_numpy)
        # scores = self.onnx_session.run(self.output_name[0], input_feed=input_feed)
        output = self.onnx_session.run(self.output_name, input_feed=input_feed)
        return output

def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
onnx_model_path = "mobilenet_v2.onnx"
model = OnnxModel(onnx_model_path)
x = torch.rand(1, 3,224, 224)
out = model.forward(to_numpy(x))
print(out)

你可能感兴趣的:(模型部署,pytorch,深度学习,python)