SFTGAN学习笔记

这两天在学习SFTGAN,对论文的解读在师兄的论文阅读笔记中有写到,这里就不赘述啦。使用的代码地址:https://github.com/xinntao/BasicSR
因此下面是我从代码中得出的对SFTGAN网络的一些粗略的理解,暂时记录下来,若有不正之处劳烦指正。
SFTGAN是一个生成对抗网络,其G网络(生成网络)结构如图所示:
SFTGAN学习笔记_第1张图片
代码:

class SFTLayer(nn.Module):
    def __init__(self):
        super(SFTLayer, self).__init__()
        self.SFT_scale_conv0 = nn.Conv2d(32, 32, 1)
        self.SFT_scale_conv1 = nn.Conv2d(32, 64, 1)
        self.SFT_shift_conv0 = nn.Conv2d(32, 32, 1)
        self.SFT_shift_conv1 = nn.Conv2d(32, 64, 1)

    def forward(self, x):
        # x[0]: fea; x[1]: cond
        scale = self.SFT_scale_conv1(F.leaky_relu(self.SFT_scale_conv0(x[1]), 0.1, inplace=True))
        shift = self.SFT_shift_conv1(F.leaky_relu(self.SFT_shift_conv0(x[1]), 0.1, inplace=True))
        return x[0] * (scale + 1) + shift


class ResBlock_SFT(nn.Module):
    def __init__(self):
        super(ResBlock_SFT, self).__init__()
        self.sft0 = SFTLayer()
        self.conv0 = nn.Conv2d(64, 64, 3, 1, 1)
        self.sft1 = SFTLayer()
        self.conv1 = nn.Conv2d(64, 64, 3, 1, 1)

    def forward(self, x):
        # x[0]: fea; x[1]: cond
        fea = self.sft0(x)
        fea = F.relu(self.conv0(fea), inplace=True)
        fea = self.sft1((fea, x[1]))
        fea = self.conv1(fea)
        return (x[0] + fea, x[1])  # return a tuple containing features and conditions


class SFT_Net(nn.Module):
    def __init__(self):
        super(SFT_Net, self).__init__()
        self.conv0 = nn.Conv2d(3, 64, 3, 1, 1)

        sft_branch = []
        for i in range(16):
            sft_branch.append(ResBlock_SFT())
        sft_branch.append(SFTLayer())
        sft_branch.append(nn.Conv2d(64, 64, 3, 1, 1))
        self.sft_branch = nn.Sequential(*sft_branch)

        self.HR_branch = nn.Sequential(
            nn.Conv2d(64, 256, 3, 1, 1),
            nn.PixelShuffle(2),
            nn.ReLU(True),
            nn.Conv2d(64, 256, 3, 1, 1),
            nn.PixelShuffle(2),
            nn.ReLU(True),
            nn.Conv2d(64, 64, 3, 1, 1),
            nn.ReLU(True),
            nn.Conv2d(64, 3, 3, 1, 1)
        )

        self.CondNet = nn.Sequential(
            nn.Conv2d(8, 128, 4, 4),
            nn.LeakyReLU(0.1, True),
            nn.Conv2d(128, 128, 1),
            nn.LeakyReLU(0.1, True),
            nn.Conv2d(128, 128, 1),
            nn.LeakyReLU(0.1, True),
            nn.Conv2d(128, 128, 1),
            nn.LeakyReLU(0.1, True),
            nn.Conv2d(128, 32, 1)
        )

    def forward(self, x):
        # x[0]: img; x[1]: seg
        cond = self.CondNet(x[1])
        fea = self.conv0(x[0])
        res = self.sft_branch((fea, cond))
        fea = fea + res
        out = self.HR_branch(fea)
        return out

其实从中我们可以看到,所谓CondNet就是将传统SRResNet生成网络残差块中的BN层(Batch Normalization)替换成了SFT层。原因在师兄的博文中也有提到,相比于BN层的全局统一归一化处理,使用语义分割概率图(segmentation probability maps)做为参考的SFT层对不同物体所在区域进行差异化处理的能力显然更强。这样的话对于物体不同的纹理重建效果应该会得到提升,使得恢复的图片视觉可信度更高。

下面我就结合在学习过程中遇到的问题一步步介绍SFTGAN的代码实现。
参照这论文给出的网络结构图,来到在这里插入图片描述
中查看代码,对于pytorch略微熟悉的同学在这里相信都不会感到有多吃力。我一开始只是想要探究清楚SFT层如何实现,看到这里便以为不过如此。然而观察下面代码的输入数据:

    def forward(self, x):
        # x[0]: img; x[1]: seg
        cond = self.CondNet(x[1])
        fea = self.conv0(x[0])
        res = self.sft_branch((fea, cond))
        fea = fea + res
        out = self.HR_branch(fea)
        return out

看到了该网络的输入是LR图片与相应的语义分割概率图seg,便自然而然地想要知道seg是如何得来的,数据集在外部应该是怎么样准备的?因此来到: 在这里插入图片描述
顿时陷入懵逼状态,直看了一个小时才理清楚,重点在于理解这一块地方:
SFTGAN学习笔记_第2张图片
和这里:
SFTGAN学习笔记_第3张图片
疑问是:

1. bg意为背景,那是怎么样的图片呢?有何作用。
2. category只是一个代表物体的数字,在这里突然出现又有何用?

实际上,在训练这一个网络时,作者使用了两个数据集: 在这里插入图片描述
比例大概是:
10 OST data samples and 1 DIV2K general data samples(background)
因此第一段关键代码就好理解,那就是在训练模式下,读取HR图片时每次都有1/10的几率得到DIV2K数据集中的图片,否则就在OST数据集中读取。那么为何要这么大费周章呢?因为OST数据集中的图片分为8种物体,而我们需要DIV2K数据集中的图片充当“什么都不是”的物体——背景。这样训练效果更好,应该是还可以防止过拟合。
但是我们一开始提出的问题还没解决,seg图怎么得来?我们知道既然把DIV2K视为背景了,那么↓就很自然了,seg都是1,都是背景没什么好分割的。
SFTGAN学习笔记_第4张图片
那对于OST中的数据集呢?
在这里插入图片描述
应该是使用已训练好的语义分割网络(模型保存在.pth文件中),但是感觉这个torch.load()返回的就已经是一个数组了呀,具体原理还需要进一步学习。在这里只知道seg应该是这样的k张图片:
SFTGAN学习笔记_第5张图片
我们继续看,虽然了解了seg的来源(语义分割网络),但是我们又找到了一个疑问:

  • category的作用是什么呢?

我们之所以要使用seg,就是因为它拥有物体分布的空间概率信息,能够帮助我们对图片进行全局差异化的超分辨率重建。这是我最开始的想法,那么其实我只需要只知道“差异”的存在及其分布就够了呀,我还需要理解具体是啥跟啥的差异吗?事实证明我还是图样图森破,只知道这是前景与背景的差异,和知道这是草与房子的差异那是山和天空的差异,显然是后者拥有的信息比较多,而且能够允许我们进行针对性的学习,能够更好的帮助我们进行重建。
那么如何验证呢?很简单只要看category被用在何处就行了。来到: 在这里插入图片描述
SFTGAN学习笔记_第6张图片
SFTGAN学习笔记_第7张图片
很明显可以看到,category先是被用在了D鉴别网络的损失函数中,熟悉gan网络的同学一个都知道,一般D网络的损失函数赋予了D网络识别真假图片的能力,但是在SFTGAN的D网络中,损失函数还包括了另外的一项,那就是物体识别。这样一来D网络就有了两个功能:鉴别真假与识别物体。而整个SFTGAN网络也将识别物体作为损失函数的一部分。
当我的思想还错误的停留在SFTGAN特点是“差异性”时,我无法理解为什么要引入这样的一个损失项。但如果将“差异性”进化为“针对性”那么就好理解了。D网络的物体识别功能,促使G网络在进行图片重建时有针对性地学习了有限的几种事物的特征。对于每一张SR图片,在训练阶段都要求达到:

  1. G网络觉得总体上看上去像(MSE损失,即PSNR值);
  2. F网络(特征提取网络)觉得特征看上去像(经过VGG提取feature map的MSE损失);
  3. D网络认不出来真假(二分类交叉熵损失函数);
  4. D网络觉得这是该物体没错(多分类交叉熵)。
    大多数生成对抗网络满足的都是头三条,而SFTGAN加入了第四条。而奇怪的是,在代码有实现的这一项损失,论文中并未提及:
    SFTGAN学习笔记_第8张图片
    但是,阅读论文中的实验数据准备部分我们就能知道作者是有这个意图的:
    SFTGAN学习笔记_第9张图片
    事实上论文中的这段话也解决了我的一大困惑,那就是这项损失的target只有一种类别,那么当图片中存在多种物体时网络是否还有效呢?现在才知道原来人家是早有准备的。而且,在test时,由于已训练好的生成网络借助语义分割概率图,已经能够有针对性得对不同的事物进行处理,因此test数据也就不用选裁了,即使存在多个物体也没问题。
    在test阶段,输入网络的数据包括一张LR图片,以及k张语义分割概率图,每一张概率图中都显示着某一特定物体在该位置存在的概率。输入数据兵分两路进入生成网络,得到输出的SR图像。当然这里也就没F网络和D网络啥事了,G网络已经获得了根据概率图针对性重建图像的能力了。而语义分割概率图应该算是一种先验的信息。

对于将category做为一项损失,我再补充谈谈自己的想法。以往的SR算法只看重PSNR,基于MSE损失的网络本身并不会去理解也不在意它在重建的是什么东西。而SRGAN引入了感知损失,通过VGG网络提取深层特征进行比较,相比之下它更关注自己重建的是啥,但是还不够。因为笼统的提取特征,在遇到不同事物具有相似纹理时往往导致误判 ,也就是说D网络不尽责,区分墙壁与野草的能力差,使得效率低的G网络轻松通过鉴别,出来的图像却逃不过人眼。
SFTGAN学习笔记_第10张图片
如何改进呢?SFTGAN的作者想到(猜测)兵分两路:(1)我直接告诉G网络这是墙壁那是野草你不要搞混了(使用seg图提供信息、SFT层融合信息)(2)一个一个单独教D网络如何识别这k种物体(category损失)。这样导致的结果是什么呢?
G网络知道了差异,但并不清楚面对差异纠结具体该怎么做。好在他的监督者:D网络,胸有成竹不断提供指导。加之训练阶段的数据都是单独的某件物体,让D网络学习得十分充分十分,才能在之后更加尽责。
在这里我认为其实训练阶段除了G网络的SR残差学习在加油进步,事实上亮点、重点应该是在D网络的部分,D网络一旦能够拥有强大的VGG感知信息进行特征比较能力、以及自身的物体识别能力,就能更加严苛地要求G网络生成高质量的图片。而SFT层的作用呢?面对严苛的监督者,G网络如果本身就资质平平,没有潜力,再激励也不行呀。但是我们这个G网络先天根骨清奇(拥有先验信息语义分割概率图),又打开了任督二脉融会贯通(SFT),因此在压力下才能不断提升实力,得到更加可信的SR图像。

你可能感兴趣的:(SFTGAN学习笔记)