“Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network”首次使用生成对抗网络(GAN)应用于图像超分辨率(SR),在图像超分辨率领域引起了极大的关注,该工作由twitter公司提出,发表于2017 IEEE Conference on Computer Vision and Pattern Recognition (CVPR)
论文地址:http://arxiv.org/abs/1609.04802
作者总结了最近的图像超分辨率的工作,认为大都集中于以均方差(MSE)作为损失函数,造成生成图像过于平滑,缺少高频细节,看起来觉得不真实,感到不舒服。所以提出了基于生成式对抗网络的网络结构,作者认为这是生成式对抗网络第一次应用于4倍下采样图像的超分辨重建工作。
SRGAN利用感知损失(perceptual loss),由对抗损失(adversarial loss)和内容损失(content loss)组成。
对抗损失将图像映射到高位流形空间,并用判别网络去判别重建后的图像和原始图像。而内容损失则是基于感觉相似性(perceptual similarity)而非像素相似性(pixel similarity),所以生成的高分辨图像视觉效果更好。
1.把生成对抗网络思想应用于图像超分辨率工作中,判别器无法分辨出生成的超分辨率图像和真实的图像,使得生成的图像达到Photo-Realistic的效果。
2.设计了新型的感知损失(perceptual loss)作为网络的损失函数。
网络结构如上图所示
GAN的生成器:残差块+卷积层+BN层+ReLU
GAN的判别器:VGG+LeakyReLU+max-pooling
SRGAN是在SRResnet的基础上加上一个鉴别器。GAN的作用,是额外增加一个鉴别器网络和2个损失(g_loss和d_loss),用一种交替训练的方式训练两个网络。
模型可以分为3部分:main(生成)模块,adversarial模块,和vgg模块。只在训练阶段会用到adversarial模块进行计算,而在推断阶段,仅仅使用G网络。
对任何一个问题,都可以让训练过程“对抗化”。“对抗化”的步骤是:先确定该问题的解决方法,把原始方法当成GAN中的G网络 ,再另外增加一个D网络(二分类网络),在原来更新main模块的loss中,增加“生成对抗损失”(要生成让判别器无法区分的数据分布),一起用来更新main模块(也就是GAN中的G网络),用判别损失更新GAN中的D网络。[1]
SRGAN损失包括两部分:内容损失(content loss)和对抗损失(adversarial loss),用一定的权重进行加权和。
content loss
传统的MSE loss,可以得到很高的信噪比,但是这样的方式产生的图像存在高频细节缺失的问题
作者定义了以预训练19层VGG网络的ReLU激活层为基础的VGG loss,求生成图像和参考图像特征表示的欧氏距离。在已经训练好的VGG网络上提出某一层的feature map,将生成的图像和真实图像的这次一层输出的feature map比较。
adversarial loss
生成让判别器无法区分的数据分布。
公式背后的数学意义就是MSE+GAN,每个占一定部分的权重,分别表示空间的相似性、判别器看到的相似性。结合代码实现,更便于理解SRGAN的损失函数。
代码实现
def train(self):
# models setup
self.netG.train()
self.netD.train()
g_train_loss = 0
d_train_loss = 0
for batch_num, (data, target) in enumerate(self.training_loader):
# setup noise
real_label = torch.ones(data.size(0), data.size(1)).to(self.device)
fake_label = torch.zeros(data.size(0), data.size(1)).to(self.device)
data, target = data.to(self.device), target.to(self.device)
# Train Discriminator
self.optimizerD.zero_grad()
d_real = self.netD(target)
d_real_loss = self.criterionD(d_real, real_label)
d_fake = self.netD(self.netG(data))
d_fake_loss = self.criterionD(d_fake, fake_label)
d_total = d_real_loss + d_fake_loss
d_train_loss += d_total.item()
d_total.backward()
self.optimizerD.step()
# Train generator
self.optimizerG.zero_grad()
g_real = self.netG(data)
g_fake = self.netD(g_real)
gan_loss = self.criterionD(g_fake, real_label)
mse_loss = self.criterionG(g_real, target)
g_total = mse_loss + 1e-3 * gan_loss
g_train_loss += g_total.item()
g_total.backward()
self.optimizerG.step()
progress_bar(batch_num, len(self.training_loader), 'G_Loss: %.4f | D_Loss: %.4f' % (g_train_loss / (batch_num + 1), d_train_loss / (batch_num + 1)))
print(" Average G_Loss: {:.4f}".format(g_train_loss / len(self.training_loader)))
作者提出认为PSNR不能从视觉上衡量生成图像的好坏,这里需要找到更好的图像质量评价标准。
“feature maps of these deeper layers focus purely on the content while leaving the adversarial loss focusing on texture details”
使用更深层的网络去改进SRGAN,并找到更合适的损失函数。这里是从网络结构和损失函数的设计层面上来说。
[1] https://blog.csdn.net/DuinoDu/article/details/78819344