Pytorch 1.0.0版本问题(一)之module ‘torch.nn‘ has no attribute ‘Flatten‘

1.错误生成原因


在使用nn.Flatten()函数的时候报错

import torch
from torch import nn

net = nn.Sequential(nn.Flatten(), nn.Linear(784, 256), nn.ReLU(),
                    nn.Linear(256, 10))

错误如下:

AttributeError                            Traceback (most recent call last)
 in 
----> 1 net = nn.Sequential(nn.Flatten(), nn.Linear(784, 256), nn.ReLU(),
      2                     nn.Linear(256, 10))
      3 
      4 def init_weights(m):
      5     if type(m) == nn.Linear:

AttributeError: module 'torch.nn' has no attribute 'Flatten'

原因是因为使用的Pytorch1.0.0版本,没有定义Flatten函数。

2.解决方法过程

过程包括试错部分,可跳转到 3.总结 直接看如何操作。

        最简单的就是更新Pytorch版本即可,但由于我以前安装为了适应GPU版本等兼容问题,不更新Pytorch版本采用第二个方案。

思路灵感来自 axis 报错在对应文件找到对应错误代码将 axis 替换成 dim 即可

       那么既然缺少Flatten函数,那就给它补上。找到1.9.0版本的Pytorch下载

https://github.com/pytorch/pytorch/releases

解压后在

D:\pytorch-1.9.0\pytorch-1.9.0\torch\nn\modules

文件里面找到  flatten.py  文件

       再找到本地安装的位置,我使用的Anaconda安装的,将在其  flatten.py 文件拷贝到Anaconda安装对应的文件的夹里。例如下:

D:\Anaconda3\Lib\site-packages\torch\nn\modules

同时在此文件夹中找到__init__.py文件按照最新版本中的对应文件加入如下代码:

from .flatten import Flatten, Unflatten

以及在__all__ = 中加入对应代码

'Flatten',

如下图所示:Pytorch 1.0.0版本问题(一)之module ‘torch.nn‘ has no attribute ‘Flatten‘_第1张图片

 重新运行会提示错误:

ModuleNotFoundError: No module named 'torch.types'

通过查询找到1.9.0版本的 types.py 文件在 D:\pytorch-1.9.0\pytorch-1.9.0\torch 中

将其 复制 到本地安装位置D:\Anaconda3\Lib\site-packages\torch

再次运行又提示错误:

     19 _dtype = torch.dtype
     20 _device = torch.device
---> 21 _qscheme = torch.qscheme
     22 _size = Union[torch.Size, List[_int], Tuple[_int, ...]]
     23 _layout = torch.layout

AttributeError: module 'torch' has no attribute 'qscheme'

由于我没有在1.9.0版本找到qscheme.py,故在 types.py 中找到 21 行将其加上 ‘#’ 屏蔽

注:因为为没有需要用到torch.qscheme的情况故把它屏蔽了,如何后续需要用到,再找解决更新方法

运行成功结果如下:

Pytorch 1.0.0版本问题(一)之module ‘torch.nn‘ has no attribute ‘Flatten‘_第2张图片

3.总结

方法一:更新Pytorch版本

方法二:在原来Pytorch版本上修改(注: torch.qscheme会被屏蔽暂时未找解决方法,不使用此函数可尝试此方法)

下载新版Pytorch版本

下载地址:https://github.com/pytorch/pytorch/releases

首先,解压找到目录D:xx\pytorch-1.9.0\pytorch-1.9.0\torch\nn\modules中的flatten.py文件复制到自己电脑安装的相同路径下。

其次,找到本地目录下的__init__.py文件,按照解压中对应的__init__.py文件添加

from .flatten import Flatten, Unflatten__all__ = 中添加'Flatten', 即可。

最后,再找到解压目录下D:xx\pytorch-1.9.0\torch的types.py文件,同理将其复制,同时打开types.py文件找到21行的_qscheme = torch.qscheme将加上其‘#’屏蔽。

你可能感兴趣的:(Pytorch,jupyter,Python,python)