CycleGAN是于2017年发表在ICCV上的由GAN发展而来的一种无监督机器学习算法,是一种实现图像风格转换功能的GAN网络,在此之前存在着pix2pix实现图像风格转换,但pix2pix具有很大的局限性,主要是要求针对两种风格图像要对应出现,而现实中很难找到一些风格不同相同图像,也能难去拍摄获得,CycleGan实现了这个功能,在两种类型图像之间进行转换,而不需要对应关系。比如把照片转换为油画风格,或者把照片的橘子转换为苹果、马与斑马之间的转换等。
实现效果:
马转斑马
代码实现:
网络定义和训练代码
'''
Descripttion:
version:
Author: MAPLE
Date: 2022-06-12 23:23:54
LastEditors: MAPLE
LastEditTime: 2022-06-28 23:24:09
'''
import os
import torch
import random
import torch.nn as nn
import numpy as np
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torch.nn import init
from torch.optim import lr_scheduler
from tqdm import tqdm
from torchvision.utils import save_image
import torch.optim as optim
import torchvision.transforms as transforms
torch.cuda.is_available()
def seed_torch(seed=2018):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TRAIN_DIR = "data/horse2zebra"
VAL_DIR = "data/horse2zebra"
BATCH_SIZE = 1
LEARNING_RATE = 2e-4#学习率
LAMBDA_IDENTITY = 5 # identityloss
LAMBDA_CYCLE = 10 # 循环一致性损失
NUM_WORKERS = 2
LOAD_MODEL = True#加载模型参数
SAVE_MODEL = True#保存模型参数
#模型参数保存位置
CHECKPOINT_GEN_H = "genh.pth.tar"
CHECKPOINT_GEN_Z = "genz.pth.tar"
CHECKPOINT_CRITIC_H = "critich.pth.tar"
CHECKPOINT_CRITIC_Z = "criticz.pth.tar"
#学习率调度超参数
EPOCH_COUNT = 1
N_EPOCHS = 100
N_EPOCHS_DECAY = 100
transforms = transforms.Compose(
[
transforms.Resize(286, Image.BICUBIC),#重构
transforms.RandomCrop(256),#随机裁剪
transforms.RandomHorizontalFlip(p=0.5),#随机水平翻转
transforms.ToTensor(),#转成tensor格式
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))#归一化
]
)
# 自定义参数初始化方式,用于多层网络初始化
def init_weights(net, init_type='normal', init_gain=0.02):
"""Initialize network weights.
使用标准正态分布
"""
def init_func(m): # define the initialization function
classname = m.__class__.__name__
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
if init_type == 'normal':
init.normal_(m.weight.data, 0.0, init_gain)
if hasattr(m, 'bias') and m.bias is not None:
init.constant_(m.bias.data, 0.0)
# BatchNorm Layer's weight is not a matrix; only normal distribution applies.
elif classname.find('BatchNorm2d') != -1:
init.normal_(m.weight.data, 1.0, init_gain)
init.constant_(m.bias.data, 0.0)
print('initialize network with %s' % init_type)
net.to(DEVICE)
net.apply(init_func) # apply the initialization function
class ImagePool():
def __init__(self, pool_size):
self.pool_size = pool_size
if self.pool_size > 0: # create an empty pool
self.num_imgs = 0
self.images = []
def query(self, images):
"""从缓存区返回图片
"""
if self.pool_size == 0: # if the buffer size is 0, do nothing
return images
return_images = []
for image in images:
image = torch.unsqueeze(image.data, 0)
if self.num_imgs < self.pool_size: # if the buffer is not full; keep inserting current images to the buffer
self.num_imgs = self.num_imgs + 1
self.images.append(image)
return_images.append(image)
else:
p = random.uniform(0, 1)
if p > 0.5: # 50%的概率返回以前生成的图像
random_id = random.randint(
0, self.pool_size - 1) # randint is inclusive
tmp = self.images[random_id].clone()
self.images[random_id] = image # 将新得到的图片存入缓存区
return_images.append(tmp)
else: # by another 50% chance, the buffer will return the current image
return_images.append(image)
# collect all the images and return
return_images = torch.cat(return_images, 0)
return return_images
GLOBAL_SEED = 1
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
set_seed(GLOBAL_SEED)
# 两个结构链接区域使用Residual block模块,默认是9个重复模块
class ResnetBlock(nn.Module):
"""Define a Resnet block"""
def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
"""Initialize the Resnet block
"""
super(ResnetBlock, self).__init__()
self.conv_block = self.build_conv_block(
dim, padding_type, norm_layer, use_dropout, use_bias)
def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
"""Construct a convolutional block."""
conv_block = []
p = 0
if padding_type == 'reflect':
conv_block += [nn.ReflectionPad2d(1)]
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p,
bias=use_bias), norm_layer(dim), nn.ReLU(True)]
#根据经验得,dropout在卷积中一般没啥用
if use_dropout:
conv_block += [nn.Dropout(0.5)]
p = 0
if padding_type == 'reflect':
conv_block += [nn.ReflectionPad2d(1)]
conv_block += [nn.Conv2d(dim, dim, kernel_size=3,
padding=p, bias=use_bias), norm_layer(dim)]
return nn.Sequential(*conv_block)
def forward(self, x):
"""Forward function (with skip connections)"""
out = x + self.conv_block(x) # add skip connections
return out
# 使用Residual block的生成器
class ResnetGenerator(nn.Module):
def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=9, padding_type='reflect'):
"""Construct a Resnet-based generator
"""
super(ResnetGenerator, self).__init__()
use_bias = norm_layer == nn.InstanceNorm2d
model = [nn.ReflectionPad2d(3),
nn.Conv2d(input_nc, ngf, kernel_size=7,padding=0, bias=use_bias),
norm_layer(ngf),
nn.ReLU(True)]
n_downsampling = 2
for i in range(n_downsampling): # add downsampling layers
mult = 2 ** i
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
norm_layer(ngf * mult * 2),
nn.ReLU(True)]
mult = 2 ** n_downsampling
for i in range(n_blocks): # add ResNet blocks
model += [ResnetBlock(ngf * mult, padding_type=padding_type,
norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
for i in range(n_downsampling): # add upsampling layers
mult = 2 ** (n_downsampling - i)
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),kernel_size=3, stride=2,padding=1, output_padding=1,bias=use_bias),
norm_layer(int(ngf * mult / 2)),
nn.ReLU(True)]
model += [nn.ReflectionPad2d(3)]
model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
model += [nn.Tanh()]
self.model = nn.Sequential(*model)
init_weights(self.model)
def forward(self, input):
"""Standard forward"""
return self.model(input)
#马尔可夫判别器(PatchGAN),由卷积层构成,最后输出一个n*n的预测矩阵
class NLayerDiscriminator(nn.Module):
"""Defines a PatchGAN discriminator"""
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
"""Construct a PatchGAN discriminator"""
super(NLayerDiscriminator, self).__init__()
use_bias = norm_layer == nn.InstanceNorm2d
kw = 4
padw = 1
sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw,stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
nf_mult = 1
nf_mult_prev = 1
for n in range(1, n_layers): # gradually increase the number of filters
nf_mult_prev = nf_mult
nf_mult = min(2 ** n, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,kernel_size=kw, stride=2, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
nf_mult_prev = nf_mult
nf_mult = min(2 ** n_layers, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,kernel_size=kw, stride=1, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
# output 1 channel prediction map
sequence += [nn.Conv2d(ndf * nf_mult, 1,kernel_size=kw, stride=1, padding=padw)]
self.model = nn.Sequential(*sequence)
init_weights(self.model)
def forward(self, input):
"""Standard forward."""
return self.model(input)
# 学习率调度
def get_scheduler(optimizer):
"""Return a learning rate scheduler
前100个epoch保持不变,后100个epoch线性衰减到0
"""
def lambda_rule(epoch):
lr_l = 1.0 - max(0, epoch + EPOCH_COUNT -N_EPOCHS) / float(N_EPOCHS_DECAY + 1)
return lr_l
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
return scheduler
def train_fn(disc_H, disc_Z, gen_H, gen_Z, loader, opt_disc, opt_gen, l1, mse):
fake_H_pool = ImagePool(50)
fake_Z_pool = ImagePool(50)
H_reals = 0
H_fakes = 0
Z_reals = 0
Z_fakes = 0
loop = tqdm(loader, leave=True)
for idx, data in enumerate(loop):
zebra = data['B'].to(DEVICE)
horse = data['A'].to(DEVICE)
# Train Discriminators H and Z
fake_horse = gen_H(zebra)
fake_horse_train = fake_H_pool.query(fake_horse)
D_H_real = disc_H(horse)
D_H_fake = disc_H(fake_horse_train.detach())
H_reals += D_H_real.mean().item()
H_fakes += D_H_fake.mean().item()
D_H_real_loss = mse(D_H_real, torch.ones_like(D_H_real))
D_H_fake_loss = mse(D_H_fake, torch.zeros_like(D_H_fake))
D_H_loss = D_H_real_loss + D_H_fake_loss
fake_zebra = gen_Z(horse)
fake_zebra_train = fake_Z_pool.query(fake_zebra)
D_Z_real = disc_Z(zebra)
D_Z_fake = disc_Z(fake_zebra_train.detach())
Z_reals += D_Z_real.mean().item()
Z_fakes += D_Z_fake.mean().item()
D_Z_real_loss = mse(D_Z_real, torch.ones_like(D_Z_real))
D_Z_fake_loss = mse(D_Z_fake, torch.zeros_like(D_Z_fake))
D_Z_loss = D_Z_real_loss + D_Z_fake_loss
# put it togethor
D_loss = (D_H_loss + D_Z_loss)/2
opt_disc.zero_grad()
D_loss.backward()
opt_disc.step()
# Train Generators H and Z
# adversarial loss for both generators
D_H_fake = disc_H(fake_horse)
D_Z_fake = disc_Z(fake_zebra)
loss_G_H = mse(D_H_fake, torch.ones_like(D_H_fake))
loss_G_Z = mse(D_Z_fake, torch.ones_like(D_Z_fake))
# cycle loss
cycle_zebra = gen_Z(fake_horse)
cycle_horse = gen_H(fake_zebra)
cycle_zebra_loss = l1(zebra, cycle_zebra)
cycle_horse_loss = l1(horse, cycle_horse)
# identity loss (remove these for efficiency if you set lambda_identity=0)
identity_zebra = gen_Z(zebra)
identity_horse = gen_H(horse)
identity_zebra_loss = l1(zebra, identity_zebra)
identity_horse_loss = l1(horse, identity_horse)
# add all togethor
G_loss = (
loss_G_Z
+ loss_G_H
+ cycle_zebra_loss * LAMBDA_CYCLE
+ cycle_horse_loss * LAMBDA_CYCLE
+ identity_horse_loss * LAMBDA_IDENTITY
+ identity_zebra_loss * LAMBDA_IDENTITY
)
opt_gen.zero_grad()
G_loss.backward()
opt_gen.step()
if idx % 200 == 0:
save_image(fake_horse*0.5+0.5, f"train_images/horse_{idx}.png")
save_image(fake_zebra*0.5+0.5, f"train_images/zebra_{idx}.png")
loop.set_postfix(H_real=H_reals/(idx+1), H_fake=H_fakes /
(idx+1), D_real=Z_reals/(idx+1), D_fake=Z_fakes/(idx+1))
class CombineDataset(Dataset):
def __init__(self, root_A, root_B, transform):
self.root_A = root_A
self.root_B = root_B
self.transform = transform
self.A_paths = os.listdir(root_A)
self.B_paths = os.listdir(root_B)
self.length_dataset = max(len(self.A_paths), len(self.B_paths))
self.A_len = len(self.A_paths)
self.B_len = len(self.B_paths)
def __len__(self):
return self.length_dataset
def __getitem__(self, index):
A_path = self.A_paths[index % self.A_len]
B_path = self.B_paths[index % self.B_len]
A_img = Image.open(self.root_A+A_path).convert("RGB")
B_img = Image.open(self.root_B+B_path).convert("RGB")
A = self.transform(A_img)
B = self.transform(B_img)
return {'A': A, 'B': B}
def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
print("=> Saving checkpoint")
checkpoint = {
"state_dict": model.state_dict(),
"optimizer": optimizer.state_dict(),
}
torch.save(checkpoint, filename)
def load_checkpoint(checkpoint_file, model, optimizer, lr):
print("=> Loading checkpoint")
checkpoint = torch.load(checkpoint_file, map_location=DEVICE)
model.load_state_dict(checkpoint["state_dict"])
optimizer.load_state_dict(checkpoint["optimizer"])
# 修改学习率,使用当前的学习率
for param_group in optimizer.param_groups:
param_group["lr"] = lr
dataset = CombineDataset(root_A=TRAIN_DIR+"/trainA/",
root_B=TRAIN_DIR+"/trainB/", transform=transforms)
data_loader = DataLoader(
dataset,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=NUM_WORKERS,
pin_memory=True
)
dataset_size = len(data_loader)
print('The number of training images = %d' % dataset_size)
disc_H = NLayerDiscriminator(input_nc=3).to(DEVICE)
disc_Z = NLayerDiscriminator(input_nc=3).to(DEVICE)
gen_Z = ResnetGenerator(input_nc=3, output_nc=3).to(DEVICE)
gen_H = ResnetGenerator(input_nc=3, output_nc=3).to(DEVICE)
opt_disc = optim.Adam(
list(disc_H.parameters()) + list(disc_Z.parameters()),
lr=LEARNING_RATE,
betas=(0.5, 0.999),
)
opt_gen = optim.Adam(
list(gen_Z.parameters()) + list(gen_H.parameters()),
lr=LEARNING_RATE,
betas=(0.5, 0.999),
)
scheduler_disc = get_scheduler(opt_disc)
scheduler_gen = get_scheduler(opt_gen)
L1 = nn.L1Loss()
mse = nn.MSELoss()
if LOAD_MODEL:
load_checkpoint(
CHECKPOINT_GEN_H, gen_H, opt_gen, LEARNING_RATE,
)
load_checkpoint(
CHECKPOINT_GEN_Z, gen_Z, opt_gen, LEARNING_RATE,
)
load_checkpoint(
CHECKPOINT_CRITIC_H, disc_H, opt_disc, LEARNING_RATE,
)
load_checkpoint(
CHECKPOINT_CRITIC_Z, disc_Z, opt_disc, LEARNING_RATE,
)
for epoch in range(EPOCH_COUNT, N_EPOCHS+N_EPOCHS_DECAY+1):
train_fn(disc_H, disc_Z, gen_H, gen_Z,
data_loader, opt_disc, opt_gen, L1, mse)
scheduler_disc.step()
scheduler_gen.step()
if SAVE_MODEL:
save_checkpoint(gen_H, opt_gen, filename=CHECKPOINT_GEN_H)
save_checkpoint(gen_Z, opt_gen, filename=CHECKPOINT_GEN_Z)
save_checkpoint(disc_H, opt_disc, filename=CHECKPOINT_CRITIC_H)
save_checkpoint(disc_Z, opt_disc, filename=CHECKPOINT_CRITIC_Z)
完整工程训练参数数据集若需要请留言。