forward() missing 1 required positional argument: ‘indices‘错误解决

在自编码器中,进行上池化操作时报了forward() missing 1 required positional argument: 'indices’的错误。
部分代码:

def __init__():
	self.pool1 = nn.MaxPool2d((3,3),stride=2)
	self.pool2 = nn.MaxUnpool2d((3,3),stride=2)
、、、、、
def forward():
	tempx= self.pool1(x)
	y = self.pool2(tempx)

uppool需要将池化返回索引作为位置参数,这些索引将与 return_indices=True 一起返回。
修改为:

def __init__():
	self.pool1 = nn.MaxPool2d((3,3),stride=2,return_indices=True )
	self.pool2 = nn.MaxUnpool2d((3,3),stride=2)
、、、、、
def forward():
	tempx,indices= self.pool1(x)
	y = self.pool2(tempx,indices)

你可能感兴趣的:(金蛋错误,python,深度学习,pytorch)