PyTorch对WGAN(Wasserstein生成对抗网络)的实现

        生成对抗网络(GAN)的用途非常广泛,可以“无中生有”图片,人物动漫头像,去掉场景中的雨、黑白转彩色的图片与视频、视频预测、2D推导3D等等,对于Goodfellow的封神之作,大家有兴趣的可以阅读:Generative Adversarial Nets  

作为开山之作,肯定也存在诸多问题,如下:

1、纳什均衡一般很难达到,这就造成模型要设计的比较精巧了,生成器和辨别器要旗鼓相当才能很好的进行下去
2、训练梯度的不稳定
3、模型的崩溃(mode collapse,生成的样本单一,如MNIST中的单个数字)
4、梯度消失(辨别器训练的太好,生成器的损失函数很难降下去,反之,辨别器很差,那么生成器就显得没有多大意义)

后续有很多对GAN的改良版本,我们来熟悉WGAN这篇优质的论文:Wasserstein GAN

很好的解决了上述出现的问题,而且为调试提供了更有意义的学习曲线和超参数的搜索。
对于WGAN的详细了解,可以参阅知乎这篇不错的文章:令人拍案叫绝的Wasserstein GAN

Wasserstein距离,又叫Earth-Mover(EM)推土机距离,相比较前面学习的KL和JS散度(KL散度与JS散度的公式与代码的简要实现)要更加占优势,简单来说就是两个分布没有重叠或完全重叠,Wasserstein距离仍然能够反映它们的远近。

推土机距离

PyTorch对WGAN(Wasserstein生成对抗网络)的实现_第1张图片

代码看下效果(使用scipy接口):

from scipy.stats import wasserstein_distance as wd

wd1 = wd([0, 1, 3], [1, 1, 4])#0.6666666666666667
wd2 = wd([1, 1, 4], [1, 1, 4])#0.0
wd3 = wd([10, 111, 41], [1, 1, 4])#52.0
#另外这个既然叫做距离也就是对称的,交换之后的值还是一样的
wd4 = wd([1, 1, 4],[10, 111, 41])#52.0

也就是说KL和JS散度在重合和不重合的极端情况下,它们的散度是突变的,而EM就是一种平滑的,这样迭代梯度才有意义。另外在论文中把辨别器换了一种叫法,叫做评论器,也就是说辨别器是辨别真伪,而评论器是计算分布的距离,叫法更适合。
相对GAN主要做了哪些修改呢?

1、判别器最后一层去掉sigmoid(不是分类问题,求的是距离,是回归问题)
2、去掉了生成器和判别器的损失函数的log(这个就是区别于KL散度和JS散度)
3、每次更新判别器的参数之后把它们的绝对值截断到不超过一个固定常数c
4、不用基于动量的优化算法(包括momentum和Adam,容易造成loss不稳定),推荐使用RMSProp优化算法

我们来运行源码看下效果:

根据自己的路径来修改(数据集我放在D盘的LSUNDIR目录下面)

(pytorch) D:\>python C:\Users\Tony\Downloads\WassersteinGAN\main.py --dataset lsun --dataroot LSUNDIR --cuda
Namespace(Diters=5, adam=False, batchSize=64, beta1=0.5, clamp_lower=-0.01, clamp_upper=0.01, cuda=True, dataroot='LSUNDIR', dataset='lsun', experiment=None, imageSize=64, lrD=5e-05, lrG=5e-05, mlp_D=False, mlp_G=False, n_extra_layers=0, nc=3, ndf=64, netD='', netG='', ngf=64, ngpu=1, niter=25, noBN=False, nz=100, workers=2) 

错误处理

    transforms.Scale(opt.imageSize),
AttributeError: module 'torchvision.transforms' has no attribute 'Scale'

main.py里面将transforms.Scale替换为transforms.Resize,有多处,新版本丢弃了Scale的写法

    dataset = dset.LSUN(db_path=opt.dataroot, classes=['bedroom_train'],
TypeError: __init__() got an unexpected keyword argument 'db_path'

查看LSUN定义,参数名称现在是root了,将db_path=opt.dataroot修改为root=opt.dataroot

 ModuleNotFoundError: No module named 'lmdb'
没有安装lmdb的数据库,安装:conda install python-lmdb

lmdb.Error: LSUNDIR\bedroom_train_lmdb: ϵͳ�Ҳ���ָ����·����
缺少数据集,需要用到的LSUN的10个场景的数据集,看情况下载:http://dl.yf.io/lsun/scenes/ 

这个lmdb( Lightning Memory-Mapped Database),是闪电式的内存映射型数据库,一个字快,针对神经网络大型数据集而设计的,不属于关系型,保存的是Key-Value对,比如将不同类型的图片等统一转换成这个格式进行存储。
对这数据库感兴趣的可以参阅:Python对于lmdb(Lightning Memory-Mapped Database)闪电式内存映射数据库的使用 

 

(pytorch) D:\>python C:\Users\Tony\Downloads\WassersteinGAN\main.py --dataset lsun --dataroot LSUNDIR --cuda
Namespace(Diters=5, adam=False, batchSize=64, beta1=0.5, clamp_lower=-0.01, clamp_upper=0.01, cuda=True, dataroot='LSUNDIR', dataset='lsun', experiment=None, imageSize=64, lrD=5e-05, lrG=5e-05, mlp_D=False, mlp_G=False, n_extra_layers=0, nc=3, ndf=64, netD='', netG='', ngf=64, ngpu=1, niter=25, noBN=False, nz=100, workers=2)
Random Seed:  9768
DCGAN_G(
  (main): Sequential(
    (initial:100-512:convt): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (initial:512:batchnorm): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (initial:512:relu): ReLU(inplace=True)
    (pyramid:512-256:convt): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (pyramid:256:batchnorm): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (pyramid:256:relu): ReLU(inplace=True)
    (pyramid:256-128:convt): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (pyramid:128:batchnorm): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (pyramid:128:relu): ReLU(inplace=True)
    (pyramid:128-64:convt): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (pyramid:64:batchnorm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (pyramid:64:relu): ReLU(inplace=True)
    (final:64-3:convt): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (final:3:tanh): Tanh()
  )
)
DCGAN_D(
  (main): Sequential(
    (initial:3-64:conv): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (initial:64:relu): LeakyReLU(negative_slope=0.2, inplace=True)
    (pyramid:64-128:conv): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (pyramid:128:batchnorm): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (pyramid:128:relu): LeakyReLU(negative_slope=0.2, inplace=True)
    (pyramid:128-256:conv): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (pyramid:256:batchnorm): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (pyramid:256:relu): LeakyReLU(negative_slope=0.2, inplace=True)
    (pyramid:256-512:conv): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (pyramid:512:batchnorm): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (pyramid:512:relu): LeakyReLU(negative_slope=0.2, inplace=True)
    (final:512-1:conv): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
  )
)
Traceback (most recent call last):
  File "C:\Users\Tony\Downloads\WassersteinGAN\main.py", line 167, in
    data_iter = iter(dataloader)
  File "D:\Anaconda3\envs\pytorch\lib\site-packages\torch\utils\data\dataloader.py", line 444, in __iter__
    return self._get_iterator()
  File "D:\Anaconda3\envs\pytorch\lib\site-packages\torch\utils\data\dataloader.py", line 390, in _get_iterator
    return _MultiProcessingDataLoaderIter(self)
  File "D:\Anaconda3\envs\pytorch\lib\site-packages\torch\utils\data\dataloader.py", line 1077, in __init__
    w.start()
  File "D:\Anaconda3\envs\pytorch\lib\multiprocessing\process.py", line 121, in start
    self._popen = self._Popen(self)
  File "D:\Anaconda3\envs\pytorch\lib\multiprocessing\context.py", line 224, in _Popen
    return _default_context.get_context().Process._Popen(process_obj)
  File "D:\Anaconda3\envs\pytorch\lib\multiprocessing\context.py", line 327, in _Popen
    return Popen(process_obj)
  File "D:\Anaconda3\envs\pytorch\lib\multiprocessing\popen_spawn_win32.py", line 93, in __init__
    reduction.dump(process_obj, to_child)
  File "D:\Anaconda3\envs\pytorch\lib\multiprocessing\reduction.py", line 60, in dump
    ForkingPickler(file, protocol).dump(obj)
TypeError: cannot pickle 'Environment' object

(pytorch) D:\>Traceback (most recent call last):
  File "", line 1, in
  File "D:\Anaconda3\envs\pytorch\lib\multiprocessing\spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "D:\Anaconda3\envs\pytorch\lib\multiprocessing\spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
EOFError: Ran out of input

操作系统的问题,本人使用的是Windows,使用Linux没有问题,修改如下:
num_workers=int(opt.workers)修改为num_workers=0。然后训练的情况如下,大概花了2小时,工作环境:Windows下面的1050的显卡,使用的是church_outdoor_train_lmdb这个数据集:PyTorch对WGAN(Wasserstein生成对抗网络)的实现_第2张图片

PyTorch对WGAN(Wasserstein生成对抗网络)的实现_第3张图片
将在samples目录迭代生成很多的“假图”,挑选几张看下:

PyTorch对WGAN(Wasserstein生成对抗网络)的实现_第4张图片PyTorch对WGAN(Wasserstein生成对抗网络)的实现_第5张图片PyTorch对WGAN(Wasserstein生成对抗网络)的实现_第6张图片PyTorch对WGAN(Wasserstein生成对抗网络)的实现_第7张图片PyTorch对WGAN(Wasserstein生成对抗网络)的实现_第8张图片

将bedroom_train_lmdb的数据集也下载下来了,解压之后50.4GB,好家伙,我这机器这得跑多久啊,可以正常迭代,也是在samples文件夹生成很多卧室图片,没有继续测试下完毕,很耗时。

PyTorch对WGAN(Wasserstein生成对抗网络)的实现_第9张图片

你可能感兴趣的:(深度学习框架(PyTorch),生成对抗网络,推土机距离,WGAN错误处理,lsun数据集)