自定义模型“XX object is not subscriptable”解决方案

自定义了一个Linear类,
并用self.add_module('L1',nn.Linear(3,2))添加了一层线性变换,

class Linear(nn.Module):
    def __init__(self) :
        super(Linear,self).__init__()
        self.add_module('L1',nn.Linear(3,2))
        self.add_module('L2',nn.Linear(2,1)) 
        self.add_module('S3',nn.Sigmoid())

然后想要获取权重

LLL=Linear()
print(LLL[0].weight)

就报了这样的错误:TypeError: 'Linear' object is not subscriptable
然而用nn.Sequential() 定义模型时却不会有这样的问题
所以要怎么解决呢?
跳到nn.Module.add_module()的函数声明,注释里这样写到:
自定义模型“XX object is not subscriptable”解决方案_第1张图片
“这个新加入的模型可以用你给它的名字来获取到”
所以正确写法是:

print(LLL.L1.weight)

你可能感兴趣的:(python,pytorch,深度学习,机器学习)