[pytorch] WACA论文,photo2sketch的复现

WACA论文,photo2sketch的复现:让你的照片自动变成简笔画

  • 论文idea
    • 数据集
    • 网络结构
    • 损失函数
  • 代码复现
    • networks
      • Unet generator
        • unet module
        • build unet
      • patchGAN discriminator
        • build net
      • network utilities
    • Loss function
  • 结果分析

[pytorch] WACA论文,photo2sketch的复现_第1张图片
生成对抗网络(GAN)虽然近年来热度有所下降,不过它展现了神经网络所拥有的巨大创造潜力。这篇被WACA收录的论文 Photo-Sketching:
Inferring Contour Drawings from Images就通过生成对抗网络的框架,成功将图片转化为简笔画。

论文idea

与轮廓检测不同的是,简笔画(countour drawing)需要画出主体事物内部的重要特征;而边缘检测则对内部特征的处理太过繁琐,忽略了简笔画的艺术性。
[pytorch] WACA论文,photo2sketch的复现_第2张图片
作者Mengtian Li, Zhe Lin, Radomır Mech, Ersin Yumer, Deva Ramanan为了解决这一问题,将图片之简笔画的变化看成两种不同风格的图像的转换,用pix2pix的框架巧妙解决了这一问题。

数据集

作者使用Amazon Mechanical Turk收集用户描绘的简笔画,数据库中一共有1000张原始图片,每张选取5个质量比较好的,一共有5000张草稿。草稿还分为线条宽度为1,3,5的不同图片,可以进行选择。
[pytorch] WACA论文,photo2sketch的复现_第3张图片
数据集的下载地址:http://www.cs.cmu.edu/~mengtial/proj/sketch/
数据集中第一张图片的原图和对应的三种线条宽度的草稿图如下:

[pytorch] WACA论文,photo2sketch的复现_第4张图片原图
[pytorch] WACA论文,photo2sketch的复现_第5张图片width-1
[pytorch] WACA论文,photo2sketch的复现_第6张图片width-3
[pytorch] WACA论文,photo2sketch的复现_第7张图片width-5

网络结构

[pytorch] WACA论文,photo2sketch的复现_第8张图片
模型的网络结构还是采用的cGAN的框架,即对下面的表达式进行优化:
L c G A N ( x , y , z ) = min ⁡ G max ⁡ D E x , y [ log ⁡ D ( x , y ) ] + E x , z [ log ⁡ 1 − D ( x , G ( x , z ) ) ] \mathcal L_{cGAN}(x,y,z)=\min_G \max_D \mathbb E_{x,y}[\log {D(x,y)}]+\mathbb E_{x,z}[\log {1-D(x,G(x,z))}] LcGAN(x,y,z)=GminDmaxEx,y[logD(x,y)]+Ex,z[log1D(x,G(x,z))]其中 x x x是条件, y y y是映射之后的图片, z z z是原始输入的噪声。不过一般情况下, z z z是可以忽略的,即:
L c G A N ( x , y ) = min ⁡ G max ⁡ D E x , y [ log ⁡ D ( x , y ) ] + E x [ log ⁡ 1 − D ( x , G ( x ) ) ] \mathcal L_{cGAN}(x,y)=\min_G \max_D \mathbb E_{x,y}[\log {D(x,y)}]+\mathbb E_{x}[\log {1-D(x,G(x))}] LcGAN(x,y)=GminDmaxEx,y[logD(x,y)]+Ex[log1D(x,G(x))]

损失函数

根据草稿图稀疏性的特征,模型使用了 L 1 \mathcal L_1 L1损失函数。这里不用 L 2 \mathcal L_2 L2的原因是, L 2 \mathcal L_2 L2损失函数表示的是均值,这会大大模糊生成的草稿图。 L 1 \mathcal L_1 L1损失函数的数学意义是中位数,可以选择较好的一个结果输出。
L c = λ L c G A N ( x , y ) + L 1 ( x , y ) \mathcal L_c =\lambda\mathcal L_{cGAN}(x,y)+\mathcal L_1(x,y) Lc=λLcGAN(x,y)+L1(x,y)TIP:同样的点在noise2noise论文中也有提到,我觉得非常有意思:
[pytorch] WACA论文,photo2sketch的复现_第9张图片
[pytorch] WACA论文,photo2sketch的复现_第10张图片
同时由于数据集比较特殊,即一个输入有多个合理的输出草图,所以作者定一个一个全新的损失函数,他们称之为MM-loss(Min-Mean-loss)。
假设输入的图像为 x i x_i xi,不同的输出目标为 ( y i ( 1 ) , y i ( 2 ) , . . . , y i ( M i ) ) (y_i^{(1)}, y_i^{(2)}, ...,y_i^{(M_i)}) (yi(1),yi(2),...,yi(Mi))。普通的训练策略是将顺序打乱,同时把 ( x 1 , y 1 ( 1 ) ) , ( x 1 , y 1 ( 2 ) ) , . . . , ( x N , y 1 ( M ) ) (x_1,y_1^{(1)}),(x_1,y_1^{(2)}),...,(x_N,y_1^{(M)}) (x1,y1(1)),(x1,y1(2)),...,(xN,y1(M))分别看成训练的对。然而本文采用的方法是,把数据 ( x i , y i ( 1 ) , y i ( 2 ) , . . . , y i ( M i ) ) (x_i,y_i^{(1)}, y_i^{(2)}, ...,y_i^{(M_i)}) (xi,yi(1),yi(2),...,yi(Mi))看成单独的一条,并定义如下的损失函数:
L ( x i , y i ( 1 ) , y i ( 2 ) , . . . , y i ( M i ) ) = λ M ∑ j = 1 M i L c G A N ( x i , y i ( j ) ) + min ⁡ j ∈ 1 , . . . , M i L 1 ( x i , y i ( j ) ) \mathcal L(x_i,y_i^{(1)}, y_i^{(2)}, ...,y_i^{(M_i)})=\frac{\lambda}{M} \sum_{j=1}^{M_i}{\mathcal L_{cGAN}{(x_i,y_i^{(j)})}}+\min_{j\in{1,...,M_i}}L_{1}{(x_i,y_i^{(j)})} L(xi,yi(1),yi(2),...,yi(Mi))=Mλj=1MiLcGAN(xi,yi(j))+j1,...,MiminL1(xi,yi(j))上述的损失函数中,第一项平均值让生成器平等的对待所有的目标草稿图;第二项最小值则让生成器生成一个效果最好的,防止模糊图片的出现。
生成器的结构采用Resnet,不像其他GAN采用Unet的原因是,Unet中的skip-connection模块会让生成器选择保留比较低级的特征,比如纹理等等。判别器也不采用PatchGAN的结构而是传统的判别器结构,因为PatchGAN将图片分为几个区块,影响了草图整体结构的生成。(虽然我的复现用的还是Unet和PatchGAN,。之后有时间会再修改)

代码复现

networks

Unet generator

Unet的设计思路参照的是https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix,即把U型的底部作为子模块,不断的迭代生长。

unet module

class unetModule(nn.Module):
    def __init__(self, input_nc, inner_nc, output_nc=None, sub_module=None, is_outest=False):
        super(unetModule, self).__init__()
        self.is_outest = is_outest
        
        conv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, stride=2, padding=1)
        
        down_norm = nn.BatchNorm2d(inner_nc)
        up_norm = nn.BatchNorm2d(input_nc)
        
        down_relu = nn.LeakyReLU(0.2, True)
        up_relu = nn.ReLU(True)
        tanh = nn.Tanh()
        
        if is_outest:
            assert(output_nc != None)
            
            convT = nn.ConvTranspose2d(inner_nc * 2, output_nc, kernel_size=4, stride=2, padding=1)
            up = [convT] + [tanh]
            
            down = [conv] + [down_relu]
        elif sub_module:
            convT = nn.ConvTranspose2d(inner_nc * 2, input_nc, kernel_size=4, stride=2, padding=1)
            up = [convT] + [up_norm] + [up_relu]
            
            down = [conv] + [down_norm] + [down_relu]
        else:
            convT = nn.ConvTranspose2d(inner_nc, input_nc, kernel_size=4, stride=2, padding=1)
            up = [convT] + [up_norm] + [up_relu]
            
            down = [conv] + [down_relu]
            
        if sub_module:
            model = down + [sub_module] + up
        else:
            model = down + up
        self.model = nn.Sequential(*model)
    
    def forward(self, x):
        if self.is_outest:
            return self.model(x)
        else:
            return torch.cat((x, self.model(x)), 1)

inner_test = torch.ones(16, 256, 4, 4)
unetModule_test = unetModule(256, 512)
rslt = unetModule_test(inner_test)
rslt.shape

OUT:

torch.Size([16, 512, 4, 4])

build unet

class unetG(nn.Module):
    def __init__(self, input_nc, output_nc, num_wrapper, first_nc=64):
        super(unetG, self).__init__()
        
        unet_submodule = unetModule(first_nc * 2**3, first_nc * 2**3)
        for _ in range(num_wrapper-4):
            unet_submodule = unetModule(first_nc * 2**3, first_nc * 2**3, sub_module=unet_submodule)
        
        unet_submodule = unetModule(first_nc * 2**2, first_nc * 2**3, sub_module=unet_submodule)
        unet_submodule = unetModule(first_nc * 2**1, first_nc * 2**2, sub_module=unet_submodule)
        unet_submodule = unetModule(first_nc, first_nc * 2**1, sub_module=unet_submodule)
        unet_submodule = unetModule(input_nc, first_nc, sub_module=unet_submodule, 
                                    output_nc=output_nc, is_outest=True)
        
        self.model = unet_submodule
        
    def forward(self, x):
        return self.model(x)

x_test = torch.ones(16, 3, 256, 256)
unetG_test = unetG(3, 1, num_wrapper=7)
rslt = unetG_test(x_test)
print(rslt.shape)

OUT

torch.Size([16, 1, 256, 256])

patchGAN discriminator

PatchGAN的思想其实比较简单,就是将图片分为若干个部分,对每个部分进行真假判断后在组合。有关它的理解可以看这里:https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/39

build net

class patchDiscriminator(nn.Module):
    def __init__(self, input_nc, first_nc=64, num_layers=3):
        super(patchDiscriminator, self).__init__()
        
        model = [nn.Conv2d(input_nc, first_nc, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, True)]
        for i_layer in range(num_layers-1):
            conv = nn.Conv2d(first_nc * 2**i_layer, first_nc * 2**(i_layer+1), kernel_size=4, stride=2, padding=1)
            batch_norm = nn.BatchNorm2d(first_nc * 2**(i_layer+1))
            relu = nn.LeakyReLU(0.2, True)
            model.extend([conv]+[batch_norm]+[relu])
        
        conv = nn.Conv2d(first_nc * 4, first_nc * 8, kernel_size=4, stride=1, padding=1)
        batch_norm = nn.BatchNorm2d(first_nc * 8)
        relu = nn.LeakyReLU(0.2, True)
        model.extend([conv]+[batch_norm]+[relu])
        
        conv = nn.Conv2d(first_nc * 8, 1, kernel_size=4, stride=1, padding=1)
        model.append(conv)
        
        self.model = nn.Sequential(*model)
    
    
    def forward(self, x):
        return self.model(x)

x_test = torch.ones(16, 3, 512, 512)
test_patchDiscriminator = patchDiscriminator(3)
rslt = test_patchDiscriminator(x_test)
rslt.shape

OUT

torch.Size([16, 1, 62, 62])

network utilities

def init_weight(network):
    def weights_init_normal(m):
        classname = m.__class__.__name__
        # print(classname)
        if classname.find('Conv') != -1:
            init.normal_(m.weight.data, 0.0, 0.02)
        elif classname.find('Linear') != -1:
            init.normal_(m.weight.data, 0.0, 0.02)
        elif classname.find('BatchNorm2d') != -1:
            init.normal_(m.weight.data, 1.0, 0.02)
            init.constant_(m.bias.data, 0.0)
    network.apply(weights_init_normal)
    
# test_init = init_weight(unetG_test)

Loss function

class baseGANLoss:
    def __init__(self, gan_type='BCE', device='cuda'):
        self.device = device
        self.loss = nn.BCEWithLogitsLoss().to(device)
    
    def __call__(self, prediction, is_real):
        if is_real:
            target_tensor = torch.ones(prediction.shape).to(self.device)
        else:
            target_tensor = torch.zeros(prediction.shape).to(self.device)
        return self.loss(prediction, target_tensor)

上述是比较关键的代码,由于篇幅原因,我把代码放在我的github上了;也可以登陆Google Colab查看我的训练结果。

结果分析

[pytorch] WACA论文,photo2sketch的复现_第11张图片epoch 2
[pytorch] WACA论文,photo2sketch的复现_第12张图片epoch 50
[pytorch] WACA论文,photo2sketch的复现_第13张图片epoch 100
[pytorch] WACA论文,photo2sketch的复现_第14张图片epoch 150
[pytorch] WACA论文,photo2sketch的复现_第15张图片true output
可以看到随着训练进行,简笔画与真实值相比不断接近。但是可能还是因为用了unet的原因,一些比较低级的纹理被保留了下来,导致图片有些多余的线条。

不过这与论文作者的训练结果相比,还是差了很远,以后有时间会继续改进。

你可能感兴趣的:(pytorch,GAN,深度学习,神经网络,python)