CycleGAN的简单实现(pytorch)

        CycleGAN是于2017年发表在ICCV上的由GAN发展而来的一种无监督机器学习算法,是一种实现图像风格转换功能的GAN网络,在此之前存在着pix2pix实现图像风格转换,但pix2pix具有很大的局限性,主要是要求针对两种风格图像要对应出现,而现实中很难找到一些风格不同相同图像,也能难去拍摄获得,CycleGan实现了这个功能,在两种类型图像之间进行转换,而不需要对应关系。比如把照片转换为油画风格,或者把照片的橘子转换为苹果、马与斑马之间的转换等。

实现效果:

CycleGAN的简单实现(pytorch)_第1张图片CycleGAN的简单实现(pytorch)_第2张图片

CycleGAN的简单实现(pytorch)_第3张图片

 CycleGAN的简单实现(pytorch)_第4张图片

马转斑马

代码实现:

网络定义和训练代码

'''
Descripttion: 
version: 
Author: MAPLE
Date: 2022-06-12 23:23:54
LastEditors: MAPLE
LastEditTime: 2022-06-28 23:24:09
'''
import os
import torch
import random
import torch.nn as nn
import numpy as np
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torch.nn import init
from torch.optim import lr_scheduler
from tqdm import tqdm
from torchvision.utils import save_image
import torch.optim as optim
import torchvision.transforms as transforms

torch.cuda.is_available()

def seed_torch(seed=2018):

    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TRAIN_DIR = "data/horse2zebra"
VAL_DIR = "data/horse2zebra"

BATCH_SIZE = 1
LEARNING_RATE = 2e-4#学习率
LAMBDA_IDENTITY = 5  # identityloss
LAMBDA_CYCLE = 10  # 循环一致性损失
NUM_WORKERS = 2
LOAD_MODEL = True#加载模型参数
SAVE_MODEL = True#保存模型参数

#模型参数保存位置
CHECKPOINT_GEN_H = "genh.pth.tar"
CHECKPOINT_GEN_Z = "genz.pth.tar"
CHECKPOINT_CRITIC_H = "critich.pth.tar"
CHECKPOINT_CRITIC_Z = "criticz.pth.tar"

#学习率调度超参数
EPOCH_COUNT = 1
N_EPOCHS = 100
N_EPOCHS_DECAY = 100

transforms = transforms.Compose(
    [
        transforms.Resize(286, Image.BICUBIC),#重构
        transforms.RandomCrop(256),#随机裁剪
        transforms.RandomHorizontalFlip(p=0.5),#随机水平翻转
        transforms.ToTensor(),#转成tensor格式
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))#归一化
    ]
)

# 自定义参数初始化方式,用于多层网络初始化
def init_weights(net, init_type='normal', init_gain=0.02):
    """Initialize network weights.
        使用标准正态分布
    """
    def init_func(m):  # define the initialization function
        classname = m.__class__.__name__
        if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
            if init_type == 'normal':
                init.normal_(m.weight.data, 0.0, init_gain)
            if hasattr(m, 'bias') and m.bias is not None:
                init.constant_(m.bias.data, 0.0)
        # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
        elif classname.find('BatchNorm2d') != -1:
            init.normal_(m.weight.data, 1.0, init_gain)
            init.constant_(m.bias.data, 0.0)

    print('initialize network with %s' % init_type)
    net.to(DEVICE)
    net.apply(init_func)  # apply the initialization function 

class ImagePool():
    def __init__(self, pool_size):
        self.pool_size = pool_size
        if self.pool_size > 0:  # create an empty pool
            self.num_imgs = 0
            self.images = []

    def query(self, images):
        """从缓存区返回图片
        """
        if self.pool_size == 0:  # if the buffer size is 0, do nothing
            return images
        return_images = []
        for image in images:
            image = torch.unsqueeze(image.data, 0)
            if self.num_imgs < self.pool_size:   # if the buffer is not full; keep inserting current images to the buffer
                self.num_imgs = self.num_imgs + 1
                self.images.append(image)
                return_images.append(image)
            else:
                p = random.uniform(0, 1)
                if p > 0.5:  # 50%的概率返回以前生成的图像
                    random_id = random.randint(
                        0, self.pool_size - 1)  # randint is inclusive
                    tmp = self.images[random_id].clone()
                    self.images[random_id] = image  # 将新得到的图片存入缓存区
                    return_images.append(tmp)
                else:       # by another 50% chance, the buffer will return the current image
                    return_images.append(image)
        # collect all the images and return
        return_images = torch.cat(return_images, 0)
        return return_images

GLOBAL_SEED = 1
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
set_seed(GLOBAL_SEED)

# 两个结构链接区域使用Residual block模块,默认是9个重复模块
class ResnetBlock(nn.Module):
    """Define a Resnet block"""
    def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
        """Initialize the Resnet block
        """
        super(ResnetBlock, self).__init__()
        self.conv_block = self.build_conv_block(
            dim, padding_type, norm_layer, use_dropout, use_bias)

    def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
        """Construct a convolutional block."""
        conv_block = []
        p = 0
        if padding_type == 'reflect':
            conv_block += [nn.ReflectionPad2d(1)]

        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p,
                                 bias=use_bias), norm_layer(dim), nn.ReLU(True)]
        #根据经验得,dropout在卷积中一般没啥用
        if use_dropout:
            conv_block += [nn.Dropout(0.5)]
        p = 0
        if padding_type == 'reflect':
            conv_block += [nn.ReflectionPad2d(1)]

        conv_block += [nn.Conv2d(dim, dim, kernel_size=3,
                                 padding=p, bias=use_bias), norm_layer(dim)]
        return nn.Sequential(*conv_block)

    def forward(self, x):
        """Forward function (with skip connections)"""
        out = x + self.conv_block(x)  # add skip connections
        return out

# 使用Residual block的生成器
class ResnetGenerator(nn.Module):

    def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=9, padding_type='reflect'):
        """Construct a Resnet-based generator
        """
        super(ResnetGenerator, self).__init__()

        use_bias = norm_layer == nn.InstanceNorm2d

        model = [nn.ReflectionPad2d(3),
                nn.Conv2d(input_nc, ngf, kernel_size=7,padding=0, bias=use_bias),
                 norm_layer(ngf),
                 nn.ReLU(True)]

        n_downsampling = 2
        for i in range(n_downsampling):  # add downsampling layers
            mult = 2 ** i
            model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
                      norm_layer(ngf * mult * 2),
                        nn.ReLU(True)]

        mult = 2 ** n_downsampling
        for i in range(n_blocks):       # add ResNet blocks

            model += [ResnetBlock(ngf * mult, padding_type=padding_type,
                                  norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]

        for i in range(n_downsampling):  # add upsampling layers
            mult = 2 ** (n_downsampling - i)
            model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),kernel_size=3, stride=2,padding=1, output_padding=1,bias=use_bias),
                                 norm_layer(int(ngf * mult / 2)),
                                nn.ReLU(True)]
        model += [nn.ReflectionPad2d(3)]
        model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
        model += [nn.Tanh()]

        self.model = nn.Sequential(*model)
        init_weights(self.model)

    def forward(self, input):
        """Standard forward"""
        return self.model(input)

#马尔可夫判别器(PatchGAN),由卷积层构成,最后输出一个n*n的预测矩阵
class NLayerDiscriminator(nn.Module):
    """Defines a PatchGAN discriminator"""

    def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
        """Construct a PatchGAN discriminator"""
        super(NLayerDiscriminator, self).__init__()
        use_bias = norm_layer == nn.InstanceNorm2d

        kw = 4
        padw = 1
        sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw,stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
        nf_mult = 1
        nf_mult_prev = 1
        for n in range(1, n_layers):  # gradually increase the number of filters
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** n, 8)
            sequence += [
                nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,kernel_size=kw, stride=2, padding=padw, bias=use_bias),
                norm_layer(ndf * nf_mult),
                nn.LeakyReLU(0.2, True)
            ]

        nf_mult_prev = nf_mult
        nf_mult = min(2 ** n_layers, 8)
        sequence += [
            nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,kernel_size=kw, stride=1, padding=padw, bias=use_bias),
            norm_layer(ndf * nf_mult),
            nn.LeakyReLU(0.2, True)
        ]
        # output 1 channel prediction map
        sequence += [nn.Conv2d(ndf * nf_mult, 1,kernel_size=kw, stride=1, padding=padw)]
        self.model = nn.Sequential(*sequence)
        init_weights(self.model)

    def forward(self, input):
        """Standard forward."""
        return self.model(input)


# 学习率调度

def get_scheduler(optimizer):
    """Return a learning rate scheduler
        前100个epoch保持不变,后100个epoch线性衰减到0
    """
    def lambda_rule(epoch):
        lr_l = 1.0 - max(0, epoch + EPOCH_COUNT -N_EPOCHS) / float(N_EPOCHS_DECAY + 1)
        return lr_l
    scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
    return scheduler

def train_fn(disc_H, disc_Z, gen_H, gen_Z, loader, opt_disc, opt_gen, l1, mse):
    fake_H_pool = ImagePool(50)
    fake_Z_pool = ImagePool(50)
    H_reals = 0
    H_fakes = 0
    Z_reals = 0
    Z_fakes = 0
    loop = tqdm(loader, leave=True)

    for idx, data in enumerate(loop):
        zebra = data['B'].to(DEVICE)
        horse = data['A'].to(DEVICE)

        # Train Discriminators H and Z
        fake_horse = gen_H(zebra)
        fake_horse_train = fake_H_pool.query(fake_horse)
        D_H_real = disc_H(horse)
        D_H_fake = disc_H(fake_horse_train.detach())
        H_reals += D_H_real.mean().item()
        H_fakes += D_H_fake.mean().item()
        D_H_real_loss = mse(D_H_real, torch.ones_like(D_H_real))
        D_H_fake_loss = mse(D_H_fake, torch.zeros_like(D_H_fake))
        D_H_loss = D_H_real_loss + D_H_fake_loss

        fake_zebra = gen_Z(horse)
        fake_zebra_train = fake_Z_pool.query(fake_zebra)
        D_Z_real = disc_Z(zebra)
        D_Z_fake = disc_Z(fake_zebra_train.detach())
        Z_reals += D_Z_real.mean().item()
        Z_fakes += D_Z_fake.mean().item()
        D_Z_real_loss = mse(D_Z_real, torch.ones_like(D_Z_real))
        D_Z_fake_loss = mse(D_Z_fake, torch.zeros_like(D_Z_fake))
        D_Z_loss = D_Z_real_loss + D_Z_fake_loss

        # put it togethor
        D_loss = (D_H_loss + D_Z_loss)/2

        opt_disc.zero_grad()
        D_loss.backward()
        opt_disc.step()

        # Train Generators H and Z
        # adversarial loss for both generators
        D_H_fake = disc_H(fake_horse)
        D_Z_fake = disc_Z(fake_zebra)
        loss_G_H = mse(D_H_fake, torch.ones_like(D_H_fake))
        loss_G_Z = mse(D_Z_fake, torch.ones_like(D_Z_fake))

        # cycle loss
        cycle_zebra = gen_Z(fake_horse)
        cycle_horse = gen_H(fake_zebra)
        cycle_zebra_loss = l1(zebra, cycle_zebra)
        cycle_horse_loss = l1(horse, cycle_horse)

        # identity loss (remove these for efficiency if you set lambda_identity=0)
        identity_zebra = gen_Z(zebra)
        identity_horse = gen_H(horse)
        identity_zebra_loss = l1(zebra, identity_zebra)
        identity_horse_loss = l1(horse, identity_horse)

        # add all togethor
        G_loss = (
            loss_G_Z
            + loss_G_H
            + cycle_zebra_loss * LAMBDA_CYCLE
            + cycle_horse_loss * LAMBDA_CYCLE
            + identity_horse_loss * LAMBDA_IDENTITY
            + identity_zebra_loss * LAMBDA_IDENTITY
        )

        opt_gen.zero_grad()
        G_loss.backward()
        opt_gen.step()

        if idx % 200 == 0:
            save_image(fake_horse*0.5+0.5, f"train_images/horse_{idx}.png")
            save_image(fake_zebra*0.5+0.5, f"train_images/zebra_{idx}.png")

        loop.set_postfix(H_real=H_reals/(idx+1), H_fake=H_fakes /
                         (idx+1), D_real=Z_reals/(idx+1), D_fake=Z_fakes/(idx+1))

class CombineDataset(Dataset):
    def __init__(self, root_A, root_B, transform):
        self.root_A = root_A
        self.root_B = root_B
        self.transform = transform

        self.A_paths = os.listdir(root_A)
        self.B_paths = os.listdir(root_B)
        self.length_dataset = max(len(self.A_paths), len(self.B_paths))
        self.A_len = len(self.A_paths)
        self.B_len = len(self.B_paths)

    def __len__(self):
        return self.length_dataset

    def __getitem__(self, index):
        A_path = self.A_paths[index % self.A_len]
        B_path = self.B_paths[index % self.B_len]

        A_img = Image.open(self.root_A+A_path).convert("RGB")
        B_img = Image.open(self.root_B+B_path).convert("RGB")

        A = self.transform(A_img)
        B = self.transform(B_img)

        return {'A': A, 'B': B}

def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)


def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print("=> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file, map_location=DEVICE)
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

    # 修改学习率,使用当前的学习率
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr

dataset = CombineDataset(root_A=TRAIN_DIR+"/trainA/",
                         root_B=TRAIN_DIR+"/trainB/", transform=transforms)
data_loader = DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True
)
dataset_size = len(data_loader)
print('The number of training images = %d' % dataset_size)

disc_H = NLayerDiscriminator(input_nc=3).to(DEVICE)
disc_Z = NLayerDiscriminator(input_nc=3).to(DEVICE)
gen_Z = ResnetGenerator(input_nc=3, output_nc=3).to(DEVICE)
gen_H = ResnetGenerator(input_nc=3, output_nc=3).to(DEVICE)

opt_disc = optim.Adam(
    list(disc_H.parameters()) + list(disc_Z.parameters()),
    lr=LEARNING_RATE,
    betas=(0.5, 0.999),
)

opt_gen = optim.Adam(
    list(gen_Z.parameters()) + list(gen_H.parameters()),
    lr=LEARNING_RATE,
    betas=(0.5, 0.999),
)

scheduler_disc = get_scheduler(opt_disc)
scheduler_gen = get_scheduler(opt_gen)
L1 = nn.L1Loss()
mse = nn.MSELoss()

if LOAD_MODEL:
    load_checkpoint(
        CHECKPOINT_GEN_H, gen_H, opt_gen, LEARNING_RATE,
    )
    load_checkpoint(
        CHECKPOINT_GEN_Z, gen_Z, opt_gen, LEARNING_RATE,
    )
    load_checkpoint(
        CHECKPOINT_CRITIC_H, disc_H, opt_disc, LEARNING_RATE,
    )
    load_checkpoint(
        CHECKPOINT_CRITIC_Z, disc_Z, opt_disc, LEARNING_RATE,
    )

for epoch in range(EPOCH_COUNT, N_EPOCHS+N_EPOCHS_DECAY+1):

    train_fn(disc_H, disc_Z, gen_H, gen_Z,
             data_loader, opt_disc, opt_gen, L1, mse)
    scheduler_disc.step()
    scheduler_gen.step()
    if SAVE_MODEL:
        save_checkpoint(gen_H, opt_gen, filename=CHECKPOINT_GEN_H)
        save_checkpoint(gen_Z, opt_gen, filename=CHECKPOINT_GEN_Z)
        save_checkpoint(disc_H, opt_disc, filename=CHECKPOINT_CRITIC_H)
        save_checkpoint(disc_Z, opt_disc, filename=CHECKPOINT_CRITIC_Z)





 完整工程训练参数数据集若需要请留言。

你可能感兴趣的:(深度学习,人工智能,pytorch)