SRGAN论文阅读笔记

“Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network”首次使用生成对抗网络(GAN)应用于图像超分辨率(SR),在图像超分辨率领域引起了极大的关注,该工作由twitter公司提出,发表于2017 IEEE Conference on Computer Vision and Pattern Recognition (CVPR)

论文地址:http://arxiv.org/abs/1609.04802

简介

作者总结了最近的图像超分辨率的工作,认为大都集中于以均方差(MSE)作为损失函数,造成生成图像过于平滑,缺少高频细节,看起来觉得不真实,感到不舒服。所以提出了基于生成式对抗网络的网络结构,作者认为这是生成式对抗网络第一次应用于4倍下采样图像的超分辨重建工作。

SRGAN利用感知损失(perceptual loss),由对抗损失(adversarial loss)和内容损失(content loss)组成。
对抗损失将图像映射到高位流形空间,并用判别网络去判别重建后的图像和原始图像。而内容损失则是基于感觉相似性(perceptual similarity)而非像素相似性(pixel similarity),所以生成的高分辨图像视觉效果更好。

主要贡献

1.把生成对抗网络思想应用于图像超分辨率工作中,判别器无法分辨出生成的超分辨率图像和真实的图像,使得生成的图像达到Photo-Realistic的效果。

2.设计了新型的感知损失(perceptual loss)作为网络的损失函数。

SRGAN网络结构

SRGAN论文阅读笔记_第1张图片网络结构如上图所示
GAN的生成器:残差块+卷积层+BN层+ReLU
GAN的判别器:VGG+LeakyReLU+max-pooling

SRGAN是在SRResnet的基础上加上一个鉴别器。GAN的作用,是额外增加一个鉴别器网络和2个损失(g_loss和d_loss),用一种交替训练的方式训练两个网络。
模型可以分为3部分:main(生成)模块,adversarial模块,和vgg模块。只在训练阶段会用到adversarial模块进行计算,而在推断阶段,仅仅使用G网络。
对任何一个问题,都可以让训练过程“对抗化”。“对抗化”的步骤是:先确定该问题的解决方法,把原始方法当成GAN中的G网络 ,再另外增加一个D网络(二分类网络),在原来更新main模块的loss中,增加“生成对抗损失”(要生成让判别器无法区分的数据分布),一起用来更新main模块(也就是GAN中的G网络),用判别损失更新GAN中的D网络。[1]

SRGAN的损失函数

SRGAN损失包括两部分:内容损失(content loss)和对抗损失(adversarial loss),用一定的权重进行加权和。
SRGAN论文阅读笔记_第2张图片
content loss
在这里插入图片描述
传统的MSE loss,可以得到很高的信噪比,但是这样的方式产生的图像存在高频细节缺失的问题

SRGAN论文阅读笔记_第3张图片
作者定义了以预训练19层VGG网络的ReLU激活层为基础的VGG loss,求生成图像和参考图像特征表示的欧氏距离。在已经训练好的VGG网络上提出某一层的feature map,将生成的图像和真实图像的这次一层输出的feature map比较。

adversarial loss
在这里插入图片描述
生成让判别器无法区分的数据分布。

公式背后的数学意义就是MSE+GAN,每个占一定部分的权重,分别表示空间的相似性、判别器看到的相似性。结合代码实现,更便于理解SRGAN的损失函数。

代码实现

    def train(self):
        # models setup
        self.netG.train()
        self.netD.train()
        g_train_loss = 0
        d_train_loss = 0
        for batch_num, (data, target) in enumerate(self.training_loader):
            # setup noise
            real_label = torch.ones(data.size(0), data.size(1)).to(self.device)
            fake_label = torch.zeros(data.size(0), data.size(1)).to(self.device)
            data, target = data.to(self.device), target.to(self.device)

            # Train Discriminator
            self.optimizerD.zero_grad()
            d_real = self.netD(target)
            d_real_loss = self.criterionD(d_real, real_label)

            d_fake = self.netD(self.netG(data))
            d_fake_loss = self.criterionD(d_fake, fake_label)
            d_total = d_real_loss + d_fake_loss
            d_train_loss += d_total.item()
            d_total.backward()
            self.optimizerD.step()

            # Train generator
            self.optimizerG.zero_grad()
            g_real = self.netG(data)
            g_fake = self.netD(g_real)
            gan_loss = self.criterionD(g_fake, real_label)
            mse_loss = self.criterionG(g_real, target)

            g_total = mse_loss + 1e-3 * gan_loss
            g_train_loss += g_total.item()
            g_total.backward()
            self.optimizerG.step()

            progress_bar(batch_num, len(self.training_loader), 'G_Loss: %.4f | D_Loss: %.4f' % (g_train_loss / (batch_num + 1), d_train_loss / (batch_num + 1)))

        print("    Average G_Loss: {:.4f}".format(g_train_loss / len(self.training_loader)))

未来工作

作者提出认为PSNR不能从视觉上衡量生成图像的好坏,这里需要找到更好的图像质量评价标准。
“feature maps of these deeper layers focus purely on the content while leaving the adversarial loss focusing on texture details”
使用更深层的网络去改进SRGAN,并找到更合适的损失函数。这里是从网络结构和损失函数的设计层面上来说。

[1] https://blog.csdn.net/DuinoDu/article/details/78819344

你可能感兴趣的:(图像超分辨率)