AttributeError: module ‘torch.onnx‘ has no attribute ‘set_training‘解决方案

文章目录

  • 前言
  • 1. 问题描述
  • 2. 问题原因
  • 3. 解决方法

前言

使用pytorch可视化网络结构时,遇到了pytorch和tensorboardX版本不兼容问题,又不能轻易降低pytorch版本,最终参考网上文章找到了问题原因。

1. 问题描述

使用TensorboardX可视化网络结构,示例代码如下。

from tensorboardX import SummaryWriter

with SummaryWriter(comment="XXXXXX") as w:
    w.add_graph(model)

报错:AttributeError: module ‘torch.onnx‘ has no attribute ‘set_training‘

2. 问题原因

参考网上文章1,作者发现是pytorch的高版本修改了一个方法名称。

PyTorch 1.6版本中set_training变成了select_model_mode_for_export

而tensorboardX的升级版本中也仍然没有解决这个问题。

问题就出现在下面这个 set_training

def graph(model, args, verbose=False, **kwargs):

    import torch

    with torch.onnx.set_training(model, False): 
        try:
            trace = torch.jit.trace(model, args)
            graph = trace.graph

        except RuntimeError as e:
            print(e)
            print('Error occurs, No graph saved')

3. 解决方法

是临时的解决方案,修改tensorboardX源码并打包。

第一步:从github拉取tensorboardX源码。

git clone https://github.com/lanpa/tensorboardX

第二步:切换到所需版本的标签。

git checkout v1.8

第三步:修改源码。

# 修改 with torch.onnx.set_training(model, False): 为下面语句
with torch.onnx.select_model_mode_for_export(model, False): 

第四步:重新打包。

(注意:以下步骤在Linux下完成,Windows环境请适当修改路径写法)

pip install wheel
# 切换到tensorboardX所在目录
cd tensorboardX
# 打包
pip wheel --wheel-dir=/root/ ./

# 生成 `tensorboardX-1.8+e136d41-py2.py3-none-any.whl`

第五步:安装。

# 将`whl`文件拷贝到项目目录,并安装
pip install tensorboardX-1.8+e136d41-py2.py3-none-any.whl

再次运行,就不再报错了。


  1. 参考文章 https://blog.csdn.net/qq_42730750/article/details/119741621 ↩︎

你可能感兴趣的:(深度学习,pytorch,深度学习,python)