目录
1、概述
2、网络结构
2.1、生成器部分的代码:
2.2、判别器部分的代码
3、代码(持续更新)
论文名称:Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks
原文地址:https://arxiv.org/abs/1703.10593
cyclegan用于实现两种不同风格的图像之间的相互转换,其特点是不需要训练数据为一一对应的成对存在,而是只需要准备两个领域的数据集即可,比如说普通马的图片和斑马的图片。经过训练可以实现如下图所示的风格转换:
其训练时的网络结构如下图所示:
可以看出,cyclegan网络中一共包含了两个生成器和两个判别器,分别用于两个领域的数据的生成和判断。
import torch.nn as nn
from torchsummary import summary
from collections import OrderedDict
# 定义残差块
class Resnet_block(nn.Module):
def __init__(self, in_channels):
super(Resnet_block, self).__init__()
block = []
for i in range(2):
block += [nn.ReflectionPad2d(1),
nn.Conv2d(in_channels, in_channels, 3, 1, 0),
nn.InstanceNorm2d(in_channels),
nn.ReLU(True) if i > 0 else nn.Identity()]
self.block = nn.Sequential(*block)
def forward(self, x):
out = x + self.block(x)
return out
class Cycle_Gan_G(nn.Module):
def __init__(self):
super(Cycle_Gan_G, self).__init__()
net_dic = OrderedDict()
# 三层卷积层
net_dic.update({'first layer': nn.Sequential(
nn.ReflectionPad2d(3), # [3,256,256] -> [3,262,262]
nn.Conv2d(3, 64, 7, 1), # [3,262,262] ->[64,256,256]
nn.InstanceNorm2d(64),
nn.ReLU(True)
)})
net_dic.update({'second_conv': nn.Sequential(
nn.Conv2d(64, 128, 3, 2, 1), # [128,128,128]
nn.InstanceNorm2d(128),
nn.ReLU(True)
)})
net_dic.update({'three_conv': nn.Sequential(
nn.Conv2d(128, 256, 3, 2, 1), # [256,64,64]
nn.InstanceNorm2d(256),
nn.ReLU(True)
)})
# 9层 resnet block
for i in range(6):
net_dic.update({'Resnet_block{}'.format(i + 1): Resnet_block(256)})
# up_sample
net_dic.update({'up_sample1': nn.Sequential(
nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.InstanceNorm2d(128), # [128,128,128]
nn.ReLU(True)
)})
net_dic.update({'up_sample2': nn.Sequential(
nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.InstanceNorm2d(64), # [64,256,256]
nn.ReLU(True)
)})
net_dic.update({'last_layer': nn.Sequential(
nn.ReflectionPad2d(3),
nn.Conv2d(64, 3, 7, 1),
nn.Tanh()
)})
self.net_G = nn.Sequential(net_dic)
self.init_weight()
def init_weight(self):
for w in self.modules():
if isinstance(w, nn.Conv2d):
nn.init.kaiming_normal_(w.weight, mode='fan_out')
if w.bias is not None:
nn.init.zeros_(w.bias)
elif isinstance(w, nn.ConvTranspose2d):
nn.init.kaiming_normal_(w.weight, mode='fan_in')
elif isinstance(w, nn.BatchNorm2d):
nn.init.ones_(w.weight)
nn.init.zeros_(w.bias)
def forward(self, x):
out = self.net_G(x)
return out
if __name__ == '__main__':
G = Cycle_Gan_G().to('cuda')
summary(G, (3, 256, 256))
其中:残差块的结构如下:
整体网络结构如下:
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
ReflectionPad2d-1 [-1, 3, 262, 262] 0
Conv2d-2 [-1, 64, 256, 256] 9,472
InstanceNorm2d-3 [-1, 64, 256, 256] 0
ReLU-4 [-1, 64, 256, 256] 0
Conv2d-5 [-1, 128, 128, 128] 73,856
InstanceNorm2d-6 [-1, 128, 128, 128] 0
ReLU-7 [-1, 128, 128, 128] 0
Conv2d-8 [-1, 256, 64, 64] 295,168
InstanceNorm2d-9 [-1, 256, 64, 64] 0
ReLU-10 [-1, 256, 64, 64] 0
ReflectionPad2d-11 [-1, 256, 66, 66] 0
Conv2d-12 [-1, 256, 64, 64] 590,080
InstanceNorm2d-13 [-1, 256, 64, 64] 0
Identity-14 [-1, 256, 64, 64] 0
ReflectionPad2d-15 [-1, 256, 66, 66] 0
Conv2d-16 [-1, 256, 64, 64] 590,080
InstanceNorm2d-17 [-1, 256, 64, 64] 0
ReLU-18 [-1, 256, 64, 64] 0
Resnet_block-19 [-1, 256, 64, 64] 0
ReflectionPad2d-20 [-1, 256, 66, 66] 0
Conv2d-21 [-1, 256, 64, 64] 590,080
InstanceNorm2d-22 [-1, 256, 64, 64] 0
Identity-23 [-1, 256, 64, 64] 0
ReflectionPad2d-24 [-1, 256, 66, 66] 0
Conv2d-25 [-1, 256, 64, 64] 590,080
InstanceNorm2d-26 [-1, 256, 64, 64] 0
ReLU-27 [-1, 256, 64, 64] 0
Resnet_block-28 [-1, 256, 64, 64] 0
ReflectionPad2d-29 [-1, 256, 66, 66] 0
Conv2d-30 [-1, 256, 64, 64] 590,080
InstanceNorm2d-31 [-1, 256, 64, 64] 0
Identity-32 [-1, 256, 64, 64] 0
ReflectionPad2d-33 [-1, 256, 66, 66] 0
Conv2d-34 [-1, 256, 64, 64] 590,080
InstanceNorm2d-35 [-1, 256, 64, 64] 0
ReLU-36 [-1, 256, 64, 64] 0
Resnet_block-37 [-1, 256, 64, 64] 0
ReflectionPad2d-38 [-1, 256, 66, 66] 0
Conv2d-39 [-1, 256, 64, 64] 590,080
InstanceNorm2d-40 [-1, 256, 64, 64] 0
Identity-41 [-1, 256, 64, 64] 0
ReflectionPad2d-42 [-1, 256, 66, 66] 0
Conv2d-43 [-1, 256, 64, 64] 590,080
InstanceNorm2d-44 [-1, 256, 64, 64] 0
ReLU-45 [-1, 256, 64, 64] 0
Resnet_block-46 [-1, 256, 64, 64] 0
ReflectionPad2d-47 [-1, 256, 66, 66] 0
Conv2d-48 [-1, 256, 64, 64] 590,080
InstanceNorm2d-49 [-1, 256, 64, 64] 0
Identity-50 [-1, 256, 64, 64] 0
ReflectionPad2d-51 [-1, 256, 66, 66] 0
Conv2d-52 [-1, 256, 64, 64] 590,080
InstanceNorm2d-53 [-1, 256, 64, 64] 0
ReLU-54 [-1, 256, 64, 64] 0
Resnet_block-55 [-1, 256, 64, 64] 0
ReflectionPad2d-56 [-1, 256, 66, 66] 0
Conv2d-57 [-1, 256, 64, 64] 590,080
InstanceNorm2d-58 [-1, 256, 64, 64] 0
Identity-59 [-1, 256, 64, 64] 0
ReflectionPad2d-60 [-1, 256, 66, 66] 0
Conv2d-61 [-1, 256, 64, 64] 590,080
InstanceNorm2d-62 [-1, 256, 64, 64] 0
ReLU-63 [-1, 256, 64, 64] 0
Resnet_block-64 [-1, 256, 64, 64] 0
ConvTranspose2d-65 [-1, 128, 128, 128] 295,040
InstanceNorm2d-66 [-1, 128, 128, 128] 0
ReLU-67 [-1, 128, 128, 128] 0
ConvTranspose2d-68 [-1, 64, 256, 256] 73,792
InstanceNorm2d-69 [-1, 64, 256, 256] 0
ReLU-70 [-1, 64, 256, 256] 0
ReflectionPad2d-71 [-1, 64, 262, 262] 0
Conv2d-72 [-1, 3, 256, 256] 9,411
Tanh-73 [-1, 3, 256, 256] 0
================================================================
Total params: 7,837,699
Trainable params: 7,837,699
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.75
Forward/backward pass size (MB): 788.18
Params size (MB): 29.90
Estimated Total Size (MB): 818.83
----------------------------------------------------------------
class Cycle_Gan_D(nn.Module):
def __init__(self):
super(Cycle_Gan_D, self).__init__()
# 定义基本的卷积\bn\relu
def base_Conv_bn_lkrl(in_channels, out_channels, stride):
if in_channels == 3:
bn = nn.Identity
else:
bn = nn.InstanceNorm2d
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, 4, stride, 1),
bn(out_channels),
nn.LeakyReLU(0.2, True)
)
D_dic = OrderedDict()
in_channels = 3
out_channels = 64
for i in range(4):
if i < 3:
D_dic.update({'layer_{}'.format(i + 1): base_Conv_bn_lkrl(in_channels, out_channels, 2)})
else:
D_dic.update({'layer_{}'.format(i + 1): base_Conv_bn_lkrl(in_channels, out_channels, 1)})
in_channels = out_channels
out_channels *= 2
D_dic.update({'last_layer': nn.Conv2d(512, 1, 4, 1, 1)}) # [batch,1,30,30]
self.D_model = nn.Sequential(D_dic)
def forward(self, x):
return self.D_model(x)
if __name__ == '__main__':
D = Cycle_Gan_D().to('cuda')
summary(D, (3, 256, 256))
网络整体架构如下:
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 64, 128, 128] 3,136
Identity-2 [-1, 64, 128, 128] 0
LeakyReLU-3 [-1, 64, 128, 128] 0
Conv2d-4 [-1, 128, 64, 64] 131,200
InstanceNorm2d-5 [-1, 128, 64, 64] 0
LeakyReLU-6 [-1, 128, 64, 64] 0
Conv2d-7 [-1, 256, 32, 32] 524,544
InstanceNorm2d-8 [-1, 256, 32, 32] 0
LeakyReLU-9 [-1, 256, 32, 32] 0
Conv2d-10 [-1, 512, 31, 31] 2,097,664
InstanceNorm2d-11 [-1, 512, 31, 31] 0
LeakyReLU-12 [-1, 512, 31, 31] 0
Conv2d-13 [-1, 1, 30, 30] 8,193
================================================================
Total params: 2,764,737
Trainable params: 2,764,737
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.75
Forward/backward pass size (MB): 53.27
Params size (MB): 10.55
Estimated Total Size (MB): 64.57
----------------------------------------------------------------