CycleGAN Pytorch code

网络效果

CycleGAN Pytorch code_第1张图片

网络结构

这部分主要参考:https://github.com/aladdinpersson/Machine-Learning-Collection/tree/master/ML/Pytorch/GANs/CycleGAN

Generator

import torch
import torch.nn as nn

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, down=True, use_act=True, **kwargs):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, padding_mode="reflect", **kwargs)
            if down
            else nn.ConvTranspose2d(in_channels, out_channels, **kwargs),
            nn.InstanceNorm2d(out_channels),
            nn.ReLU(inplace=True) if use_act else nn.Identity()
        )

    def forward(self,x):
        return self.conv(x)


class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            ConvBlock(channels, channels, kernel_size=3, padding=1),
            ConvBlock(channels, channels, use_act=False, kernel_size=3, padding=1),
        )

    def forward(self,x):
        return x + self.block(x)


class Generator(nn.Module):
    def __init__(self, img_channels, num_features=64, num_residuals=6):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(img_channels, num_features, kernel_size=7, stride=1, padding=3, padding_mode="reflect"),
            nn.InstanceNorm2d(num_features),
            nn.ReLU(inplace=True)
        )

        self.down_blocks = nn.ModuleList(
            [
                ConvBlock(num_features,num_features*2, kernel_size=3, stride=2, padding=1),
                ConvBlock(num_features*2, num_features*4, kernel_size=3, stride=2, padding=1)
            ]
        )

        self.residual_blocks = nn.Sequential(
            *[ResidualBlock(num_features*4) for _ in range(num_residuals)]
        )

        self.up_blocks = nn.ModuleList(
            [
                ConvBlock(num_features*4, num_features*2, down=False, kernel_size=3, stride=2, padding=1, output_padding=1),
                ConvBlock(num_features*2, num_features, down=False, kernel_size=3, stride=2, padding=1, output_padding=1)
            ]
        )

        self.last = nn.Conv2d(num_features, img_channels, kernel_size=7, stride=1, padding=3, padding_mode="reflect")

    def forward(self,x):
        x = self.initial(x)
        for layer in self.down_blocks:
            x = layer(x)
        x = self.residual_blocks(x)
        for layer in self.up_blocks:
            x = layer(x)
        x = self.last(x)

        return torch.tanh(x)

def test():
    img_channels = 3
    img_size = 256
    x = torch.randn((2,img_channels,img_size,256))
    model = Generator(img_channels,num_residuals=6)
    preds = model(x)
    print(preds.shape)

if __name__ == "__main__":
    test()

Discriminator

import torch
import torch.nn as nn


class Block(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels,out_channels,4,stride,1,bias=True,padding_mode='reflect'),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU(0.2)
        )

    def forward(self,x):
        return self.conv(x)


### Input image size: 3x256x256
class Discriminator(nn.Module):
    def __init__(self, in_channels=3, features=[64,128,256,512]):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(
                in_channels,
                features[0],
                kernel_size=4,
                stride=2,
                padding=1,
                padding_mode="reflect"
            ),
            nn.LeakyReLU(0.2),
        )

        layers = []
        in_channels = features[0]
        for feature in features[1:]:
            layers.append(Block(in_channels,feature,stride=1 if feature==features[-1] else 2))
            in_channels = feature
        layers.append(nn.Conv2d(in_channels, 1, kernel_size=4, stride=1, padding=1, padding_mode='reflect'))

        self.model = nn.Sequential(*layers)


    def forward(self,x):
        x = self.initial(x)
        x = self.model(x)
        return torch.sigmoid(x)



def test():
    x = torch.randn((5,3,256,256))
    model = Discriminator()
    preds = model(x)
    print(preds.shape)

if __name__ == "__main__":
    test()


Dataset

from PIL import Image
import os
from torch.utils.data import Dataset
import numpy as np


class HorseZebraDataset(Dataset):
    def __init__(self, root_zebra, root_horse, transform=None):
        self.root_zebra = root_zebra
        self.root_horse = root_horse
        self.transform = transform

        self.zebra_images = os.listdir(root_zebra)
        self.horse_images = os.listdir(root_horse)

        self.zebra_len = len(self.zebra_images)
        self.horse_len = len(self.horse_images)
        self.length_dataset = max(self.zebra_len, self.horse_len)

    def __len__(self):
        return self.length_dataset

    def __getitem__(self, idx):
        zebra_img = self.zebra_images[idx % self.zebra_len]
        horse_img = self.horse_images[idx % self.horse_len]

        zebra_path = os.path.join(self.root_zebra, zebra_img)
        horse_path = os.path.join(self.root_horse, horse_img)

        zebra_img = np.array(Image.open(zebra_path).convert("RGB"))
        horse_img = np.array(Image.open(horse_path).convert("RGB"))

        if self.transform:
            augmentations = self.transform(image=zebra_img, image0=horse_img)

            horse_img = augmentations["image0"]
            zebra_img = augmentations["image"]

        return {'A':horse_img , 'B': zebra_img}

网络训练

这部分主要参考:https://github.com/aitorzip/PyTorch-CycleGAN

Training

import argparse
import itertools
from torch.utils.data import DataLoader
from torch.autograd import Variable
import torch
import torch.nn as nn
import torch.optim as optim

from generator_model import Generator
from discriminator_model import Discriminator

from dataset import HorseZebraDataset
import albumentations as A
from albumentations.pytorch import ToTensorV2
from utils import ReplayBuffer

def main(opt):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    ### Load the data
    # image pre-processing
    transforms = A.Compose(
        [
            A.Resize(width=256, height=256),
            A.HorizontalFlip(p=0.5),
            A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
            ToTensorV2(),
        ],
        additional_targets={"image0": "image"},
    )

    datasets = HorseZebraDataset(root_horse=opt.data_root + "trainA", root_zebra=opt.data_root + "trainB",
                                 transform=transforms)


    loader = DataLoader(datasets, batch_size=opt.batch_size, shuffle=True, num_workers=4)

    ### Building the Network
    netG_A2B = Generator(opt.input_nc).to(device)
    netG_B2A = Generator(opt.input_nc).to(device)

    netD_A = Discriminator(opt.input_nc).to(device)
    netD_B = Discriminator(opt.input_nc).to(device)

    # Losses
    criterion_GAN = nn.MSELoss()
    criterion_cycle = nn.L1Loss()
    criterion_identity = nn.L1Loss()

    # Optimizers & LR schedulers
    optimizer_G = optim.Adam(itertools.chain(netG_A2B.parameters(), netG_B2A.parameters()),
                             lr=opt.lr, betas=(0.5, 0.999))

    optimizer_D_A = optim.Adam(netD_A.parameters(), lr=opt.lr, betas=(0.5, 0.999))
    optimizer_D_B = optim.Adam(netD_B.parameters(), lr=opt.lr, betas=(0.5, 0.999))

    ### Inputs & targets memory allocation
    Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.Tensor
    input_A = Tensor(opt.batch_size, opt.input_nc, opt.size, opt.size)
    input_B = Tensor(opt.batch_size, opt.input_nc, opt.size, opt.size)

    fake_A_buffer = ReplayBuffer()
    fake_B_buffer = ReplayBuffer()
    # Loss plot


    for epoch in range(opt.n_epochs):
        for idx, batch in enumerate(loader):
            # set model input
            real_A = Variable(input_A.copy_(batch['A']))
            real_B = Variable(input_B.copy_(batch['B']))


            ### generate A2B and B2A ###
            optimizer_G.zero_grad()

            # Identity loss
            # G_A2B(B) should equal B if real B is fed
            same_B = netG_A2B(real_B)
            loss_identity_B = criterion_identity(same_B, real_B)*5.0
            # G_B2A(A) should equal A if real A is fed
            same_A = netG_B2A(real_A)
            loss_identity_A = criterion_identity(same_A, real_A)*5.0

            # GAN loss
            fake_B = netG_A2B(real_A)
            pred_fake = netD_B(fake_B)
            loss_GAN_A2B = criterion_GAN(pred_fake, torch.ones_like(pred_fake))

            fake_A = netG_B2A(real_B)
            pred_fake = netD_A(fake_A)
            loss_GAN_B2A = criterion_GAN(pred_fake, torch.ones_like(pred_fake))


            # Cycle loss
            cycle_A = netG_B2A(fake_B)
            cycle_B = netG_A2B(fake_A)

            loss_cycle = criterion_cycle(cycle_A, real_A) + criterion_cycle(cycle_B, real_B)
            loss_cycle *= 10.0

            loss_G = loss_identity_A + loss_identity_B + loss_GAN_A2B + loss_GAN_B2A + loss_cycle
            loss_G.backward()

            optimizer_G.step()


            ### Discriminator A ###
            optimizer_D_A.zero_grad()
            # Real loss
            pred_real = netD_A(real_A)
            loss_D_real = criterion_GAN(pred_real, torch.ones_like(pred_real))

            # Fake loss
            fake_A = fake_A_buffer.push_and_pop(fake_A)
            pred_fake = netD_A(fake_A.detach())
            loss_D_fake = criterion_GAN(pred_fake, torch.zeros_like(pred_real))

            # Total loss
            loss_D_A = (loss_D_real + loss_D_fake)*0.5
            loss_D_A.backward()

            optimizer_D_A.step()

            ### Discriminator B ###
            optimizer_D_B.zero_grad()


            # Real loss
            pred_real = netD_B(real_B)
            loss_D_real = criterion_GAN(pred_real, torch.ones_like(pred_real))

            # Fake loss
            fake_B = fake_B_buffer.push_and_pop(fake_B)
            pred_fake = netD_B(fake_B.detach())
            loss_D_fake = criterion_GAN(pred_fake, torch.zeros_like(pred_real))

            # Total loss
            loss_D_B = (loss_D_real + loss_D_fake)*0.5
            loss_D_B.backward()

            optimizer_D_B.step()

            if idx % 50 == 0:
                print(
                    f"Epoch [{epoch}/{opt.n_epochs}] Batch {idx}/{len(loader)} \
                    Loss G: {loss_G:.4f}, loss_cycle: {loss_cycle:.4f}, loss_D_A: {loss_D_A:.4f},"
                )


        torch.save(netG_A2B.state_dict(),'./output/netG_A2B.pth')
        torch.save(netG_B2A.state_dict(), './output/netG_B2A.pth')
        torch.save(netD_A.state_dict(), './output/netD_A.pth')
        torch.save(netD_B.state_dict(), './output/netD_B.pth')



if __name__ == "__main__":

    parser = argparse.ArgumentParser()

    parser.add_argument('--n_epochs', type=int, default=200, help="number of epochs of training")
    parser.add_argument('--batch_size', type=int, default=2, help="size of the batches")
    parser.add_argument('--data_root', type=str, default='./data/horse2zebra/', help="root directory of the dataset")
    parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate')
    parser.add_argument('--size', type=int, default=256, help='size of data crop(squared assumed)')
    parser.add_argument('--input_nc', type=int, default=3, help='number of channels of input data')
    parser.add_argument('--output_nc', type=int, default=3, help='number of channels of output data')
    opt = parser.parse_args()
    print(opt)

    main(opt)

Predict

import argparse
import sys
import os
from PIL import Image
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torch.autograd import Variable
import torch
import numpy as np
from generator_model import Generator
from dataset import HorseZebraDataset

parser = argparse.ArgumentParser()
parser.add_argument('--batchSize', type=int, default=1, help='size of the batches')
parser.add_argument('--dataroot', type=str, default='datasets/horse2zebra/', help='root directory of the dataset')
parser.add_argument('--input_nc', type=int, default=3, help='number of channels of input data')
parser.add_argument('--output_nc', type=int, default=3, help='number of channels of output data')
parser.add_argument('--size', type=int, default=256, help='size of the data (squared assumed)')
parser.add_argument('--n_cpu', type=int, default=1, help='number of cpu threads to use during batch generation')
parser.add_argument('--generator_A2B', type=str, default='./output/netG_A2B.pth', help='A2B generator checkpoint file')
parser.add_argument('--generator_B2A', type=str, default='./output/netG_B2A.pth', help='B2A generator checkpoint file')
opt = parser.parse_args()
print(opt)


###### Definition of variables ######
# Networks
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

netG_A2B = Generator(opt.input_nc).to(device)
netG_B2A = Generator(opt.output_nc).to(device)


# Load state dicts
netG_A2B.load_state_dict(torch.load(opt.generator_A2B))
netG_B2A.load_state_dict(torch.load(opt.generator_B2A))

# Set model's test mode
netG_A2B.eval()
netG_B2A.eval()

# Inputs & targets memory allocation
horse_path = "D:/d2l/CycleGAN/data/horse2zebra/trainA/n02381460_36.jpg"
horse_img = np.array(Image.open(horse_path).convert("RGB"))

zebra_path = "D:/d2l/CycleGAN/data/horse2zebra/trainB/n02391049_77.jpg"
zebra_img = np.array(Image.open(zebra_path).convert("RGB"))

# transforms.ToTensor()
transform = transforms.Compose([
    transforms.ToTensor(), # range [0, 255] -> [0.0,1.0]
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ]
)


real_A = transform(horse_img).unsqueeze(0).cuda()
real_B = transform(zebra_img).unsqueeze(0).cuda()


fake_A = 0.5*(netG_B2A(real_B).data + 1.0)
fake_B = 0.5*(netG_A2B(real_A).data + 1.0)

out = fake_B.squeeze().cpu().numpy()
img_1 = np.transpose(out, (1,2,0))

out = fake_A.squeeze().cpu().numpy()
img_2 = np.transpose(out, (1,2,0))

import matplotlib.pyplot as plt
plt.subplot(221),plt.imshow(horse_img),plt.title("input image"),plt.axis("off")
plt.subplot(222),plt.imshow(img_1),plt.title("output image"),plt.axis("off")
plt.subplot(223),plt.imshow(zebra_img),plt.title("input image"),plt.axis("off")
plt.subplot(224),plt.imshow(img_2),plt.title("output image"),plt.axis("off")

训练好的网络

训练好的网络

你可能感兴趣的:(深度学习,Python学习,pytorch,深度学习,python)