结合了几乎所有的目前先进的图像修复技术,基于部分卷积提出了门控卷积,结合了CA中的注意力机制,根据 Adversarial Edge图像修复中的边缘信息先验提出了用户可交互的草图先验信息。基于spectral-normallized GAN 提出了 SN-PatchGAN 鉴别器,本文所用的损失函数只有l1 重建损失和 SN-PatchGAN损失.
为了介绍门控卷积,得先提提部分卷积,对于分类、分割等任务,网络的输入像素是全部有效的,而对于修复任务,孔洞区域的像素是无效像素,如果将其当成和其他区域的像素一样处理,那么必然会造成修复结果的模糊,颜色不一致等情况,基于这种原因,部分卷积(partial convolution)被提出。它的实现机制在我上一篇Image Inpainting for Irregular Holes Using有被提到。它的目的在于,使得卷积的结果尽量只依赖与有效像素。部分卷积有效提高了非规则掩模上的图像修复质量。但是仍然还存在一些问题:
部分卷积与门控卷积的图示区别如下图:
基于上述部分卷积的一些问题,本文作者提出了门控卷积。取代了部分卷积的硬门控的掩码mask更新规则,门控卷积从数据中自动学习软掩码mask.更新的数学表达如下:
这里的I是特征图, σ \sigma σ是sigmoid()函数, ϕ \phi ϕ是激活函数,可以是ReLU、ELU、LeakyReLU。实际就是对I分别做两次卷积,然后其中一个卷积用sigmoid()函数,将其值全部限制在0-1之间,然后与另外一个卷积得到的特征图进行逐像素的相乘。
门控卷积的代码实现非常简单,如下:
#1.门控卷积的模块
class Gated_Conv(nn.Module):
def __init__(self,in_ch,out_ch,ksize=3,stride=1,rate=1,activation=nn.ELU):
super(Gated_Conv, self).__init__()
padding=int(rate*(ksize-1)/2)
#通过卷积将通道数变成输出两倍,其中一半用来做门控,学习
self.conv=nn.Conv2d(in_ch,2*out_ch,kernel_size=ksize,stride=stride,padding=padding,dilation=rate)
self.activation=activation
def forward(self,x):
raw=self.conv(x)
x1=raw.split(int(raw.shape[1]/2),dim=1)#将特征图分成两半,其中一半是做学习
gate=torch.sigmoid(x1[0])#将值限制在0-1之间
out=self.activation(x1[1])*gate
return out
对于孔洞单一为矩形的,local GAN 使用提升了修复结果,但是对于自由形式孔洞区域,这种局部鉴别器显然不太适用。基于 global and local GANs、MarkovianGAN、perceptual loss 和spectral-normalized loss.。作者提出了简单高效的SN-PatchGAN,可以应对自由形式的空洞破损。网络结构如下图所示:
网络的输入包括:破损图片、孔洞掩码mask、用户指导的先验草图信息。网络的输出是3D的feature map.而不是传统鉴别器输出的了一个打分标量。网络堆叠了6个卷积为kernel size为5,stride=2去捕获Markovian patches的特征统计信息。值得注意的是输出特征图的每一个元素的感受野都是包含了整个输入图。因此全局鉴别器也就不需要了。同时也采用了spectral normalizetion (借鉴的是SN-GANs)来进一步稳定GAN的训练。为了鉴别出真图还是假图,采用了hinge loss作为目标函数,对于生成器G:
l o s s G = − E z − p z ( z ) [ D s n ( G ( z ) ) ] loss_G=-E_{z-p_z(z)}[D^{sn}(G(z))] lossG=−Ez−pz(z)[Dsn(G(z))]
对于鉴别器:
l o s s D = E x − P d a t a ( x ) [ R e L U ( 1 − D s n ( x ) ) ] + E z − p z ( z ) [ R e L U ( 1 + D s n ( G ( z ) ) ) ] loss_D=E_{x-P_{data}(x)}[ReLU(1-D^{sn}(x))]+E_{z-p_z(z)}[ReLU(1+D^{sn}(G(z)))] lossD=Ex−Pdata(x)[ReLU(1−Dsn(x))]+Ez−pz(z)[ReLU(1+Dsn(G(z)))]
这里的 D s n D^{sn} Dsn代表spectral-normalized discriminator ,G是修复网络。
鉴别器网络结构实现如下:
#1.
class SpectralNorm(nn.Module):
'''
spectral normalization,modified from https://github.com/christiancosgrove/pytorch-spectral-normalization-gan/blob/master/spectral_normalization.py
'''
def __init__(self, module, name='weight', power_iterations=1):
super(SpectralNorm, self).__init__()
self.module = module
self.name = name
self.power_iterations = power_iterations
if not self._made_params():
self._make_params()
def _update_u_v(self):
u = getattr(self.module, self.name + "_u")
v = getattr(self.module, self.name + "_v")
w = getattr(self.module, self.name + "_bar")
height = w.data.shape[0]
for _ in range(self.power_iterations):
v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data))
u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data))
# sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))
sigma = u.dot(w.view(height, -1).mv(v))
setattr(self.module, self.name, w / sigma.expand_as(w))
def _made_params(self):
try:
u = getattr(self.module, self.name + "_u")
v = getattr(self.module, self.name + "_v")
w = getattr(self.module, self.name + "_bar")
return True
except AttributeError:
return False
def _make_params(self):
w = getattr(self.module, self.name)
height = w.data.shape[0]
width = w.view(height, -1).data.shape[1]
u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
u.data = l2normalize(u.data)
v.data = l2normalize(v.data)
w_bar = Parameter(w.data)
del self.module._parameters[self.name]
self.module.register_parameter(self.name + "_u", u)
self.module.register_parameter(self.name + "_v", v)
self.module.register_parameter(self.name + "_bar", w_bar)
def forward(self, *args):
self._update_u_v()
return self.module.forward(*args)
#2.SN卷积层实现
class SN_Conv(nn.Module):
def __init__(self,in_ch,out_ch,ksize=3,stride=1,rate=1,activation=nn.LeakyReLU()):
super(SN_Conv,self).__init__()
padding = int(rate * (ksize - 1) / 2)
conv = nn.Conv2d(in_ch,out_ch, kernel_size=ksize, stride=stride, padding=padding, dilation=rate)
self.snconv = SpectralNorm(conv)
self.activation = activation
def forward(self,x):
x1 = self.snconv(x)
if self.activation is not None:
x1 = self.activation(x1)
return x1
#3.sn鉴别器网络
class SNDiscriminator(nn.Module):
def __init__(self,in_ch=5,cnum=64):
super(SNDiscriminator,self).__init__()
disconv_layer = OrderedDict()
disconv_layer['conv1'] = SN_Conv(in_ch=in_ch,out_ch=cnum,ksize=5,stride=2)
disconv_layer['conv2'] = SN_Conv(in_ch=cnum, out_ch=2*cnum, ksize=5, stride=2)
disconv_layer['conv3'] = SN_Conv(in_ch=2*cnum, out_ch=4*cnum, ksize=5, stride=2)
disconv_layer['conv4'] = SN_Conv(in_ch=4 * cnum, out_ch=4 * cnum, ksize=5, stride=2)
disconv_layer['conv5'] = SN_Conv(in_ch=4 * cnum, out_ch=4 * cnum, ksize=5, stride=2)
disconv_layer['conv6'] = SN_Conv(in_ch=4 * cnum, out_ch=4 * cnum, ksize=5, stride=2)
self.dislayer = nn.Sequential(disconv_layer)
def forward(self,x):
x1 = self.dislayer(x)
#print(x1.shape)
out = x1.view(x1.shape[0],-1)
return out
整个修复网络分为两个阶段(粗阶段和细化阶段),卷积部分都采用了门控卷积:
#1.粗阶段,输入是5通道(破损图片3,掩码mask,用户指导草图),输出为3通道
class CoarseNet(nn.Module):
def __init__(self,in_ch=5,cnum=48):
super(CoarseNet,self).__init__()
self.conv1 = Gated_Conv(in_ch=in_ch,out_ch=cnum,ksize=5)
self.conv2_down = Gated_Conv(in_ch=cnum,out_ch=2*cnum,stride=2)
self.conv3 = Gated_Conv(in_ch=2*cnum,out_ch=2*cnum)
self.conv4_down = Gated_Conv(in_ch=2*cnum,out_ch=4*cnum,stride=2)
self.conv5 = Gated_Conv(in_ch=4*cnum,out_ch=4*cnum)
self.conv6 = Gated_Conv(in_ch=4*cnum,out_ch=4*cnum)
self.conv7 = Gated_Conv(in_ch=4*cnum,out_ch=4*cnum,rate=2)
self.conv8 = Gated_Conv(in_ch=4 * cnum, out_ch=4 * cnum, rate=4)
self.conv9 = Gated_Conv(in_ch=4 * cnum, out_ch=4 * cnum, rate=8)
self.conv10 = Gated_Conv(in_ch=4 * cnum, out_ch=4 * cnum, rate=16)
self.conv11 = Gated_Conv(in_ch=4 * cnum, out_ch=4 * cnum)
self.conv12 = Gated_Conv(in_ch=4 * cnum, out_ch=4 * cnum)
self.conv13_up = Gated_Deconv(in_ch=4*cnum,out_ch=2*cnum)
self.conv14 = Gated_Conv(in_ch=2*cnum,out_ch=2*cnum)
self.conv15_up = Gated_Deconv(in_ch=2*cnum,out_ch=cnum)
self.conv16 = Gated_Conv(in_ch=cnum,out_ch=cnum//2)
self.conv17 = nn.Conv2d(in_channels=cnum//2,out_channels=3,kernel_size=3,stride=1,padding=1)
def forward(self,x):
x1 = self.conv1(x)
x2 = self.conv2_down(x1)
x3 = self.conv3(x2)
x4 = self.conv4_down(x3)
x5 = self.conv5(x4)
x6 = self.conv6(x5)
x7 = self.conv7(x6)
x8 = self.conv8(x7)
x9 = self.conv9(x8)
x10 = self.conv10(x9)
x11 = self.conv11(x10)
x12 = self.conv12(x11)
x13 = self.conv13_up(x12)
x14 = self.conv14(x13)
x15 = self.conv15_up(x14)
x16 = self.conv16(x15)
x17 = self.conv17(x16)
x_stage1 = F.tanh(x17)
return x_stage1
#2,细化阶段的输入为粗阶段的输出结果,该阶段有两个分支(卷积分支和注意力机制分支)
class RefineNet(nn.Module):
def __init__(self,in_ch=3,cnum=48):
super(RefineNet,self).__init__()
#1.conv branch
xconv_layer = OrderedDict()
xconv_layer['xconv1'] = Gated_Conv(in_ch=in_ch,out_ch=cnum,ksize=5)
xconv_layer['xconv2_down'] = Gated_Conv(in_ch=cnum,out_ch=cnum,stride=2)
xconv_layer['xconv3'] = Gated_Conv(in_ch=cnum,out_ch=2*cnum)
xconv_layer['xconv4_down'] = Gated_Conv(in_ch=2*cnum,out_ch=2*cnum,stride=2)
xconv_layer['xconv5'] = Gated_Conv(in_ch=2*cnum,out_ch=4*cnum)
xconv_layer['xconv6'] = Gated_Conv(in_ch=4*cnum,out_ch=4*cnum)
xconv_layer['xconv7_atrous'] = Gated_Conv(in_ch=4*cnum,out_ch=4*cnum,rate=2)
xconv_layer['xconv8_atrous'] = Gated_Conv(in_ch=4 * cnum, out_ch=4 * cnum, rate=4)
xconv_layer['xconv9_atrous'] = Gated_Conv(in_ch=4 * cnum, out_ch=4 * cnum, rate=8)
xconv_layer['xconv10_atrous'] = Gated_Conv(in_ch=4 * cnum, out_ch=4 * cnum, rate=16)
self.xlayer = nn.Sequential(xconv_layer)
#2.attention brach
pmconv_layer1 = OrderedDict()
pmconv_layer1['pmconv1'] = Gated_Conv(in_ch=in_ch,out_ch=cnum,ksize=5)
pmconv_layer1['pmconv2_down'] = Gated_Conv(in_ch=cnum,out_ch=cnum,stride=2)
pmconv_layer1['pmconv3'] = Gated_Conv(in_ch=cnum,out_ch=2*cnum)
pmconv_layer1['pmconv4_down'] = Gated_Conv(in_ch=2*cnum, out_ch=4*cnum, stride=2)
pmconv_layer1['pmconv5'] = Gated_Conv(in_ch=4*cnum,out_ch=4*cnum)
pmconv_layer1['pmconv6'] = Gated_Conv(in_ch=4 * cnum, out_ch=4 * cnum,activation=nn.ReLU())
self.pmlayer1 = nn.Sequential(pmconv_layer1)
self.CA = Contextual_Attention(rate=2)
pmconv_layer2 = OrderedDict()
pmconv_layer2['pmconv9'] = Gated_Conv(in_ch=4*cnum,out_ch=4*cnum)
pmconv_layer2['pmconv10'] = Gated_Conv(in_ch=4*cnum,out_ch=4*cnum)
self.pmlayer2 = nn.Sequential(pmconv_layer2)
#confluent branch
allconv_layer = OrderedDict()
allconv_layer['allconv11'] = Gated_Conv(in_ch=8*cnum,out_ch=4*cnum)
allconv_layer['allconv12'] = Gated_Conv(in_ch=4 * cnum, out_ch=4 * cnum)
allconv_layer['allconv13_up'] = Gated_Deconv(in_ch=4 * cnum, out_ch=2 * cnum)
allconv_layer['allconv14'] = Gated_Conv(in_ch=2 * cnum, out_ch=2 * cnum)
allconv_layer['allconv15_up'] = Gated_Deconv(in_ch=2 * cnum, out_ch=cnum)
allconv_layer['allconv16'] = Gated_Conv(in_ch=cnum, out_ch=cnum//2)
allconv_layer['allconv17'] = nn.Conv2d(in_channels=cnum//2,out_channels=3,kernel_size=3,padding=1)
allconv_layer['tanh'] = nn.Tanh()
self.colayer = nn.Sequential(allconv_layer)
def forward(self, xin, mask):
x1 = self.xlayer(xin)
x_hallu = x1
x2 = self.pmlayer1(xin)
mask_s = self.resize_mask_like(mask,x2)
x3,offset_flow = self.CA(x2,x2,mask_s)
x4 = self.pmlayer2(x3)
pm = x4
x5 = torch.cat((x_hallu,pm),dim=1)
x6 = self.colayer(x5)
x_stage2 = x6
return x_stage2,offset_flow
def resize_mask_like(self,mask,x):
sizeh = x.shape[2]
sizew = x.shape[3]
return down_sample(mask,size=(sizeh,sizew))
#3.完整的修复网络
class CAGenerator(nn.Module):
def __init__(self,in_ch=5,cnum=48,):
super(CAGenerator,self).__init__()
self.stage_1 = CoarseNet(in_ch=in_ch,cnum=cnum)
self.stage_2 = RefineNet(in_ch=3,cnum=cnum)
def forward(self,xin,mask):
stage1_out = self.stage_1(xin)
stage2_in = stage1_out * mask + xin[:,0:3,:,:] * (1. - mask)
stage2_out,offset_flow = self.stage_2(stage2_in,mask)
return stage1_out,stage2_out,offset_flow
作者提出了一种基于端到端生成网络的新型自由形式图像修复系统,该网络具有门控卷积,并经过逐像素l1损失和SN-patchGAN训练。证明门控卷积改善了修复的质量。