tensorboard打印pytorch网络结构遇到错误

尝试用tensorboard打印pytorch的网络结构。pytorch版本是1.4.0,配置过程中间走了一些弯路记录下来,最终解决方案请参考文末。

1、问题一

File "D:\Softwares\Anaconda3\lib\site-packages\tensorboardX\writer.py", line 419, in add_graph
    self.file_writer.add_graph(graph(model, input_to_model, verbose))
  File "D:\Softwares\Anaconda3\lib\site-packages\tensorboardX\graph.py", line 69, in graph
    trace, _ = torch.jit.get_trace_graph(model, args)
AttributeError: module 'torch.jit' has no attribute 'get_trace_graph'

初步解决方案:

trace, _ = torch.jit.get_trace_graph(model, args)

改为:

trace, _ = torch.jit._get_trace_graph(model, args)

此方案一般

2、问题二

File "D:\Softwares\Anaconda3\lib\site-packages\tensorboardX\writer.py", line 419, in add_graph
    self.file_writer.add_graph(graph(model, input_to_model, verbose))
  File "D:\Softwares\Anaconda3\lib\site-packages\tensorboardX\graph.py", line 79, in graph
    torch.onnx._optimize_trace(trace, False)
  File "D:\Softwares\Anaconda3\lib\site-packages\torch\onnx\__init__.py", line 163, in _optimize_trace
    return utils._optimize_graph(graph, operator_export_type)
  File "D:\Softwares\Anaconda3\lib\site-packages\torch\onnx\utils.py", line 135, in _optimize_graph
    graph = torch._C._jit_pass_onnx(graph, operator_export_type)
TypeError: _jit_pass_onnx(): incompatible function arguments. The following argument types are supported:
    1. (arg0: torch::jit::Graph, arg1: torch._C._onnx.OperatorExportTypes) -> torch::jit::Graph

初步解决方案:

graph = torch._C._jit_pass_onnx(graph, operator_export_type)

改为:

graph = torch._C._jit_pass_onnx(graph, OperatorExportTypes.ONNX)

此方案一般

3、问题三

File "D:\Softwares\Anaconda3\lib\site-packages\tensorboardX\writer.py", line 419, in add_graph
    self.file_writer.add_graph(graph(model, input_to_model, verbose))
  File "D:\Softwares\Anaconda3\lib\site-packages\tensorboardX\graph.py", line 82, in graph
    graph = trace.graph()
AttributeError: 'torch._C.Graph' object has no attribute 'graph'

初步解决方案:

升级tensorboardx即可。

pip install --upgrade tensorboardx

此方案一般

 

最终解决方案:

这几个问题的原因的是版本不匹配导致的,同时升级tensorboard和tensorboardx即可。

pip install --upgrade tensorboardx

pip install --upgrade tensorboard

你可能感兴趣的:(环境配置,pytorch,python,人工智能)