“Recovering Realistic Texture in Image Super-resolution by Deep Spatial Feature Transform”发表于CVPR 2018
作者论文、补充材料、数据集及代码地址:http://mmlab.ie.cuhk.edu.hk/projects/SFTGAN/
这篇论文提出了使用先验类别信息来解决超分辨率纹理不真实的问题 ,就是在超分辨率的合成中使用语义图,语义图的生成使用了图像分割网络。文章探讨了不同分辨率下的语义分割的误差,比较后发现其实在高低分辨率图像对于分割的精度影响不大。
class SFTLayer(nn.Module):
def __init__(self):
super(SFTLayer, self).__init__()
self.SFT_scale_conv0 = nn.Conv2d(32, 32, 1)
self.SFT_scale_conv1 = nn.Conv2d(32, 64, 1)
self.SFT_shift_conv0 = nn.Conv2d(32, 32, 1)
self.SFT_shift_conv1 = nn.Conv2d(32, 64, 1)
def forward(self, x):
# x[0]: fea; x[1]: cond
scale = self.SFT_scale_conv1(F.leaky_relu(self.SFT_scale_conv0(x[1]), 0.1, inplace=True))
shift = self.SFT_shift_conv1(F.leaky_relu(self.SFT_shift_conv0(x[1]), 0.1, inplace=True))
return x[0] * (scale + 1) + shift
class ResBlock_SFT(nn.Module):
def __init__(self):
super(ResBlock_SFT, self).__init__()
self.sft0 = SFTLayer()
self.conv0 = nn.Conv2d(64, 64, 3, 1, 1)
self.sft1 = SFTLayer()
self.conv1 = nn.Conv2d(64, 64, 3, 1, 1)
def forward(self, x):
# x[0]: fea; x[1]: cond
fea = self.sft0(x)
fea = F.relu(self.conv0(fea), inplace=True)
fea = self.sft1((fea, x[1]))
fea = self.conv1(fea)
return (x[0] + fea, x[1]) # return a tuple containing features and conditions
class SFT_Net(nn.Module):
def __init__(self):
super(SFT_Net, self).__init__()
self.conv0 = nn.Conv2d(3, 64, 3, 1, 1)
sft_branch = []
for i in range(16):
sft_branch.append(ResBlock_SFT())
sft_branch.append(SFTLayer())
sft_branch.append(nn.Conv2d(64, 64, 3, 1, 1))
self.sft_branch = nn.Sequential(*sft_branch)
self.HR_branch = nn.Sequential(
nn.Conv2d(64, 256, 3, 1, 1),
nn.PixelShuffle(2),
nn.ReLU(True),
nn.Conv2d(64, 256, 3, 1, 1),
nn.PixelShuffle(2),
nn.ReLU(True),
nn.Conv2d(64, 64, 3, 1, 1),
nn.ReLU(True),
nn.Conv2d(64, 3, 3, 1, 1)
)
self.CondNet = nn.Sequential(
nn.Conv2d(8, 128, 4, 4),
nn.LeakyReLU(0.1, True),
nn.Conv2d(128, 128, 1),
nn.LeakyReLU(0.1, True),
nn.Conv2d(128, 128, 1),
nn.LeakyReLU(0.1, True),
nn.Conv2d(128, 128, 1),
nn.LeakyReLU(0.1, True),
nn.Conv2d(128, 32, 1)
)
def forward(self, x):
# x[0]: img; x[1]: seg
cond = self.CondNet(x[1])
fea = self.conv0(x[0])
res = self.sft_branch((fea, cond))
fea = fea + res
out = self.HR_branch(fea)
return out
实验结果