这两天在学习SFTGAN,对论文的解读在师兄的论文阅读笔记中有写到,这里就不赘述啦。使用的代码地址:https://github.com/xinntao/BasicSR
因此下面是我从代码中得出的对SFTGAN网络的一些粗略的理解,暂时记录下来,若有不正之处劳烦指正。
SFTGAN是一个生成对抗网络,其G网络(生成网络)结构如图所示:
代码:
class SFTLayer(nn.Module):
def __init__(self):
super(SFTLayer, self).__init__()
self.SFT_scale_conv0 = nn.Conv2d(32, 32, 1)
self.SFT_scale_conv1 = nn.Conv2d(32, 64, 1)
self.SFT_shift_conv0 = nn.Conv2d(32, 32, 1)
self.SFT_shift_conv1 = nn.Conv2d(32, 64, 1)
def forward(self, x):
# x[0]: fea; x[1]: cond
scale = self.SFT_scale_conv1(F.leaky_relu(self.SFT_scale_conv0(x[1]), 0.1, inplace=True))
shift = self.SFT_shift_conv1(F.leaky_relu(self.SFT_shift_conv0(x[1]), 0.1, inplace=True))
return x[0] * (scale + 1) + shift
class ResBlock_SFT(nn.Module):
def __init__(self):
super(ResBlock_SFT, self).__init__()
self.sft0 = SFTLayer()
self.conv0 = nn.Conv2d(64, 64, 3, 1, 1)
self.sft1 = SFTLayer()
self.conv1 = nn.Conv2d(64, 64, 3, 1, 1)
def forward(self, x):
# x[0]: fea; x[1]: cond
fea = self.sft0(x)
fea = F.relu(self.conv0(fea), inplace=True)
fea = self.sft1((fea, x[1]))
fea = self.conv1(fea)
return (x[0] + fea, x[1]) # return a tuple containing features and conditions
class SFT_Net(nn.Module):
def __init__(self):
super(SFT_Net, self).__init__()
self.conv0 = nn.Conv2d(3, 64, 3, 1, 1)
sft_branch = []
for i in range(16):
sft_branch.append(ResBlock_SFT())
sft_branch.append(SFTLayer())
sft_branch.append(nn.Conv2d(64, 64, 3, 1, 1))
self.sft_branch = nn.Sequential(*sft_branch)
self.HR_branch = nn.Sequential(
nn.Conv2d(64, 256, 3, 1, 1),
nn.PixelShuffle(2),
nn.ReLU(True),
nn.Conv2d(64, 256, 3, 1, 1),
nn.PixelShuffle(2),
nn.ReLU(True),
nn.Conv2d(64, 64, 3, 1, 1),
nn.ReLU(True),
nn.Conv2d(64, 3, 3, 1, 1)
)
self.CondNet = nn.Sequential(
nn.Conv2d(8, 128, 4, 4),
nn.LeakyReLU(0.1, True),
nn.Conv2d(128, 128, 1),
nn.LeakyReLU(0.1, True),
nn.Conv2d(128, 128, 1),
nn.LeakyReLU(0.1, True),
nn.Conv2d(128, 128, 1),
nn.LeakyReLU(0.1, True),
nn.Conv2d(128, 32, 1)
)
def forward(self, x):
# x[0]: img; x[1]: seg
cond = self.CondNet(x[1])
fea = self.conv0(x[0])
res = self.sft_branch((fea, cond))
fea = fea + res
out = self.HR_branch(fea)
return out
其实从中我们可以看到,所谓CondNet就是将传统SRResNet生成网络残差块中的BN层(Batch Normalization)替换成了SFT层。原因在师兄的博文中也有提到,相比于BN层的全局统一归一化处理,使用语义分割概率图(segmentation probability maps)做为参考的SFT层对不同物体所在区域进行差异化处理的能力显然更强。这样的话对于物体不同的纹理重建效果应该会得到提升,使得恢复的图片视觉可信度更高。
下面我就结合在学习过程中遇到的问题一步步介绍SFTGAN的代码实现。
参照这论文给出的网络结构图,来到
中查看代码,对于pytorch略微熟悉的同学在这里相信都不会感到有多吃力。我一开始只是想要探究清楚SFT层如何实现,看到这里便以为不过如此。然而观察下面代码的输入数据:
def forward(self, x):
# x[0]: img; x[1]: seg
cond = self.CondNet(x[1])
fea = self.conv0(x[0])
res = self.sft_branch((fea, cond))
fea = fea + res
out = self.HR_branch(fea)
return out
看到了该网络的输入是LR图片与相应的语义分割概率图seg,便自然而然地想要知道seg是如何得来的,数据集在外部应该是怎么样准备的?因此来到:
顿时陷入懵逼状态,直看了一个小时才理清楚,重点在于理解这一块地方:
和这里:
疑问是:
1. bg意为背景,那是怎么样的图片呢?有何作用。
2. category只是一个代表物体的数字,在这里突然出现又有何用?
实际上,在训练这一个网络时,作者使用了两个数据集:
比例大概是:
10 OST data samples and 1 DIV2K general data samples(background)
因此第一段关键代码就好理解,那就是在训练模式下,读取HR图片时每次都有1/10的几率得到DIV2K数据集中的图片,否则就在OST数据集中读取。那么为何要这么大费周章呢?因为OST数据集中的图片分为8种物体,而我们需要DIV2K数据集中的图片充当“什么都不是”的物体——背景。这样训练效果更好,应该是还可以防止过拟合。
但是我们一开始提出的问题还没解决,seg图怎么得来?我们知道既然把DIV2K视为背景了,那么↓就很自然了,seg都是1,都是背景没什么好分割的。
那对于OST中的数据集呢?
应该是使用已训练好的语义分割网络(模型保存在.pth文件中),但是感觉这个torch.load()返回的就已经是一个数组了呀,具体原理还需要进一步学习。在这里只知道seg应该是这样的k张图片:
我们继续看,虽然了解了seg的来源(语义分割网络),但是我们又找到了一个疑问:
我们之所以要使用seg,就是因为它拥有物体分布的空间概率信息,能够帮助我们对图片进行全局差异化的超分辨率重建。这是我最开始的想法,那么其实我只需要只知道“差异”的存在及其分布就够了呀,我还需要理解具体是啥跟啥的差异吗?事实证明我还是图样图森破,只知道这是前景与背景的差异,和知道这是草与房子的差异那是山和天空的差异,显然是后者拥有的信息比较多,而且能够允许我们进行针对性的学习,能够更好的帮助我们进行重建。
那么如何验证呢?很简单只要看category被用在何处就行了。来到:
很明显可以看到,category先是被用在了D鉴别网络的损失函数中,熟悉gan网络的同学一个都知道,一般D网络的损失函数赋予了D网络识别真假图片的能力,但是在SFTGAN的D网络中,损失函数还包括了另外的一项,那就是物体识别。这样一来D网络就有了两个功能:鉴别真假与识别物体。而整个SFTGAN网络也将识别物体作为损失函数的一部分。
当我的思想还错误的停留在SFTGAN特点是“差异性”时,我无法理解为什么要引入这样的一个损失项。但如果将“差异性”进化为“针对性”那么就好理解了。D网络的物体识别功能,促使G网络在进行图片重建时有针对性地学习了有限的几种事物的特征。对于每一张SR图片,在训练阶段都要求达到:
对于将category做为一项损失,我再补充谈谈自己的想法。以往的SR算法只看重PSNR,基于MSE损失的网络本身并不会去理解也不在意它在重建的是什么东西。而SRGAN引入了感知损失,通过VGG网络提取深层特征进行比较,相比之下它更关注自己重建的是啥,但是还不够。因为笼统的提取特征,在遇到不同事物具有相似纹理时往往导致误判 ,也就是说D网络不尽责,区分墙壁与野草的能力差,使得效率低的G网络轻松通过鉴别,出来的图像却逃不过人眼。
如何改进呢?SFTGAN的作者想到(猜测)兵分两路:(1)我直接告诉G网络这是墙壁那是野草你不要搞混了(使用seg图提供信息、SFT层融合信息)(2)一个一个单独教D网络如何识别这k种物体(category损失)。这样导致的结果是什么呢?
G网络知道了差异,但并不清楚面对差异纠结具体该怎么做。好在他的监督者:D网络,胸有成竹不断提供指导。加之训练阶段的数据都是单独的某件物体,让D网络学习得十分充分十分,才能在之后更加尽责。
在这里我认为其实训练阶段除了G网络的SR残差学习在加油进步,事实上亮点、重点应该是在D网络的部分,D网络一旦能够拥有强大的VGG感知信息进行特征比较能力、以及自身的物体识别能力,就能更加严苛地要求G网络生成高质量的图片。而SFT层的作用呢?面对严苛的监督者,G网络如果本身就资质平平,没有潜力,再激励也不行呀。但是我们这个G网络先天根骨清奇(拥有先验信息语义分割概率图),又打开了任督二脉融会贯通(SFT),因此在压力下才能不断提升实力,得到更加可信的SR图像。