github代码:https://github.com/tgeorgy/mgan
文章的创新点:
1.生成网络输入x,输出包括分割模板mask,和中间图像y,根据mask将输入x与中间图像y结合,得到生成图像.这样得到的生成图像背景与输入x相同,前景为生成部分.
2.采用端到端训练,在cyclegan损失函数的基础上,添加了对输出生成图像进行约束.
模型结构如下,
生成网络首先输出为分割模板mask,以及中间图像y,将中间图像y和mask混合,得到的输出作为最后的生成生成图像.生成网络代码如下,
class Generator(nn.Module):
def __init__(self, input_nc=3, output_nc=4, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6):
assert(n_blocks >= 0)
super(Generator, self).__init__()
self.input_nc = input_nc
self.output_nc = output_nc
self.ngf = ngf
model = [nn.ReflectionPad2d(3),
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0),
norm_layer(ngf),
nn.ReLU(True)]
n_downsampling = 2
for i in range(n_downsampling):
mult = 2**i
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
stride=2, padding=1),
norm_layer(ngf * mult * 2),
nn.ReLU(True)]
mult = 2**n_downsampling
for i in range(n_blocks):
model += [ResnetBlock(ngf * mult, norm_layer=norm_layer, use_dropout=use_dropout)]
for i in range(n_downsampling):
mult = 2**(n_downsampling - i)
model += [nn.ReflectionPad2d(1),
nn.Conv2d(ngf * mult, int(ngf * mult / 2),
kernel_size=3, stride=1),
norm_layer(int(ngf * mult / 2)),
nn.ReLU(True),
nn.Conv2d(int(ngf * mult / 2), int(ngf * mult / 2)*4,
kernel_size=1, stride=1),
nn.PixelShuffle(2),
norm_layer(int(ngf * mult / 2)),
nn.ReLU(True),
]
model += [nn.ReflectionPad2d(3)]
model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
self.model = nn.Sequential(*model)
代码中,生成网络输入通道为3,输出通道为4,第一个通道为mask,其他三个通道为中间生成图像.
def forward(self, input):
output = self.model(input)
mask = F.sigmoid(output[:, :1])
oimg = output[:, 1:]
mask = mask.repeat(1, 3, 1, 1)
oimg = oimg*mask + input*(1-mask)
return oimg, mask
采用cyclegan结构,也就是,包含两个生成网络,两个判别网络.
对于每个生成网络,损失函数包括三个部分,第一个为loss_P2N_cyc ,与cyclegan loss相同,即输入到生成网络g1的输出,在输入生成网络g2,得到输出与输入尽量相同.第二个loss_P2N_gan为gan损失函数,也就是判别网络判断label为真.第三个为loss_N2P_idnt,也就是生成网路g1的输出与label尽量相似,也就是文章是end to end(输入-label对应)训练,由于cyclegan不是end to end,所以没有这个损失函数,
criterion_cycle = nn.L1Loss()
criterion_identity = nn.L1Loss()
criterion_gan = nn.MSELoss()
# Train P2N Generator
real_pos_v = Variable(real_pos)
fake_neg, mask_neg = netP2N(real_pos_v)
rec_pos, _ = netN2P(fake_neg)
fake_neg_lbl = netDN(fake_neg)
loss_P2N_cyc = criterion_cycle(rec_pos, real_pos_v)
loss_P2N_gan = criterion_gan(fake_neg_lbl, Variable(real_lbl))
loss_N2P_idnt = criterion_identity(fake_neg, real_pos_v)
# Train N2P Generator
real_neg_v = Variable(real_neg)
fake_pos, mask_pos = netN2P(real_neg_v)
rec_neg, _ = netP2N(fake_pos)
fake_pos_lbl = netDP(fake_pos)
loss_N2P_cyc = criterion_cycle(rec_neg, real_neg_v)
loss_N2P_gan = criterion_gan(fake_pos_lbl, Variable(real_lbl))
loss_P2N_idnt = criterion_identity(fake_pos, real_neg_v)
loss_G = ((loss_P2N_gan + loss_N2P_gan)*0.5 +
(loss_P2N_cyc + loss_N2P_cyc)*lambda_cycle +
(loss_P2N_idnt + loss_N2P_idnt)*lambda_identity)
判别网络用于判别输入的真假,
# Train Discriminators
netDN.zero_grad()
netDP.zero_grad()
fake_neg_score = netDN(fake_neg.detach())
loss_D = criterion_gan(fake_neg_score, Variable(fake_lbl))
fake_pos_score = netDP(fake_pos.detach())
loss_D += criterion_gan(fake_pos_score, Variable(fake_lbl))
real_neg_score = netDN.forward(real_neg_v)
loss_D += criterion_gan(real_neg_score, Variable(real_lbl))
real_pos_score = netDP.forward(real_pos_v)
loss_D += criterion_gan(real_pos_score, Variable(real_lbl))