onnx 模型转换及推理

最近在部署模型在安卓端时发现ncnn模型不及原来的pytorch精度高,误差挺大的,分类误差超过了0.1,怀疑是模型转换造成的精度损失,这里就验证一下是不是pytorch -> onnx 模型中产生的误差。

1.模型转换

先来看一下pyotrch 到 onnx 模型转换,这个网上已经很多资料了,这里再贴一下代码:

import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def model_converter():
    model = torch.load('test.pth').to(device)  # 这里保存的是完整模型
    model.eval()

    dummy_input = torch.randn(1, 3, 96, 96, device=device)
    input_names = ['data']
    output_names = ['fc']
    torch.onnx.export(model, dummy_input, 'test.onnx', 
                      export_params=True, 
                      verbose=True, 
                      input_names=input_names, 
                      output_names=output_names)

会输出test.onnx模型。

 

2.onnx 模型 onnx_runtime 推理

import cv2
import numpy as np
import onnxruntime as rt

def image_process(image_path):
    mean = np.array([[[0.485, 0.456, 0.406]]])      # 训练的时候用来mean和std
    std = np.array([[[0.229, 0.224, 0.225]]])

    img = cv2.imread(image_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = cv2.resize(img, (96, 96))                 # (96, 96, 3)

    image = img.astype(np.float32)/255.0
    image = (image - mean)/ std

    image = image.transpose((2, 0, 1))              # (3, 96, 96)
    image = image[np.newaxis,:,:,:]                 # (1, 3, 96, 96)

    image = np.array(image, dtype=np.float32)
    
    return image

def onnx_runtime():
    imgdata = image_process('test.jpg')
    
    sess = rt.InferenceSession('test.onnx')
    input_name = sess.get_inputs()[0].name  
    output_name = sess.get_outputs()[0].name

    pred_onnx = sess.run([output_name], {input_name: imgdata})

    print("outputs:")
    print(np.array(pred_onnx))

onnx_runtime()

onnx模型预测结果:

[1.7607212e-04 3.5554171e-05 2.4718046e-04 9.9977958e-01 1.0728532e-01
   3.8951635e-05 4.4435263e-05]

pytorch模型预测结果:

[1.7604830e-04 3.5546487e-05 2.4715674e-04 9.9977964e-01 1.0728550e-01
  3.8929291e-05 4.4419667e-05]

可以看到onnx输出的分类概率和pytorch模型预测的概率几乎完全一样,可见一开始我遇到问题不是pytorch->onnx 模型转换造成的,可能是在移动端图片预处理没弄好造成的。

 

你可能感兴趣的:(AI,算法,深度学习,人工智能,onnx模型推理,pytorch转onnx)