GAN 起航篇

关于GAN的点滴理解

  • 1.关于Loss函数
    • 1.1 Goodfellow 大神的 paper 《Generative Adversarial Nets》
    • 1.2 Loss的具体选择
    • 1.3 训练过程中generator和discriminator中的label
  • 2.炼丹心法

1.关于Loss函数

1.1 Goodfellow 大神的 paper 《Generative Adversarial Nets》

核心算法:
GAN 起航篇_第1张图片
对于判别器Discriminator,其loss函数为:

max ⁡   log ⁡ D ( r e a l ) + log ⁡ ( 1 − D ( f a k e ) ) \qquad\qquad \max \ \log{D(real)}+\log\big(1-D(fake)\big) max logD(real)+log(1D(fake))

这里, D ( r e a l ) 指 的 是 将 r e a l   i m g 判 断 为 T r u e 的 概 率 , D ( f a k e ) 指 的 是 将 f a k e   i m g 判 断 为 真 的 概 率 。 D(real)指的是将real \ img判断为True的概率,D(fake)指的是将fake \ img判断为真的概率。 D(real)real imgTrueD(fake)fake img

原文中最大化判别器的Loss,即为将真图判别真,将生成图片判别为假的可能性最大。

很多程序实现的时候,使用的是如下类似形式:

min ⁡   L o s s ( D ( r e a l ) , 1 ) + L o s s ( D ( f a k e ) , 0 ) \qquad\qquad \min\ Loss \big(D(real),1 \big)+Loss \big(D(fake),0 \big ) min Loss(D(real),1)+Loss(D(fake),0)

而对于生成器Generator,其目标仅为更大程度将生成的图片判断为真,paper中使用的是:

min ⁡   log ⁡ ( 1 − D ( f a k e ) ) \qquad\qquad \min\ \log(1-D(fake)) min log(1D(fake))

类似地,代码实现:

min ⁡   L o s s ( D ( f a k e ) , 1 ) \qquad\qquad \min \ Loss \big(D(fake),1\big) min Loss(D(fake),1)

1.2 Loss的具体选择

Loss具体可选择一些基本的函数。如在github cyclegan实现中,给出了两种:

if gan_mode == 'lsgan':
    self.loss = nn.MSELoss()
elif gan_mode == 'vanilla':
    self.loss = nn.BCEWithLogitsLoss()

实际使用的是'lsgan',即MSELoss

parser.add_argument('--gan_mode', type=str, default='lsgan', help='the type of GAN objective. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.')

而在DCGAN的某些实现中,
采用BCE loss:
min ⁡ l i = − [ y i ⋅ l o g x i + ( 1 − y i ) ⋅ l o g ( 1 − x i ) ] \qquad\qquad \min l_i=−[y_i⋅logx_i+(1−y_i)⋅log(1−x_i)] minli=[yilogxi+(1yi)log(1xi)]
正好等价于:
max ⁡   log ⁡ D ( r e a l ) + log ⁡ ( 1 − D ( f a k e ) ) \qquad\qquad \max \ \log{D(real)}+\log\big(1-D(fake)\big) max logD(real)+log(1D(fake))

1.3 训练过程中generator和discriminator中的label

截取部分代码1,以说明问题。

# Initialize BCELoss function
criterion = nn.BCELoss()

# Establish convention for real and fake labels during training
real_label = 1.
fake_label = 0.

discriminator

## Train with all-real batch
label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
# Forward pass real batch through D
output = netD(real_cpu).view(-1)
# Calculate loss on all-real batch
errD_real = criterion(output, label)

## Train with all-fake batch
# Generate batch of latent vectors
noise = torch.randn(b_size, nz, 1, 1, device=device)
# Generate fake image batch with G
fake = netG(noise)
label.fill_(fake_label)
# Classify all fake batch with D
output = netD(fake.detach()).view(-1)
 # Calculate D's loss on the all-fake batch
errD_fake = criterion(output, label)

generator

label.fill_(real_label)  # fake labels are real for generator cost
# Since we just updated D, perform another forward pass of all-fake batch through D
output = netD(fake).view(-1)
# Calculate G's loss based on this output
errG = criterion(output, label)

针对生成的fake image:

discriminator训练过程中使用的标签是:

label.fill_(fake_label)

而在generator训练过程中使用的标签是:

label.fill_(real_label) 

这正是体现对抗思路之所在。discriminator希望尽量将generator产生的fake image判断为假;而generator却期望产生的图片尽量糊弄得了discriminator,让其判断成真。

屁股决定脑袋。

2.炼丹心法

最近在用SRGAN做超分辨任务,发现训练GAN确实不是一件太容易的事。

比如说:discriminator和generator要不要分开训练;discriminator model loss一直为1或者0该怎么办?

How to Train a GAN? Tips and tricks to make GANs work此GitHub给出了不少训练要点,后续逐一领悟其要旨。

pytorch官方网站发布的 pytorch DCGAN TUTORIAL也提到了此文。


  1. pytorch DCGAN TUTORIA ↩︎

你可能感兴趣的:(GAN,深度学习)