InfoGAN原理PyTorch实现Debug记录

CGAN从无监督GAN改进成有监督的GAN

GAN的基本原理输入是随机噪声,无法控制输出和输入之间的对应关系,也无法控制输出的模式,CGAN全称是条件GAN(Conditional GAN)改进基本的GAN解决了这个问题,CGAN和基本的GAN不同的地方是:
参考下面的链接
https://www.jianshu.com/p/39c57e9a6630
这里面介绍了实现CGAN有三种形式,从网络实现上的三种形式,没有讲解怎样优化目标函数

CGAN的一个问题是输入的有监督标签是离散型输入,如果输入中还有连续型输入,也就是C这个条件是个连续型的,那么将要继续参考InfoGAN

InfoGAN

参考下面的链接,非常详细的讲解了InfoGAN的原理、网络结构的实现、损失函数怎样求解
https://www.jianshu.com/p/fa892c81df60
InfoGAN的Info部分和判别器D共用了前面的网络,那么PyTorch怎么实现共用网咯呢?
参考下面的PyTorch实现
https://mp.weixin.qq.com/s?__biz=MzI3MzkyMzE5Mw==&mid=2247485031&idx=1&sn=e6ccbc33639462d59ee56923c59173b6&chksm=eb1aab71dc6d2267cc52bf769106067c53c867ad6a02063791674937857fb86da36ecd6cbfd9&token=1864035800&lang=zh_CN#rd
InfoGAN原理PyTorch实现Debug记录_第1张图片
原来PyTorch定义判别器类的时候可以分成三个网络,分别是主网络、D网络和C网络、L网络,D网络和C网络和L网络公用主网络,这个例子中的InfoGAN得输入有随机噪声、离散输入(C部分)、连续输入(L部分),forward中先用主网络处理x,之后返回D网络、C网络和L网络
不得不说,这种写法很有趣啊

PyTorch实现Debug记录

我自己实现了InfoGAN网络,运行程序后接二连三出现了很多错误

Bug(1):

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace
参考链接:
https://blog.csdn.net/qq_32953463/article/details/115728762
出现这个错误的原因是Pytorch的版本问题,我的Pytorch是1.11.0版本,如果Pytorch版本低于1.4不会有这个问题,链接中提供了一种不需要重新安装Pytorch的办法,backward()放在一起,step()放在一起,zero_grad()不需要放在一起,如下截图所示,

不得不说,这么神奇,真的解决了

            real_out = netD(real_img).mean()
            fake_out = netD(fake_img).mean()
            d_loss = 1 - real_out + fake_out
            netD.zero_grad()
            g_loss = generator_criterion(fake_out, fake_img, real_img)
            netG.zero_grad()

            d_loss.backward(retain_graph=True)
            g_loss.backward()

            optimizerD.step()
            optimizerG.step()

            fake_img = netG(z)
            fake_out = netD(fake_img).mean()

Bug(2):RuntimeError: Trying to backward through the graph a second time but the buffers have already been f

或者说是pytorch中的retain_graph=True的作用

参考链接:https://blog.csdn.net/qq_39861441/article/details/104129368
总的来说进行一次backward之后,各个节点的值会清除,这样进行第二次backward会报错,如果加上retain_graph==True后,可以再来一次backward。

InfoGAN原理PyTorch实现Debug记录_第2张图片

上面的示例代码中前两个网络D和G在backward的时候使用了retain_graph=True的参数,最后一个网络没有使用此参数,此参数的默认值是False

如果想了解底层的原理,建议阅读下面的链接,里面的图解非常的有趣
https://blog.csdn.net/SY_qqq/article/details/107384161

Bug(3):

RuntimeError: Found dtype Long but expected Float
这个错误来源于torch需要float类型,但是数据中是int类型或者long类型,解决方法是debug一个一个看变量中哪里出现了int或者long类型,假设variable是int或者long类型的变量,将它转换成float类型

variable = variable.to(torch.float32)

你可能感兴趣的:(python,pytorch,深度学习,python)