pytorch转onnx模型

pytorch转onnx模型。
pytorch模型用DDP训练,模型的权值的名称前有module.,需要去掉

new_checkpoint = {}
# print(checkpoint)
for k,value in checkpoint.items():
    key = k.split('module.')[-1]
    new_checkpoint[key] = value
    print(k,key)

# model_pos.load_state_dict(checkpoint['model_pos'], strict=True)
model_pos.load_state_dict(new_checkpoint, strict=True)

完整代码。
包含验证pytorch与转完后onnx模型的输出是否一致

import torch
import onnx
import onnxruntime as rt
from utils import *

chk_filename = os.path.join(args.checkpoint, args.resume if args.resume else args.evaluate)
print('Loading checkpoint', chk_filename)
checkpoint = torch.load(chk_filename, map_location=lambda storage, loc: storage)
checkpoint = checkpoint['model_pos']
new_checkpoint = {}
# print(checkpoint)
for k,value in checkpoint.items():
    key = k.split('module.')[-1]
    new_checkpoint[key] = value
    print(k,key)

# model_pos.load_state_dict(checkpoint['model_pos'], strict=True)
model_pos.load_state_dict(new_checkpoint, strict=True)

output_file = './pf_f_27_1s.onnx'

test_instance = torch.rand((1, 27, 17, 2))
if True:
    print('export .........')

    torch.onnx.export(model_pos.cpu(),test_instance ,output_file ,
                      input_names=['input'], output_names=["output"], opset_version=10)
    print('Finished ******************')
    # check by onnx
    onnx_model = onnx.load(output_file)
    onnx.checker.check_model(onnx_model)

    # check the numerical value
    # get pytorch output
    pytorch_results = model_pos(test_instance)
    if not isinstance(pytorch_results, (list, tuple)):
        assert isinstance(pytorch_results, torch.Tensor)
        pytorch_results = [pytorch_results]

    # get onnx output
    input_all = [node.name for node in onnx_model.graph.input]
    input_initializer = [
        node.name for node in onnx_model.graph.initializer
    ]
    net_feed_input = list(set(input_all) - set(input_initializer))
    assert len(net_feed_input) == 1
    sess = rt.InferenceSession(output_file)
    onnx_results = sess.run(None,
                            {net_feed_input[0]: test_instance.detach().numpy()})

    # compare results
    assert len(pytorch_results) == len(onnx_results)
    for pt_result, onnx_result in zip(pytorch_results, onnx_results):
        assert np.allclose(
            pt_result.detach().cpu(), onnx_result, atol=1.e-5
        ), 'The outputs are different between Pytorch and ONNX'
    print('The numerical values are same between Pytorch and ONNX')

你可能感兴趣的:(pytorch,深度学习,人工智能)