核心算法:
对于判别器Discriminator,其loss函数为:
max log D ( r e a l ) + log ( 1 − D ( f a k e ) ) \qquad\qquad \max \ \log{D(real)}+\log\big(1-D(fake)\big) max logD(real)+log(1−D(fake))
这里, D ( r e a l ) 指 的 是 将 r e a l i m g 判 断 为 T r u e 的 概 率 , D ( f a k e ) 指 的 是 将 f a k e i m g 判 断 为 真 的 概 率 。 D(real)指的是将real \ img判断为True的概率,D(fake)指的是将fake \ img判断为真的概率。 D(real)指的是将real img判断为True的概率,D(fake)指的是将fake img判断为真的概率。
原文中最大化判别器的Loss,即为将真图判别真,将生成图片判别为假的可能性最大。
很多程序实现的时候,使用的是如下类似形式:
min L o s s ( D ( r e a l ) , 1 ) + L o s s ( D ( f a k e ) , 0 ) \qquad\qquad \min\ Loss \big(D(real),1 \big)+Loss \big(D(fake),0 \big ) min Loss(D(real),1)+Loss(D(fake),0)
而对于生成器Generator,其目标仅为更大程度将生成的图片判断为真,paper中使用的是:
min log ( 1 − D ( f a k e ) ) \qquad\qquad \min\ \log(1-D(fake)) min log(1−D(fake))
类似地,代码实现:
min L o s s ( D ( f a k e ) , 1 ) \qquad\qquad \min \ Loss \big(D(fake),1\big) min Loss(D(fake),1)
Loss具体可选择一些基本的函数。如在github cyclegan实现中,给出了两种:
if gan_mode == 'lsgan':
self.loss = nn.MSELoss()
elif gan_mode == 'vanilla':
self.loss = nn.BCEWithLogitsLoss()
实际使用的是'lsgan',即MSELoss
。
parser.add_argument('--gan_mode', type=str, default='lsgan', help='the type of GAN objective. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.')
而在DCGAN的某些实现中,
采用BCE loss:
min l i = − [ y i ⋅ l o g x i + ( 1 − y i ) ⋅ l o g ( 1 − x i ) ] \qquad\qquad \min l_i=−[y_i⋅logx_i+(1−y_i)⋅log(1−x_i)] minli=−[yi⋅logxi+(1−yi)⋅log(1−xi)],
正好等价于:
max log D ( r e a l ) + log ( 1 − D ( f a k e ) ) \qquad\qquad \max \ \log{D(real)}+\log\big(1-D(fake)\big) max logD(real)+log(1−D(fake))
截取部分代码1,以说明问题。
# Initialize BCELoss function
criterion = nn.BCELoss()
# Establish convention for real and fake labels during training
real_label = 1.
fake_label = 0.
discriminator
## Train with all-real batch
label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
# Forward pass real batch through D
output = netD(real_cpu).view(-1)
# Calculate loss on all-real batch
errD_real = criterion(output, label)
## Train with all-fake batch
# Generate batch of latent vectors
noise = torch.randn(b_size, nz, 1, 1, device=device)
# Generate fake image batch with G
fake = netG(noise)
label.fill_(fake_label)
# Classify all fake batch with D
output = netD(fake.detach()).view(-1)
# Calculate D's loss on the all-fake batch
errD_fake = criterion(output, label)
generator
label.fill_(real_label) # fake labels are real for generator cost
# Since we just updated D, perform another forward pass of all-fake batch through D
output = netD(fake).view(-1)
# Calculate G's loss based on this output
errG = criterion(output, label)
针对生成的fake image:
discriminator训练过程中使用的标签是:
label.fill_(fake_label)
而在generator训练过程中使用的标签是:
label.fill_(real_label)
这正是体现对抗思路之所在。discriminator希望尽量将generator产生的fake image判断为假;而generator却期望产生的图片尽量糊弄得了discriminator,让其判断成真。
屁股决定脑袋。
最近在用SRGAN做超分辨任务,发现训练GAN确实不是一件太容易的事。
比如说:discriminator和generator要不要分开训练;discriminator model loss一直为1或者0该怎么办?
How to Train a GAN? Tips and tricks to make GANs work此GitHub给出了不少训练要点,后续逐一领悟其要旨。
pytorch官方网站发布的 pytorch DCGAN TUTORIAL也提到了此文。
pytorch DCGAN TUTORIA ↩︎