pytorch创建网络的一种错误(gan网络无法学习的一种可能性)

写了一个gan的程序,但是跑起来发现不对劲,网络完全没有学习的迹象,以为是损失函数的问题,但是看了很多博客,也对比了之前的教程,发现损失函数没有问题呀,最后,也是抱着怀疑的试一试的态度,发现问题所在了:
我在构建网络的时候,因为网络在向前推进的过程中会出现维度我没考虑到的情况,所以我在forward()这个函数里面加了网络来向前推进,类似下面这种:

class Generate(nn.Module):
    def __init__(self):
        super(Generate, self).__init__()
        self.net = nn.Sequential(
            nn.ReLU(),
            nn.Linear(1024, 1024),
            nn.ReLU(),
            nn.Linear(1024, 784),
            nn.Tanh() # 为啥不使用sigmod,后面试一下
        )

    def forward(self, x):
        net1 = nn.Linear(x.size(1), 1024)
        net1.cuda() # 这里要把模型放到gpu上
        x = net1(x)
        x = self.net(x)
        return x

可以看见上面的代码,在forward()函数里面另外加了网络来避免网络向前推进时候中间维度的未知,然后我试着打印出这个类里面的网络结构,打印出来,我发现,nn.Linear(x.size(1), 1024)这一层网络并没有在这个类的网络结构中,于是我将nn.Linear(x.size(1), 1024)这里的网络放到了def __init__里面了,放进去之后,再去跑模型,正常了。
改的例子如下:

class Generate(nn.Module):
    def __init__(self):
        super(Generate, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(784, 1024),# 放到这里来,第一个数值要自己根据自己的输入数据的维度计算
            nn.ReLU(),
            nn.Linear(1024, 1024),
            nn.ReLU(),
            nn.Linear(1024, 784),
            nn.Tanh() # 为啥不使用sigmod,后面试一下
        )

    def forward(self, x):
        x = self.net(x)
        return x

至于原因,很可能是放到forward()里面的网络并没有参与网络的梯度更新,这也就导致了网络学习不到我们想让他学习的东西,大家创建网络的时候注意一下。

2020 6.11

你可能感兴趣的:(pytorch)