与轮廓检测不同的是,简笔画(countour drawing)需要画出主体事物内部的重要特征;而边缘检测则对内部特征的处理太过繁琐,忽略了简笔画的艺术性。
作者Mengtian Li, Zhe Lin, Radomır Mech, Ersin Yumer, Deva Ramanan为了解决这一问题,将图片之简笔画的变化看成两种不同风格的图像的转换,用pix2pix的框架巧妙解决了这一问题。
作者使用Amazon Mechanical Turk收集用户描绘的简笔画,数据库中一共有1000张原始图片,每张选取5个质量比较好的,一共有5000张草稿。草稿还分为线条宽度为1,3,5的不同图片,可以进行选择。
数据集的下载地址:http://www.cs.cmu.edu/~mengtial/proj/sketch/
数据集中第一张图片的原图和对应的三种线条宽度的草稿图如下:
|
|
|
|
模型的网络结构还是采用的cGAN的框架,即对下面的表达式进行优化:
L c G A N ( x , y , z ) = min G max D E x , y [ log D ( x , y ) ] + E x , z [ log 1 − D ( x , G ( x , z ) ) ] \mathcal L_{cGAN}(x,y,z)=\min_G \max_D \mathbb E_{x,y}[\log {D(x,y)}]+\mathbb E_{x,z}[\log {1-D(x,G(x,z))}] LcGAN(x,y,z)=GminDmaxEx,y[logD(x,y)]+Ex,z[log1−D(x,G(x,z))]其中 x x x是条件, y y y是映射之后的图片, z z z是原始输入的噪声。不过一般情况下, z z z是可以忽略的,即:
L c G A N ( x , y ) = min G max D E x , y [ log D ( x , y ) ] + E x [ log 1 − D ( x , G ( x ) ) ] \mathcal L_{cGAN}(x,y)=\min_G \max_D \mathbb E_{x,y}[\log {D(x,y)}]+\mathbb E_{x}[\log {1-D(x,G(x))}] LcGAN(x,y)=GminDmaxEx,y[logD(x,y)]+Ex[log1−D(x,G(x))]
根据草稿图稀疏性的特征,模型使用了 L 1 \mathcal L_1 L1损失函数。这里不用 L 2 \mathcal L_2 L2的原因是, L 2 \mathcal L_2 L2损失函数表示的是均值,这会大大模糊生成的草稿图。 L 1 \mathcal L_1 L1损失函数的数学意义是中位数,可以选择较好的一个结果输出。
L c = λ L c G A N ( x , y ) + L 1 ( x , y ) \mathcal L_c =\lambda\mathcal L_{cGAN}(x,y)+\mathcal L_1(x,y) Lc=λLcGAN(x,y)+L1(x,y)TIP
:同样的点在noise2noise论文中也有提到,我觉得非常有意思:
同时由于数据集比较特殊,即一个输入有多个合理的输出草图,所以作者定一个一个全新的损失函数,他们称之为MM-loss(Min-Mean-loss)。
假设输入的图像为 x i x_i xi,不同的输出目标为 ( y i ( 1 ) , y i ( 2 ) , . . . , y i ( M i ) ) (y_i^{(1)}, y_i^{(2)}, ...,y_i^{(M_i)}) (yi(1),yi(2),...,yi(Mi))。普通的训练策略是将顺序打乱,同时把 ( x 1 , y 1 ( 1 ) ) , ( x 1 , y 1 ( 2 ) ) , . . . , ( x N , y 1 ( M ) ) (x_1,y_1^{(1)}),(x_1,y_1^{(2)}),...,(x_N,y_1^{(M)}) (x1,y1(1)),(x1,y1(2)),...,(xN,y1(M))分别看成训练的对。然而本文采用的方法是,把数据 ( x i , y i ( 1 ) , y i ( 2 ) , . . . , y i ( M i ) ) (x_i,y_i^{(1)}, y_i^{(2)}, ...,y_i^{(M_i)}) (xi,yi(1),yi(2),...,yi(Mi))看成单独的一条,并定义如下的损失函数:
L ( x i , y i ( 1 ) , y i ( 2 ) , . . . , y i ( M i ) ) = λ M ∑ j = 1 M i L c G A N ( x i , y i ( j ) ) + min j ∈ 1 , . . . , M i L 1 ( x i , y i ( j ) ) \mathcal L(x_i,y_i^{(1)}, y_i^{(2)}, ...,y_i^{(M_i)})=\frac{\lambda}{M} \sum_{j=1}^{M_i}{\mathcal L_{cGAN}{(x_i,y_i^{(j)})}}+\min_{j\in{1,...,M_i}}L_{1}{(x_i,y_i^{(j)})} L(xi,yi(1),yi(2),...,yi(Mi))=Mλj=1∑MiLcGAN(xi,yi(j))+j∈1,...,MiminL1(xi,yi(j))上述的损失函数中,第一项平均值让生成器平等的对待所有的目标草稿图;第二项最小值则让生成器生成一个效果最好的,防止模糊图片的出现。
生成器的结构采用Resnet,不像其他GAN采用Unet的原因是,Unet中的skip-connection模块会让生成器选择保留比较低级的特征,比如纹理等等。判别器也不采用PatchGAN的结构而是传统的判别器结构,因为PatchGAN将图片分为几个区块,影响了草图整体结构的生成。(虽然我的复现用的还是Unet和PatchGAN,。之后有时间会再修改)
Unet的设计思路参照的是https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix,即把U型的底部作为子模块,不断的迭代生长。
class unetModule(nn.Module):
def __init__(self, input_nc, inner_nc, output_nc=None, sub_module=None, is_outest=False):
super(unetModule, self).__init__()
self.is_outest = is_outest
conv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, stride=2, padding=1)
down_norm = nn.BatchNorm2d(inner_nc)
up_norm = nn.BatchNorm2d(input_nc)
down_relu = nn.LeakyReLU(0.2, True)
up_relu = nn.ReLU(True)
tanh = nn.Tanh()
if is_outest:
assert(output_nc != None)
convT = nn.ConvTranspose2d(inner_nc * 2, output_nc, kernel_size=4, stride=2, padding=1)
up = [convT] + [tanh]
down = [conv] + [down_relu]
elif sub_module:
convT = nn.ConvTranspose2d(inner_nc * 2, input_nc, kernel_size=4, stride=2, padding=1)
up = [convT] + [up_norm] + [up_relu]
down = [conv] + [down_norm] + [down_relu]
else:
convT = nn.ConvTranspose2d(inner_nc, input_nc, kernel_size=4, stride=2, padding=1)
up = [convT] + [up_norm] + [up_relu]
down = [conv] + [down_relu]
if sub_module:
model = down + [sub_module] + up
else:
model = down + up
self.model = nn.Sequential(*model)
def forward(self, x):
if self.is_outest:
return self.model(x)
else:
return torch.cat((x, self.model(x)), 1)
inner_test = torch.ones(16, 256, 4, 4)
unetModule_test = unetModule(256, 512)
rslt = unetModule_test(inner_test)
rslt.shape
OUT:
torch.Size([16, 512, 4, 4])
class unetG(nn.Module):
def __init__(self, input_nc, output_nc, num_wrapper, first_nc=64):
super(unetG, self).__init__()
unet_submodule = unetModule(first_nc * 2**3, first_nc * 2**3)
for _ in range(num_wrapper-4):
unet_submodule = unetModule(first_nc * 2**3, first_nc * 2**3, sub_module=unet_submodule)
unet_submodule = unetModule(first_nc * 2**2, first_nc * 2**3, sub_module=unet_submodule)
unet_submodule = unetModule(first_nc * 2**1, first_nc * 2**2, sub_module=unet_submodule)
unet_submodule = unetModule(first_nc, first_nc * 2**1, sub_module=unet_submodule)
unet_submodule = unetModule(input_nc, first_nc, sub_module=unet_submodule,
output_nc=output_nc, is_outest=True)
self.model = unet_submodule
def forward(self, x):
return self.model(x)
x_test = torch.ones(16, 3, 256, 256)
unetG_test = unetG(3, 1, num_wrapper=7)
rslt = unetG_test(x_test)
print(rslt.shape)
OUT
torch.Size([16, 1, 256, 256])
PatchGAN的思想其实比较简单,就是将图片分为若干个部分,对每个部分进行真假判断后在组合。有关它的理解可以看这里:https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/39
class patchDiscriminator(nn.Module):
def __init__(self, input_nc, first_nc=64, num_layers=3):
super(patchDiscriminator, self).__init__()
model = [nn.Conv2d(input_nc, first_nc, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, True)]
for i_layer in range(num_layers-1):
conv = nn.Conv2d(first_nc * 2**i_layer, first_nc * 2**(i_layer+1), kernel_size=4, stride=2, padding=1)
batch_norm = nn.BatchNorm2d(first_nc * 2**(i_layer+1))
relu = nn.LeakyReLU(0.2, True)
model.extend([conv]+[batch_norm]+[relu])
conv = nn.Conv2d(first_nc * 4, first_nc * 8, kernel_size=4, stride=1, padding=1)
batch_norm = nn.BatchNorm2d(first_nc * 8)
relu = nn.LeakyReLU(0.2, True)
model.extend([conv]+[batch_norm]+[relu])
conv = nn.Conv2d(first_nc * 8, 1, kernel_size=4, stride=1, padding=1)
model.append(conv)
self.model = nn.Sequential(*model)
def forward(self, x):
return self.model(x)
x_test = torch.ones(16, 3, 512, 512)
test_patchDiscriminator = patchDiscriminator(3)
rslt = test_patchDiscriminator(x_test)
rslt.shape
OUT
torch.Size([16, 1, 62, 62])
def init_weight(network):
def weights_init_normal(m):
classname = m.__class__.__name__
# print(classname)
if classname.find('Conv') != -1:
init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('Linear') != -1:
init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm2d') != -1:
init.normal_(m.weight.data, 1.0, 0.02)
init.constant_(m.bias.data, 0.0)
network.apply(weights_init_normal)
# test_init = init_weight(unetG_test)
class baseGANLoss:
def __init__(self, gan_type='BCE', device='cuda'):
self.device = device
self.loss = nn.BCEWithLogitsLoss().to(device)
def __call__(self, prediction, is_real):
if is_real:
target_tensor = torch.ones(prediction.shape).to(self.device)
else:
target_tensor = torch.zeros(prediction.shape).to(self.device)
return self.loss(prediction, target_tensor)
上述是比较关键的代码,由于篇幅原因,我把代码放在我的github上了;也可以登陆Google Colab查看我的训练结果。
|
|
|
|
|
不过这与论文作者的训练结果相比,还是差了很远,以后有时间会继续改进。