conv2d() received an invalid combination of arguments问题解决

在学习动手学深度学习风格迁移这一部分的时候,程序运行的时候抱错:conv2d() received an invalid combination of arguments

具体来说,先使用函数SynthesizedImage定义一个图像,它的权重是更新的目标,经get_inits实例化,通过训练更新图像的权重,获得风格迁移后的图像。

class SynthesizedImage(nn.Module):
    def __init__(self, img_shape, **kwargs):
        super(SynthesizedImage, self).__init__(**kwargs)
        self.weight = nn.Parameter(torch.rand(*img_shape))
        
    def forward(self):
        return self.weight
    
def get_inits(content_img, lr, lr_decay_epoch, init_random):
    gen_img = SynthesizedImage(content_img.shape).to(device)
    if not init_random:  
        gen_img.weight.data.copy_(content_img.data)

    optimizer = torch.optim.Adam(gen_img.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, lr_decay_epoch, 0.8)
    return gen_img(), optimizer, scheduler

参考:在python中遇到的错误(二):用pytorch的CNN发生的报错_游鱼不知夏的博客-CSDN博客

发现可能是初始化数据出了问题。经过检查发现函数get_inits返回值写成了是gen_img,它的格式是:

返回的参数应该写成gen_img(),返回后的格式是:

这样就不会报错了。

这里蕴含一个知识点:pytorch模型定义。下面举几个例子就能明白,为什么gen_img的格式是, gen_img()的格式是

简单地说,就是将模型实例化之后,gen_img代表模型自身,gen_img()执行了魔法函数forward(),得到forward()的返回值

第一个例子

class Net1(nn.Module):
    def __init__(self, img_shape, **kwargs):
        super(Net1, self).__init__(**kwargs)
        self.weight = nn.Parameter(torch.rand(*img_shape))
        
    def forward(self):
        return self.weight

model = Net1([2,3])

>>print(model)

Net1()

>>print(model())

Parameter containing:
tensor([[0.6031, 0.3673, 0.7362],
        [0.9071, 0.1086, 0.0191]], requires_grad=True)

>>print(type(model))

>>print(type(model()))

第二个例子

class Net2(nn.Module):
    def __init__(self, a):
        super(Net2, self).__init__()
        self.conv1 = nn.Conv2d(3, 5, 3)

    def forward(self, x):        
        return self.conv1
    
model = Net2(1)

>>print(model)

Net2(
  (conv1): Conv2d(3, 5, kernel_size=(3, 3), stride=(1, 1))
)

>>print(model(1))

Conv2d(3, 5, kernel_size=(3, 3), stride=(1, 1))

>>print(type(model))

>>print(type(model(1)))

第三个例子

class Net3(nn.Module):
    def __init__(self, a):
        super(Net3, self).__init__()
        self.weight = 123

    def forward(self):        
        return 456
    
model = Net3(1)

>>print(model)

Net3()

>>print(model())

456

>>print(type(model))

>>print(type(model()))

你可能感兴趣的:(代码报错,python,深度学习,人工智能)