torch.onnx.export(model, input, "xxx.proto", verbose=True)出现问题

想从pytorch中得到xxx.proto时候发现出现以下问题:

Traceback (most recent call last):
  File "/home/leon/Leon/sparse-to-dense-Leon/main.py", line 357, in 
    main()
  File "/home/leon/Leon/sparse-to-dense-Leon/main.py", line 188, in main
    train(train_loader, model,criterion, optimizer, epoch)  # train for one epoch
  File "/home/leon/Leon/sparse-to-dense-Leon/main.py", line 223, in train
    torch.onnx.export(model, input, "model2.proto", verbose=True)
  File "/usr/local/lib/python3.5/site-packages/torch/onnx/__init__.py", line 25, in export
    return utils.export(*args, **kwargs)
  File "/usr/local/lib/python3.5/site-packages/torch/onnx/utils.py", line 84, in export
    _export(model, args, f, export_params, verbose, training, input_names, output_names)
  File "/usr/local/lib/python3.5/site-packages/torch/onnx/utils.py", line 140, in _export
    trace.set_graph(_optimize_graph(trace.graph(), aten))
  File "/usr/local/lib/python3.5/site-packages/torch/onnx/utils.py", line 95, in _optimize_graph
    graph = torch._C._jit_pass_onnx(graph, aten)
  File "/usr/local/lib/python3.5/site-packages/torch/onnx/__init__.py", line 40, in _run_symbolic_function
    return utils._run_symbolic_function(*args, **kwargs)
  File "/usr/local/lib/python3.5/site-packages/torch/onnx/utils.py", line 368, in _run_symbolic_function
    return fn(g, *inputs, **attrs)
TypeError: upsample_bilinear2d() got an unexpected keyword argument 'align_corners' (occurred when translating upsample_bilinear2d)

问题复现:

跑以下代码的时候会出现问题:

import torch

def main():
    data=torch.rand(1, 2, 2, 3)
    print("input", data)
    print(data.size())
    model=torch.nn.Upsample((4, 6), mode='bilinear', align_corners=True)
    pre=model(data)
    print("output", pre)
    print(pre.size())
    print("Transform the model to proto")
    torch.onnx.export(model, data, "model.proto", verbose=True)

main()

出现bug:

input tensor([[[[ 0.2396,  0.4801,  0.2886],
          [ 0.9168,  0.7868,  0.5738]],

         [[ 0.5297,  0.5078,  0.8996],
          [ 0.8995,  0.8765,  0.3325]]]])
torch.Size([1, 2, 2, 3])
output tensor([[[[ 0.2396,  0.3358,  0.4320,  0.4418,  0.3652,  0.2886],
          [ 0.4653,  0.5121,  0.5589,  0.5426,  0.4631,  0.3837],
          [ 0.6911,  0.6885,  0.6859,  0.6434,  0.5611,  0.4787],
          [ 0.9168,  0.8648,  0.8128,  0.7442,  0.6590,  0.5738]],

         [[ 0.5297,  0.5209,  0.5122,  0.5862,  0.7429,  0.8996],
          [ 0.6529,  0.6440,  0.6352,  0.6467,  0.6786,  0.7106],
          [ 0.7762,  0.7672,  0.7581,  0.7072,  0.6144,  0.5215],
          [ 0.8995,  0.8903,  0.8811,  0.7677,  0.5501,  0.3325]]]])
torch.Size([1, 2, 4, 6])
Transform the model to proto
Traceback (most recent call last):
  File "test.py", line 14, in 
    main()
  File "test.py", line 12, in main
    torch.onnx.export(model, data, "model.proto", verbose=True)
  File "/usr/local/lib/python3.5/site-packages/torch/onnx/__init__.py", line 25, in export
    return utils.export(*args, **kwargs)
  File "/usr/local/lib/python3.5/site-packages/torch/onnx/utils.py", line 84, in export
    _export(model, args, f, export_params, verbose, training, input_names, output_names)
  File "/usr/local/lib/python3.5/site-packages/torch/onnx/utils.py", line 140, in _export
    trace.set_graph(_optimize_graph(trace.graph(), aten))
  File "/usr/local/lib/python3.5/site-packages/torch/onnx/utils.py", line 95, in _optimize_graph
    graph = torch._C._jit_pass_onnx(graph, aten)
  File "/usr/local/lib/python3.5/site-packages/torch/onnx/__init__.py", line 40, in _run_symbolic_function
    return utils._run_symbolic_function(*args, **kwargs)
  File "/usr/local/lib/python3.5/site-packages/torch/onnx/utils.py", line 368, in _run_symbolic_function
    return fn(g, *inputs, **attrs)
TypeError: upsample_bilinear2d() got an unexpected keyword argument 'align_corners' (occurred when translating upsample_bilinear2d)

问题分析:

是torch.onnx模块转换的支持不够,对插值函数的支持不足。导致转换的时候出现问题。已将问题提交至管网,等待解决。

https://discuss.pytorch.org/t/transform-model-to-xxx-proto-failed/24353

问题解决

升级torch版本,sudo pip install torch --upgrade

 

你可能感兴趣的:(深度学习,Python,Pytorch)