在使用nn.Flatten()函数的时候报错
import torch
from torch import nn
net = nn.Sequential(nn.Flatten(), nn.Linear(784, 256), nn.ReLU(),
nn.Linear(256, 10))
错误如下:
AttributeError Traceback (most recent call last)
in ----> 1 net = nn.Sequential(nn.Flatten(), nn.Linear(784, 256), nn.ReLU(), 2 nn.Linear(256, 10)) 3 4 def init_weights(m): 5 if type(m) == nn.Linear: AttributeError: module 'torch.nn' has no attribute 'Flatten'
原因是因为使用的Pytorch1.0.0版本,没有定义Flatten函数。
过程包括试错部分,可跳转到 3.总结 直接看如何操作。
最简单的就是更新Pytorch版本即可,但由于我以前安装为了适应GPU版本等兼容问题,不更新Pytorch版本采用第二个方案。
思路灵感来自 axis 报错在对应文件找到对应错误代码将 axis 替换成 dim 即可
那么既然缺少Flatten函数,那就给它补上。找到1.9.0版本的Pytorch下载
https://github.com/pytorch/pytorch/releases
解压后在
D:\pytorch-1.9.0\pytorch-1.9.0\torch\nn\modules
文件里面找到 flatten.py 文件
再找到本地安装的位置,我使用的Anaconda安装的,将在其 flatten.py 文件拷贝到Anaconda安装对应的文件的夹里。例如下:
D:\Anaconda3\Lib\site-packages\torch\nn\modules
同时在此文件夹中找到__init__.py文件按照最新版本中的对应文件加入如下代码:
from .flatten import Flatten, Unflatten
以及在__all__ = 中加入对应代码
'Flatten',
重新运行会提示错误:
ModuleNotFoundError: No module named 'torch.types'
通过查询找到1.9.0版本的 types.py 文件在 D:\pytorch-1.9.0\pytorch-1.9.0\torch 中
将其 复制 到本地安装位置D:\Anaconda3\Lib\site-packages\torch
再次运行又提示错误:
19 _dtype = torch.dtype 20 _device = torch.device ---> 21 _qscheme = torch.qscheme 22 _size = Union[torch.Size, List[_int], Tuple[_int, ...]] 23 _layout = torch.layout AttributeError: module 'torch' has no attribute 'qscheme'
由于我没有在1.9.0版本找到qscheme.py,故在 types.py 中找到 21 行将其加上 ‘#’ 屏蔽
注:因为为没有需要用到torch.qscheme的情况故把它屏蔽了,如何后续需要用到,再找解决更新方法。
运行成功结果如下:
方法一:更新Pytorch版本
方法二:在原来Pytorch版本上修改(注: torch.qscheme会被屏蔽暂时未找解决方法,不使用此函数可尝试此方法)
下载新版Pytorch版本
下载地址:https://github.com/pytorch/pytorch/releases
首先,解压找到目录D:xx\pytorch-1.9.0\pytorch-1.9.0\torch\nn\modules中的flatten.py文件复制到自己电脑安装的相同路径下。
其次,找到本地目录下的__init__.py文件,按照解压中对应的__init__.py文件添加
from .flatten import Flatten, Unflatten在__all__ = 中添加'Flatten', 即可。
最后,再找到解压目录下D:xx\pytorch-1.9.0\torch的types.py文件,同理将其复制,同时打开types.py文件找到21行的_qscheme = torch.qscheme将加上其‘#’屏蔽。