基于pytorch的cyclegan,实现风格转换

目录

1、概述

 2、网络结构

2.1、生成器部分的代码:

2.2、判别器部分的代码

3、代码(持续更新)


论文名称:Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks

原文地址:https://arxiv.org/abs/1703.10593

1、概述

        cyclegan用于实现两种不同风格的图像之间的相互转换,其特点是不需要训练数据为一一对应的成对存在,而是只需要准备两个领域的数据集即可,比如说普通马的图片和斑马的图片。经过训练可以实现如下图所示的风格转换:

 2、网络结构

其训练时的网络结构如下图所示:

基于pytorch的cyclegan,实现风格转换_第1张图片

         可以看出,cyclegan网络中一共包含了两个生成器和两个判别器,分别用于两个领域的数据的生成和判断。

2.1、生成器部分的代码:

import torch.nn as nn
from torchsummary import summary
from collections import OrderedDict


# 定义残差块
class Resnet_block(nn.Module):
    def __init__(self, in_channels):
        super(Resnet_block, self).__init__()
        block = []
        for i in range(2):
            block += [nn.ReflectionPad2d(1),
                      nn.Conv2d(in_channels, in_channels, 3, 1, 0),
                      nn.InstanceNorm2d(in_channels),
                      nn.ReLU(True) if i > 0 else nn.Identity()]
        self.block = nn.Sequential(*block)

    def forward(self, x):
        out = x + self.block(x)
        return out


class Cycle_Gan_G(nn.Module):
    def __init__(self):
        super(Cycle_Gan_G, self).__init__()
        net_dic = OrderedDict()
        # 三层卷积层
        net_dic.update({'first layer': nn.Sequential(
            nn.ReflectionPad2d(3),  # [3,256,256]  ->  [3,262,262]
            nn.Conv2d(3, 64, 7, 1),  # [3,262,262]  ->[64,256,256]
            nn.InstanceNorm2d(64),
            nn.ReLU(True)
        )})
        net_dic.update({'second_conv': nn.Sequential(
            nn.Conv2d(64, 128, 3, 2, 1),  # [128,128,128]
            nn.InstanceNorm2d(128),
            nn.ReLU(True)
        )})
        net_dic.update({'three_conv': nn.Sequential(
            nn.Conv2d(128, 256, 3, 2, 1),  # [256,64,64]
            nn.InstanceNorm2d(256),
            nn.ReLU(True)
        )})

        # 9层 resnet block
        for i in range(6):
            net_dic.update({'Resnet_block{}'.format(i + 1): Resnet_block(256)})

        # up_sample
        net_dic.update({'up_sample1': nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(128),  # [128,128,128]
            nn.ReLU(True)
        )})
        net_dic.update({'up_sample2': nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(64),  # [64,256,256]
            nn.ReLU(True)
        )})

        net_dic.update({'last_layer': nn.Sequential(
            nn.ReflectionPad2d(3),
            nn.Conv2d(64, 3, 7, 1),
            nn.Tanh()
        )})

        self.net_G = nn.Sequential(net_dic)
        self.init_weight()

    def init_weight(self):
        for w in self.modules():
            if isinstance(w, nn.Conv2d):
                nn.init.kaiming_normal_(w.weight, mode='fan_out')
                if w.bias is not None:
                    nn.init.zeros_(w.bias)
            elif isinstance(w, nn.ConvTranspose2d):
                nn.init.kaiming_normal_(w.weight, mode='fan_in')
            elif isinstance(w, nn.BatchNorm2d):
                nn.init.ones_(w.weight)
                nn.init.zeros_(w.bias)

    def forward(self, x):
        out = self.net_G(x)
        return out


if __name__ == '__main__':
    G = Cycle_Gan_G().to('cuda')
    summary(G, (3, 256, 256))

其中:残差块的结构如下:

基于pytorch的cyclegan,实现风格转换_第2张图片

 整体网络结构如下:

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
   ReflectionPad2d-1          [-1, 3, 262, 262]               0
            Conv2d-2         [-1, 64, 256, 256]           9,472
    InstanceNorm2d-3         [-1, 64, 256, 256]               0
              ReLU-4         [-1, 64, 256, 256]               0
            Conv2d-5        [-1, 128, 128, 128]          73,856
    InstanceNorm2d-6        [-1, 128, 128, 128]               0
              ReLU-7        [-1, 128, 128, 128]               0
            Conv2d-8          [-1, 256, 64, 64]         295,168
    InstanceNorm2d-9          [-1, 256, 64, 64]               0
             ReLU-10          [-1, 256, 64, 64]               0
  ReflectionPad2d-11          [-1, 256, 66, 66]               0
           Conv2d-12          [-1, 256, 64, 64]         590,080
   InstanceNorm2d-13          [-1, 256, 64, 64]               0
         Identity-14          [-1, 256, 64, 64]               0
  ReflectionPad2d-15          [-1, 256, 66, 66]               0
           Conv2d-16          [-1, 256, 64, 64]         590,080
   InstanceNorm2d-17          [-1, 256, 64, 64]               0
             ReLU-18          [-1, 256, 64, 64]               0
     Resnet_block-19          [-1, 256, 64, 64]               0
  ReflectionPad2d-20          [-1, 256, 66, 66]               0
           Conv2d-21          [-1, 256, 64, 64]         590,080
   InstanceNorm2d-22          [-1, 256, 64, 64]               0
         Identity-23          [-1, 256, 64, 64]               0
  ReflectionPad2d-24          [-1, 256, 66, 66]               0
           Conv2d-25          [-1, 256, 64, 64]         590,080
   InstanceNorm2d-26          [-1, 256, 64, 64]               0
             ReLU-27          [-1, 256, 64, 64]               0
     Resnet_block-28          [-1, 256, 64, 64]               0
  ReflectionPad2d-29          [-1, 256, 66, 66]               0
           Conv2d-30          [-1, 256, 64, 64]         590,080
   InstanceNorm2d-31          [-1, 256, 64, 64]               0
         Identity-32          [-1, 256, 64, 64]               0
  ReflectionPad2d-33          [-1, 256, 66, 66]               0
           Conv2d-34          [-1, 256, 64, 64]         590,080
   InstanceNorm2d-35          [-1, 256, 64, 64]               0
             ReLU-36          [-1, 256, 64, 64]               0
     Resnet_block-37          [-1, 256, 64, 64]               0
  ReflectionPad2d-38          [-1, 256, 66, 66]               0
           Conv2d-39          [-1, 256, 64, 64]         590,080
   InstanceNorm2d-40          [-1, 256, 64, 64]               0
         Identity-41          [-1, 256, 64, 64]               0
  ReflectionPad2d-42          [-1, 256, 66, 66]               0
           Conv2d-43          [-1, 256, 64, 64]         590,080
   InstanceNorm2d-44          [-1, 256, 64, 64]               0
             ReLU-45          [-1, 256, 64, 64]               0
     Resnet_block-46          [-1, 256, 64, 64]               0
  ReflectionPad2d-47          [-1, 256, 66, 66]               0
           Conv2d-48          [-1, 256, 64, 64]         590,080
   InstanceNorm2d-49          [-1, 256, 64, 64]               0
         Identity-50          [-1, 256, 64, 64]               0
  ReflectionPad2d-51          [-1, 256, 66, 66]               0
           Conv2d-52          [-1, 256, 64, 64]         590,080
   InstanceNorm2d-53          [-1, 256, 64, 64]               0
             ReLU-54          [-1, 256, 64, 64]               0
     Resnet_block-55          [-1, 256, 64, 64]               0
  ReflectionPad2d-56          [-1, 256, 66, 66]               0
           Conv2d-57          [-1, 256, 64, 64]         590,080
   InstanceNorm2d-58          [-1, 256, 64, 64]               0
         Identity-59          [-1, 256, 64, 64]               0
  ReflectionPad2d-60          [-1, 256, 66, 66]               0
           Conv2d-61          [-1, 256, 64, 64]         590,080
   InstanceNorm2d-62          [-1, 256, 64, 64]               0
             ReLU-63          [-1, 256, 64, 64]               0
     Resnet_block-64          [-1, 256, 64, 64]               0
  ConvTranspose2d-65        [-1, 128, 128, 128]         295,040
   InstanceNorm2d-66        [-1, 128, 128, 128]               0
             ReLU-67        [-1, 128, 128, 128]               0
  ConvTranspose2d-68         [-1, 64, 256, 256]          73,792
   InstanceNorm2d-69         [-1, 64, 256, 256]               0
             ReLU-70         [-1, 64, 256, 256]               0
  ReflectionPad2d-71         [-1, 64, 262, 262]               0
           Conv2d-72          [-1, 3, 256, 256]           9,411
             Tanh-73          [-1, 3, 256, 256]               0
================================================================
Total params: 7,837,699
Trainable params: 7,837,699
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.75
Forward/backward pass size (MB): 788.18
Params size (MB): 29.90
Estimated Total Size (MB): 818.83
----------------------------------------------------------------

2.2、判别器部分的代码

class Cycle_Gan_D(nn.Module):
    def __init__(self):
        super(Cycle_Gan_D, self).__init__()

        # 定义基本的卷积\bn\relu
        def base_Conv_bn_lkrl(in_channels, out_channels, stride):
            if in_channels == 3:
                bn = nn.Identity
            else:
                bn = nn.InstanceNorm2d
            return nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 4, stride, 1),
                bn(out_channels),
                nn.LeakyReLU(0.2, True)
            )

        D_dic = OrderedDict()
        in_channels = 3
        out_channels = 64
        for i in range(4):
            if i < 3:
                D_dic.update({'layer_{}'.format(i + 1): base_Conv_bn_lkrl(in_channels, out_channels, 2)})
            else:
                D_dic.update({'layer_{}'.format(i + 1): base_Conv_bn_lkrl(in_channels, out_channels, 1)})
            in_channels = out_channels
            out_channels *= 2
        D_dic.update({'last_layer': nn.Conv2d(512, 1, 4, 1, 1)})  # [batch,1,30,30]
        self.D_model = nn.Sequential(D_dic)

    def forward(self, x):
        return self.D_model(x)



if __name__ == '__main__':
    D = Cycle_Gan_D().to('cuda')
    summary(D, (3, 256, 256))

网络整体架构如下:

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 64, 128, 128]           3,136
          Identity-2         [-1, 64, 128, 128]               0
         LeakyReLU-3         [-1, 64, 128, 128]               0
            Conv2d-4          [-1, 128, 64, 64]         131,200
    InstanceNorm2d-5          [-1, 128, 64, 64]               0
         LeakyReLU-6          [-1, 128, 64, 64]               0
            Conv2d-7          [-1, 256, 32, 32]         524,544
    InstanceNorm2d-8          [-1, 256, 32, 32]               0
         LeakyReLU-9          [-1, 256, 32, 32]               0
           Conv2d-10          [-1, 512, 31, 31]       2,097,664
   InstanceNorm2d-11          [-1, 512, 31, 31]               0
        LeakyReLU-12          [-1, 512, 31, 31]               0
           Conv2d-13            [-1, 1, 30, 30]           8,193
================================================================
Total params: 2,764,737
Trainable params: 2,764,737
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.75
Forward/backward pass size (MB): 53.27
Params size (MB): 10.55
Estimated Total Size (MB): 64.57
----------------------------------------------------------------

3、代码(持续更新)

你可能感兴趣的:(Gan生成式对抗网络,pytorch,深度学习,python)