目录
一、前言
二、数据集
三、网络结构
四、代码
(一)net
(二)train
(三)test
五、结果
(一)loss
(二)训练可视化
(三)测试结果
六、完整代码
pix2pix对训练样本要求较高,需要成对的数据集,而这种样本的获取往往需要耗费很大精力。CycleGAN恰巧解决了该问题,实现两个domain之间的转换,即只需要准备两种风格的数据集,让GAN去学习将domain X中的图片转换成domain Y的风格(不改变domain X原图中物体,仅仅实现风格转换)。
一种直观的思路是直接让G去学习domain X 到domain Y 以及domain Y 到domain X的映射关系,但这种方式会造成G生成图片的随机性太强,会使得生成的图片与输入的图片完全不相关,不仅违背了CycleGAN的目的,同时输出的结果也没有任何意义。
作者认为这种转换应该具有循环一致性,比如在语言翻译中,把一段话从中文翻译成英文,再从英文翻译回中文,意思应该是相近的,CycleGAN就是采用了这种思想。假设Ga表示Domain X到Domain Y的生成器,Gb表示Domain Y 到Domain X 的生成器,那么让Domain X中的图片real_A通过Ga后生成的图片fake_A再通过Gb生成的rec_A应该和A是高度相似的,Domain Y到Domain X同理。
CycleGAN中有两个生成器以及两个判别器,分别对应Domain X 到Domain Y 以及Domain Y到Domain X。
这里我采用的是monet2photo数据集(莫奈画->真实风景照片),部分数据如下图所示。
Domain X(monet):
Domain Y(photo):
生成器G的结构如下图所示,判别器D与pix2pix相同,网络结构pix2pix。
初始化方式与源码不同。
import torch.nn as nn
from torchsummary import summary
from collections import OrderedDict
# 定义残差块
class Resnet_block(nn.Module):
def __init__(self, in_channels):
super(Resnet_block, self).__init__()
block = []
for i in range(2):
block += [nn.ReflectionPad2d(1),
nn.Conv2d(in_channels, in_channels, 3, 1, 0),
nn.InstanceNorm2d(in_channels),
nn.ReLU(True) if i > 0 else nn.Identity()]
self.block = nn.Sequential(*block)
def forward(self, x):
out = x + self.block(x)
return out
class Cycle_Gan_G(nn.Module):
def __init__(self):
super(Cycle_Gan_G, self).__init__()
net_dic = OrderedDict()
# 三层卷积层
net_dic.update({'first layer': nn.Sequential(
nn.ReflectionPad2d(3), # [3,256,256] -> [3,262,262]
nn.Conv2d(3, 64, 7, 1), # [3,262,262] ->[64,256,256]
nn.InstanceNorm2d(64),
nn.ReLU(True)
)})
net_dic.update({'second_conv': nn.Sequential(
nn.Conv2d(64, 128, 3, 2, 1), # [128,128,128]
nn.InstanceNorm2d(128),
nn.ReLU(True)
)})
net_dic.update({'three_conv': nn.Sequential(
nn.Conv2d(128, 256, 3, 2, 1), # [256,64,64]
nn.InstanceNorm2d(256),
nn.ReLU(True)
)})
# 9层 resnet block
for i in range(6):
net_dic.update({'Resnet_block{}'.format(i + 1): Resnet_block(256)})
# up_sample
net_dic.update({'up_sample1': nn.Sequential(
nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.InstanceNorm2d(128), # [128,128,128]
nn.ReLU(True)
)})
net_dic.update({'up_sample2': nn.Sequential(
nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.InstanceNorm2d(64), # [64,256,256]
nn.ReLU(True)
)})
net_dic.update({'last_layer': nn.Sequential(
nn.ReflectionPad2d(3),
nn.Conv2d(64, 3, 7, 1),
nn.Tanh()
)})
self.net_G = nn.Sequential(net_dic)
self.init_weight()
def init_weight(self):
for w in self.modules():
if isinstance(w, nn.Conv2d):
nn.init.kaiming_normal_(w.weight, mode='fan_out')
if w.bias is not None:
nn.init.zeros_(w.bias)
elif isinstance(w, nn.ConvTranspose2d):
nn.init.kaiming_normal_(w.weight, mode='fan_in')
elif isinstance(w, nn.BatchNorm2d):
nn.init.ones_(w.weight)
nn.init.zeros_(w.bias)
def forward(self, x):
out = self.net_G(x)
return out
class Cycle_Gan_D(nn.Module):
def __init__(self):
super(Cycle_Gan_D, self).__init__()
# 定义基本的卷积\bn\relu
def base_Conv_bn_lkrl(in_channels, out_channels, stride):
if in_channels == 3:
bn = nn.Identity
else:
bn = nn.InstanceNorm2d
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, 4, stride, 1),
bn(out_channels),
nn.LeakyReLU(0.2, True)
)
D_dic = OrderedDict()
in_channels = 3
out_channels = 64
for i in range(4):
if i < 3:
D_dic.update({'layer_{}'.format(i + 1): base_Conv_bn_lkrl(in_channels, out_channels, 2)})
else:
D_dic.update({'layer_{}'.format(i + 1): base_Conv_bn_lkrl(in_channels, out_channels, 1)})
in_channels = out_channels
out_channels *= 2
D_dic.update({'last_layer': nn.Conv2d(512, 1, 4, 1, 1)}) # [batch,1,30,30]
self.D_model = nn.Sequential(D_dic)
def forward(self, x):
return self.D_model(x)
if __name__ == '__main__':
# G = Cycle_Gan_G().to('cuda')
# summary(G, (3, 256, 256))
D = Cycle_Gan_D().to('cuda')
summary(D, (3, 256, 256))
训练过程中有一些小细节,为了减小模型振荡,提高训练的稳定性,论文中采用了buffer来暂存G生成的图片,用之前生成的图片来更新判别器。G共包含三种损失(两个方向共6部分),GAN_loss、Cycle_loss、id_loss。其中,GAN_loss就是传统GAN的loss,使得输出图片尽可能真,Cycle_loss是重建的图片与原始图片之间的L1损失,id_loss是为了保证G不去随意改变图片的色调(即便判别器告诉你另外一种色调也服从Domain Y的分布,但为了仅仅改变风格不改变别的因素,因此引入了该损失)。判别器D仍然采用了PatchGAN,训练过程与pix2pix类似。
import itertools
from image_pool import ImagePool
from torch.utils.tensorboard import SummaryWriter
from cyclegan import Cycle_Gan_G, Cycle_Gan_D
import argparse
from mydatasets import CreateDatasets
import os
from torch.utils.data.dataloader import DataLoader
import torch
import torch.optim as optim
import torch.nn as nn
from utils import train_one_epoch, val
def train(opt):
batch = opt.batch
data_path = opt.dataPath
print_every = opt.every
device = 'cuda' if torch.cuda.is_available() else 'cpu'
epochs = opt.epoch
img_size = opt.imgsize
if not os.path.exists(opt.savePath):
os.mkdir(opt.savePath)
# 加载数据集
train_datasets = CreateDatasets(data_path, img_size, mode='train')
val_datasets = CreateDatasets(data_path, img_size, mode='test')
train_loader = DataLoader(dataset=train_datasets, batch_size=batch, shuffle=True, num_workers=opt.numworker,
drop_last=True)
val_loader = DataLoader(dataset=val_datasets, batch_size=batch, shuffle=True, num_workers=opt.numworker,
drop_last=True)
# 实例化网络
Cycle_G_A = Cycle_Gan_G().to(device)
Cycle_D_A = Cycle_Gan_D().to(device)
Cycle_G_B = Cycle_Gan_G().to(device)
Cycle_D_B = Cycle_Gan_D().to(device)
# 定义优化器和损失函数
optim_G = optim.Adam(itertools.chain(Cycle_G_A.parameters(), Cycle_G_B.parameters()), lr=0.0002, betas=(0.5, 0.999))
optim_D = optim.Adam(itertools.chain(Cycle_D_A.parameters(), Cycle_D_B.parameters()), lr=0.0002, betas=(0.5, 0.999))
loss = nn.MSELoss()
l1_loss = nn.L1Loss()
start_epoch = 0
A_fake_pool = ImagePool(50)
B_fake_pool = ImagePool(50)
# 加载预训练权重
if opt.weight != '':
ckpt = torch.load(opt.weight)
Cycle_G_A.load_state_dict(ckpt['Ga_model'], strict=False)
Cycle_G_B.load_state_dict(ckpt['Gb_model'], strict=False)
Cycle_D_A.load_state_dict(ckpt['Da_model'], strict=False)
Cycle_D_B.load_state_dict(ckpt['Db_model'], strict=False)
start_epoch = ckpt['epoch'] + 1
writer = SummaryWriter('train_logs')
# 开始训练
for epoch in range(start_epoch, epochs):
loss_mG, loss_mD = train_one_epoch(Ga=Cycle_G_A, Da=Cycle_D_A, Gb=Cycle_G_B, Db=Cycle_D_B,
train_loader=train_loader,
optim_G=optim_G, optim_D=optim_D, writer=writer, loss=loss, device=device,
plot_every=print_every, epoch=epoch, l1_loss=l1_loss,
A_fake_pool=A_fake_pool, B_fake_pool=B_fake_pool)
writer.add_scalars(main_tag='train_loss', tag_scalar_dict={
'loss_G': loss_mG,
'loss_D': loss_mD
}, global_step=epoch)
# 保存模型
torch.save({
'Ga_model': Cycle_G_A.state_dict(),
'Gb_model': Cycle_G_B.state_dict(),
'Da_model': Cycle_D_A.state_dict(),
'Db_model': Cycle_D_B.state_dict(),
'epoch': epoch
}, './weights/cycle_monent2photo.pth')
# 验证集
val(Ga=Cycle_G_A, Da=Cycle_D_A, Gb=Cycle_G_B, Db=Cycle_D_B, val_loader=val_loader, loss=loss, l1_loss=l1_loss,
device=device, epoch=epoch)
def cfg():
parse = argparse.ArgumentParser()
parse.add_argument('--batch', type=int, default=1)
parse.add_argument('--epoch', type=int, default=100)
parse.add_argument('--imgsize', type=int, default=256)
parse.add_argument('--dataPath', type=str, default='../monet2photo', help='data root path')
parse.add_argument('--weight', type=str, default='', help='load pre train weight')
parse.add_argument('--savePath', type=str, default='./weights', help='weight save path')
parse.add_argument('--numworker', type=int, default=4)
parse.add_argument('--every', type=int, default=20, help='plot train result every * iters')
opt = parse.parse_args()
return opt
if __name__ == '__main__':
opt = cfg()
print(opt)
train(opt)
import torchvision
from tqdm import tqdm
import torch
import os
def train_one_epoch(Ga, Da, Gb, Db, train_loader, optim_G, optim_D, writer, loss, device, plot_every, epoch, l1_loss,
A_fake_pool, B_fake_pool):
pd = tqdm(train_loader)
loss_D, loss_G = 0, 0
step = 0
Ga.train()
Da.train()
Gb.train()
Db.train()
for idx, data in enumerate(pd):
A_real = data[0].to(device)
B_real = data[1].to(device)
# 前向传递
B_fake = Ga(A_real) # Ga生成的假B
A_rec = Gb(B_fake) # Gb重构回的A
A_fake = Gb(B_real) # Gb生成的假A
B_rec = Ga(A_fake) # Ga重构回的B
# 训练G => G包含六部分损失
set_required_grad([Da, Db], requires_grad=False) # 不更新D
optim_G.zero_grad()
ls_G = train_G(Da=Da, Db=Db, B_fake=B_fake, loss=loss, A_fake=A_fake, l1_loss=l1_loss,
A_rec=A_rec,
A_real=A_real, B_rec=B_rec, B_real=B_real, Ga=Ga, Gb=Gb)
ls_G.backward()
optim_G.step()
# 训练D
set_required_grad([Da, Db], requires_grad=True)
optim_D.zero_grad()
A_fake_p = A_fake_pool.query(A_fake)
B_fake_p = B_fake_pool.query(B_fake)
ls_D = train_D(Da=Da, Db=Db, B_fake=B_fake_p, B_real=B_real, loss=loss, A_fake=A_fake_p, A_real=A_real)
ls_D.backward()
optim_D.step()
loss_D += ls_D
loss_G += ls_G
pd.desc = 'train_{} G_loss: {} D_loss: {}'.format(epoch, ls_G.item(), ls_D.item())
# 绘制训练结果
if idx % plot_every == 0:
writer.add_images(tag='epoch{}_Ga'.format(epoch), img_tensor=0.5 * (torch.cat([A_real, B_fake], 0) + 1),
global_step=step)
writer.add_images(tag='epoch{}_Gb'.format(epoch), img_tensor=0.5 * (torch.cat([B_real, A_fake], 0) + 1),
global_step=step)
step += 1
mean_lsG = loss_G / len(train_loader)
mean_lsD = loss_D / len(train_loader)
return mean_lsG, mean_lsD
@torch.no_grad()
def val(Ga, Da, Gb, Db, val_loader, loss, device, l1_loss, epoch):
pd = tqdm(val_loader)
loss_D, loss_G = 0, 0
Ga.eval()
Da.eval()
Gb.eval()
Db.eval()
all_loss = 10000
for idx, item in enumerate(pd):
A_real_img = item[0].to(device)
B_real_img = item[1].to(device)
B_fake_img = Ga(A_real_img)
A_fake_img = Gb(B_real_img)
A_rec = Gb(B_fake_img)
B_rec = Ga(A_fake_img)
# D的loss
ls_D = train_D(Da=Da, Db=Db, B_fake=B_fake_img, B_real=B_real_img, loss=loss, A_fake=A_fake_img,
A_real=A_real_img)
# G的loss
ls_G = train_G(Da=Da, Db=Db, B_fake=B_fake_img, loss=loss, A_fake=A_fake_img, l1_loss=l1_loss,
A_rec=A_rec,
A_real=A_real_img, B_rec=B_rec, B_real=B_real_img, Ga=Ga, Gb=Gb)
loss_G += ls_G
loss_D += ls_D
pd.desc = 'val_{}: G_loss:{} D_Loss:{}'.format(epoch, ls_G.item(), ls_D.item())
# 保存最好的结果
all_ls = ls_G + ls_D
if all_ls < all_loss:
all_loss = all_ls
best_image = torch.cat([A_real_img, B_fake_img, B_real_img, A_fake_img], 0)
result_img = (best_image + 1) * 0.5
if not os.path.exists('./results'):
os.mkdir('./results')
torchvision.utils.save_image(result_img, './results/val_epoch{}_cycle.jpg'.format(epoch))
def set_required_grad(nets, requires_grad=False):
if not isinstance(nets, list):
nets = [nets]
for net in nets:
if net is not None:
for params in net.parameters():
params.requires_grad = requires_grad
def train_G(Da, Db, B_fake, loss, A_fake, l1_loss, A_rec, A_real, B_rec, B_real, Ga, Gb):
# GAN loss
Da_out_fake = Da(B_fake)
Ga_gan_loss = loss(Da_out_fake, torch.ones(Da_out_fake.size()).cuda())
Db_out_fake = Db(A_fake)
Gb_gan_loss = loss(Db_out_fake, torch.ones(Db_out_fake.size()).cuda())
# Cycle loss
Cycle_A_loss = l1_loss(A_rec, A_real) * 10
Cycle_B_loss = l1_loss(B_rec, B_real) * 10
# identity loss
Ga_id_out = Ga(B_real)
Gb_id_out = Gb(A_real)
Ga_id_loss = l1_loss(Ga_id_out, B_real) * 10 * 0.5
Gb_id_loss = l1_loss(Gb_id_out, A_real) * 10 * 0.5
# G的总损失
ls_G = Ga_gan_loss + Gb_gan_loss + Cycle_A_loss + Cycle_B_loss + Ga_id_loss + Gb_id_loss
return ls_G
def train_D(Da, Db, B_fake, B_real, loss, A_fake, A_real):
# Da的loss
Da_fake_out = Da(B_fake.detach()).squeeze()
Da_real_out = Da(B_real).squeeze()
ls_Da1 = loss(Da_fake_out, torch.zeros(Da_fake_out.size()).cuda())
ls_Da2 = loss(Da_real_out, torch.ones(Da_real_out.size()).cuda())
ls_Da = (ls_Da1 + ls_Da2) * 0.5
# Db的loss
Db_fake_out = Db(A_fake.detach()).squeeze()
Db_real_out = Db(A_real.detach()).squeeze()
ls_Db1 = loss(Db_fake_out, torch.zeros(Db_fake_out.size()).cuda())
ls_Db2 = loss(Db_real_out, torch.ones(Db_real_out.size()).cuda())
ls_Db = (ls_Db1 + ls_Db2) * 0.5
# D的总损失
ls_D = ls_Da + ls_Db
return ls_D
from cyclegan import Cycle_Gan_G
import torch
import torchvision.transforms as transform
import matplotlib.pyplot as plt
import cv2
from PIL import Image
def test(img_path):
if img_path.endswith('.png'):
img = cv2.imread(img_path)
img = img[:, :, ::-1]
else:
img = Image.open(img_path)
transforms = transform.Compose([
transform.ToTensor(),
transform.Resize((256, 256)),
transform.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
img = transforms(img.copy())
img = img[None].to('cuda') # [1,3,128,128]
# 实例化网络
Gb = Cycle_Gan_G().to('cuda')
# 加载预训练权重
ckpt = torch.load('weights/cycle_monent2photo.pth')
Gb.load_state_dict(ckpt['Gb_model'], strict=False)
Gb.eval()
out = Gb(img)[0]
out = out.permute(1, 2, 0)
out = (0.5 * (out + 1)).cpu().detach().numpy()
plt.figure()
plt.imshow(out)
plt.show()
if __name__ == '__main__':
test('123.jpg')
这里我挑选了一部分训练结果和验证结果。
训练集上monet -> photo
训练集上photo-> monet
验证集上结果(左边为monet -> photo,右边为photo-> monet )
下图为photo转monet的结果
数据集:百度网盘 请输入提取码 提取码:s3e3
代码:百度网盘 请输入提取码 提取码:t0d5