尝试用tensorboard打印pytorch的网络结构。pytorch版本是1.4.0,配置过程中间走了一些弯路记录下来,最终解决方案请参考文末。
1、问题一
File "D:\Softwares\Anaconda3\lib\site-packages\tensorboardX\writer.py", line 419, in add_graph
self.file_writer.add_graph(graph(model, input_to_model, verbose))
File "D:\Softwares\Anaconda3\lib\site-packages\tensorboardX\graph.py", line 69, in graph
trace, _ = torch.jit.get_trace_graph(model, args)
AttributeError: module 'torch.jit' has no attribute 'get_trace_graph'
初步解决方案:
将
trace, _ = torch.jit.get_trace_graph(model, args)
改为:
trace, _ = torch.jit._get_trace_graph(model, args)
此方案一般。
2、问题二
File "D:\Softwares\Anaconda3\lib\site-packages\tensorboardX\writer.py", line 419, in add_graph
self.file_writer.add_graph(graph(model, input_to_model, verbose))
File "D:\Softwares\Anaconda3\lib\site-packages\tensorboardX\graph.py", line 79, in graph
torch.onnx._optimize_trace(trace, False)
File "D:\Softwares\Anaconda3\lib\site-packages\torch\onnx\__init__.py", line 163, in _optimize_trace
return utils._optimize_graph(graph, operator_export_type)
File "D:\Softwares\Anaconda3\lib\site-packages\torch\onnx\utils.py", line 135, in _optimize_graph
graph = torch._C._jit_pass_onnx(graph, operator_export_type)
TypeError: _jit_pass_onnx(): incompatible function arguments. The following argument types are supported:
1. (arg0: torch::jit::Graph, arg1: torch._C._onnx.OperatorExportTypes) -> torch::jit::Graph
初步解决方案:
将
graph = torch._C._jit_pass_onnx(graph, operator_export_type)
改为:
graph = torch._C._jit_pass_onnx(graph, OperatorExportTypes.ONNX)
此方案一般。
3、问题三
File "D:\Softwares\Anaconda3\lib\site-packages\tensorboardX\writer.py", line 419, in add_graph
self.file_writer.add_graph(graph(model, input_to_model, verbose))
File "D:\Softwares\Anaconda3\lib\site-packages\tensorboardX\graph.py", line 82, in graph
graph = trace.graph()
AttributeError: 'torch._C.Graph' object has no attribute 'graph'
初步解决方案:
升级tensorboardx即可。
pip install --upgrade tensorboardx
此方案一般。
最终解决方案:
这几个问题的原因的是版本不匹配导致的,同时升级tensorboard和tensorboardx即可。