torch和onnx输出结果对比

torch和onnx输出层结果对比,判断torch转onnx结果一致性

import cv2
import torch
import numpy as np
from PIL import Image
from models import Net
import torchvision.transforms as transforms
import onnx
import onnxruntime as ort

def data_process(img):
	transform = transforms.Compose([
	transforms.Resize((512,512)),
	transforms.ToTensor(),
  	transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),])
  	crop = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
  	crop = transform(crop).unsqueeze(0)
  	return crop
    
def compare_cos_sim(tensor1, tensor2):
    t1 = tensor1.reshape(-1)
    t2 = tensor2.reshape(-1)
    n1 = np.linalg.norm(t1)
    n2 = np.linalg.norm(t2)
    cos = np.dot(t1/n1, t2.T/n2)
    print('cos',cos)

if __name__ == '__main__':

    imgpath='demo.jpg'
    img=cv2.imread(imgpath)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model_path='epoch-best.pth' 
    torch_model = Net()  # 构造模型,创建新模型,网络实例
    loaded_model = torch.load(model_path)  # 加载模型参数
    torch_model.load_state_dict(loaded_model['state_dict'])  #将模型参数加载到构造的新建模型实例model_demo中,需要创建的model_demo模型和加载模型的结构、参数名称、参数维度相同,不同时,可选择加载相同部分参数
    torch_model = torch_model.to(device)
    torch_model.eval()

    with torch.no_grad():
        crop = data_process(img)
        torch_outputs = torch_model(crop)
    
    dummy_input=crop.clone()
    ort_session = ort.InferenceSession(export_onnx_file)
    onnx_outputs = ort_session.run(None, {'data': dummy_input.numpy()})    
    
    for i,out in enumerate(torch_outputs[0]):
        torch_out=torch_outputs[0][out]
        onnx_out=onnx_outputs[i]
        print('torch out shape:',torch_out.shape)
        print('onnx out shape:',onnx_out.shape)
        print('torch out min max:',torch_out.min().numpy(),torch_out.max().numpy())
        print('onnx out min max:',np.min(onnx_out),np.max(onnx_out))
        compare_cos_sim(torch_out,onnx_out)    

你可能感兴趣的:(onnx转换,python,开发语言,pytorch)