这部分主要参考:https://github.com/aladdinpersson/Machine-Learning-Collection/tree/master/ML/Pytorch/GANs/CycleGAN
import torch
import torch.nn as nn
class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels, down=True, use_act=True, **kwargs):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, padding_mode="reflect", **kwargs)
if down
else nn.ConvTranspose2d(in_channels, out_channels, **kwargs),
nn.InstanceNorm2d(out_channels),
nn.ReLU(inplace=True) if use_act else nn.Identity()
)
def forward(self,x):
return self.conv(x)
class ResidualBlock(nn.Module):
def __init__(self, channels):
super().__init__()
self.block = nn.Sequential(
ConvBlock(channels, channels, kernel_size=3, padding=1),
ConvBlock(channels, channels, use_act=False, kernel_size=3, padding=1),
)
def forward(self,x):
return x + self.block(x)
class Generator(nn.Module):
def __init__(self, img_channels, num_features=64, num_residuals=6):
super().__init__()
self.initial = nn.Sequential(
nn.Conv2d(img_channels, num_features, kernel_size=7, stride=1, padding=3, padding_mode="reflect"),
nn.InstanceNorm2d(num_features),
nn.ReLU(inplace=True)
)
self.down_blocks = nn.ModuleList(
[
ConvBlock(num_features,num_features*2, kernel_size=3, stride=2, padding=1),
ConvBlock(num_features*2, num_features*4, kernel_size=3, stride=2, padding=1)
]
)
self.residual_blocks = nn.Sequential(
*[ResidualBlock(num_features*4) for _ in range(num_residuals)]
)
self.up_blocks = nn.ModuleList(
[
ConvBlock(num_features*4, num_features*2, down=False, kernel_size=3, stride=2, padding=1, output_padding=1),
ConvBlock(num_features*2, num_features, down=False, kernel_size=3, stride=2, padding=1, output_padding=1)
]
)
self.last = nn.Conv2d(num_features, img_channels, kernel_size=7, stride=1, padding=3, padding_mode="reflect")
def forward(self,x):
x = self.initial(x)
for layer in self.down_blocks:
x = layer(x)
x = self.residual_blocks(x)
for layer in self.up_blocks:
x = layer(x)
x = self.last(x)
return torch.tanh(x)
def test():
img_channels = 3
img_size = 256
x = torch.randn((2,img_channels,img_size,256))
model = Generator(img_channels,num_residuals=6)
preds = model(x)
print(preds.shape)
if __name__ == "__main__":
test()
import torch
import torch.nn as nn
class Block(nn.Module):
def __init__(self, in_channels, out_channels, stride):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels,out_channels,4,stride,1,bias=True,padding_mode='reflect'),
nn.InstanceNorm2d(out_channels),
nn.LeakyReLU(0.2)
)
def forward(self,x):
return self.conv(x)
### Input image size: 3x256x256
class Discriminator(nn.Module):
def __init__(self, in_channels=3, features=[64,128,256,512]):
super().__init__()
self.initial = nn.Sequential(
nn.Conv2d(
in_channels,
features[0],
kernel_size=4,
stride=2,
padding=1,
padding_mode="reflect"
),
nn.LeakyReLU(0.2),
)
layers = []
in_channels = features[0]
for feature in features[1:]:
layers.append(Block(in_channels,feature,stride=1 if feature==features[-1] else 2))
in_channels = feature
layers.append(nn.Conv2d(in_channels, 1, kernel_size=4, stride=1, padding=1, padding_mode='reflect'))
self.model = nn.Sequential(*layers)
def forward(self,x):
x = self.initial(x)
x = self.model(x)
return torch.sigmoid(x)
def test():
x = torch.randn((5,3,256,256))
model = Discriminator()
preds = model(x)
print(preds.shape)
if __name__ == "__main__":
test()
from PIL import Image
import os
from torch.utils.data import Dataset
import numpy as np
class HorseZebraDataset(Dataset):
def __init__(self, root_zebra, root_horse, transform=None):
self.root_zebra = root_zebra
self.root_horse = root_horse
self.transform = transform
self.zebra_images = os.listdir(root_zebra)
self.horse_images = os.listdir(root_horse)
self.zebra_len = len(self.zebra_images)
self.horse_len = len(self.horse_images)
self.length_dataset = max(self.zebra_len, self.horse_len)
def __len__(self):
return self.length_dataset
def __getitem__(self, idx):
zebra_img = self.zebra_images[idx % self.zebra_len]
horse_img = self.horse_images[idx % self.horse_len]
zebra_path = os.path.join(self.root_zebra, zebra_img)
horse_path = os.path.join(self.root_horse, horse_img)
zebra_img = np.array(Image.open(zebra_path).convert("RGB"))
horse_img = np.array(Image.open(horse_path).convert("RGB"))
if self.transform:
augmentations = self.transform(image=zebra_img, image0=horse_img)
horse_img = augmentations["image0"]
zebra_img = augmentations["image"]
return {'A':horse_img , 'B': zebra_img}
这部分主要参考:https://github.com/aitorzip/PyTorch-CycleGAN
import argparse
import itertools
from torch.utils.data import DataLoader
from torch.autograd import Variable
import torch
import torch.nn as nn
import torch.optim as optim
from generator_model import Generator
from discriminator_model import Discriminator
from dataset import HorseZebraDataset
import albumentations as A
from albumentations.pytorch import ToTensorV2
from utils import ReplayBuffer
def main(opt):
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
### Load the data
# image pre-processing
transforms = A.Compose(
[
A.Resize(width=256, height=256),
A.HorizontalFlip(p=0.5),
A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
ToTensorV2(),
],
additional_targets={"image0": "image"},
)
datasets = HorseZebraDataset(root_horse=opt.data_root + "trainA", root_zebra=opt.data_root + "trainB",
transform=transforms)
loader = DataLoader(datasets, batch_size=opt.batch_size, shuffle=True, num_workers=4)
### Building the Network
netG_A2B = Generator(opt.input_nc).to(device)
netG_B2A = Generator(opt.input_nc).to(device)
netD_A = Discriminator(opt.input_nc).to(device)
netD_B = Discriminator(opt.input_nc).to(device)
# Losses
criterion_GAN = nn.MSELoss()
criterion_cycle = nn.L1Loss()
criterion_identity = nn.L1Loss()
# Optimizers & LR schedulers
optimizer_G = optim.Adam(itertools.chain(netG_A2B.parameters(), netG_B2A.parameters()),
lr=opt.lr, betas=(0.5, 0.999))
optimizer_D_A = optim.Adam(netD_A.parameters(), lr=opt.lr, betas=(0.5, 0.999))
optimizer_D_B = optim.Adam(netD_B.parameters(), lr=opt.lr, betas=(0.5, 0.999))
### Inputs & targets memory allocation
Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.Tensor
input_A = Tensor(opt.batch_size, opt.input_nc, opt.size, opt.size)
input_B = Tensor(opt.batch_size, opt.input_nc, opt.size, opt.size)
fake_A_buffer = ReplayBuffer()
fake_B_buffer = ReplayBuffer()
# Loss plot
for epoch in range(opt.n_epochs):
for idx, batch in enumerate(loader):
# set model input
real_A = Variable(input_A.copy_(batch['A']))
real_B = Variable(input_B.copy_(batch['B']))
### generate A2B and B2A ###
optimizer_G.zero_grad()
# Identity loss
# G_A2B(B) should equal B if real B is fed
same_B = netG_A2B(real_B)
loss_identity_B = criterion_identity(same_B, real_B)*5.0
# G_B2A(A) should equal A if real A is fed
same_A = netG_B2A(real_A)
loss_identity_A = criterion_identity(same_A, real_A)*5.0
# GAN loss
fake_B = netG_A2B(real_A)
pred_fake = netD_B(fake_B)
loss_GAN_A2B = criterion_GAN(pred_fake, torch.ones_like(pred_fake))
fake_A = netG_B2A(real_B)
pred_fake = netD_A(fake_A)
loss_GAN_B2A = criterion_GAN(pred_fake, torch.ones_like(pred_fake))
# Cycle loss
cycle_A = netG_B2A(fake_B)
cycle_B = netG_A2B(fake_A)
loss_cycle = criterion_cycle(cycle_A, real_A) + criterion_cycle(cycle_B, real_B)
loss_cycle *= 10.0
loss_G = loss_identity_A + loss_identity_B + loss_GAN_A2B + loss_GAN_B2A + loss_cycle
loss_G.backward()
optimizer_G.step()
### Discriminator A ###
optimizer_D_A.zero_grad()
# Real loss
pred_real = netD_A(real_A)
loss_D_real = criterion_GAN(pred_real, torch.ones_like(pred_real))
# Fake loss
fake_A = fake_A_buffer.push_and_pop(fake_A)
pred_fake = netD_A(fake_A.detach())
loss_D_fake = criterion_GAN(pred_fake, torch.zeros_like(pred_real))
# Total loss
loss_D_A = (loss_D_real + loss_D_fake)*0.5
loss_D_A.backward()
optimizer_D_A.step()
### Discriminator B ###
optimizer_D_B.zero_grad()
# Real loss
pred_real = netD_B(real_B)
loss_D_real = criterion_GAN(pred_real, torch.ones_like(pred_real))
# Fake loss
fake_B = fake_B_buffer.push_and_pop(fake_B)
pred_fake = netD_B(fake_B.detach())
loss_D_fake = criterion_GAN(pred_fake, torch.zeros_like(pred_real))
# Total loss
loss_D_B = (loss_D_real + loss_D_fake)*0.5
loss_D_B.backward()
optimizer_D_B.step()
if idx % 50 == 0:
print(
f"Epoch [{epoch}/{opt.n_epochs}] Batch {idx}/{len(loader)} \
Loss G: {loss_G:.4f}, loss_cycle: {loss_cycle:.4f}, loss_D_A: {loss_D_A:.4f},"
)
torch.save(netG_A2B.state_dict(),'./output/netG_A2B.pth')
torch.save(netG_B2A.state_dict(), './output/netG_B2A.pth')
torch.save(netD_A.state_dict(), './output/netD_A.pth')
torch.save(netD_B.state_dict(), './output/netD_B.pth')
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--n_epochs', type=int, default=200, help="number of epochs of training")
parser.add_argument('--batch_size', type=int, default=2, help="size of the batches")
parser.add_argument('--data_root', type=str, default='./data/horse2zebra/', help="root directory of the dataset")
parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate')
parser.add_argument('--size', type=int, default=256, help='size of data crop(squared assumed)')
parser.add_argument('--input_nc', type=int, default=3, help='number of channels of input data')
parser.add_argument('--output_nc', type=int, default=3, help='number of channels of output data')
opt = parser.parse_args()
print(opt)
main(opt)
import argparse
import sys
import os
from PIL import Image
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torch.autograd import Variable
import torch
import numpy as np
from generator_model import Generator
from dataset import HorseZebraDataset
parser = argparse.ArgumentParser()
parser.add_argument('--batchSize', type=int, default=1, help='size of the batches')
parser.add_argument('--dataroot', type=str, default='datasets/horse2zebra/', help='root directory of the dataset')
parser.add_argument('--input_nc', type=int, default=3, help='number of channels of input data')
parser.add_argument('--output_nc', type=int, default=3, help='number of channels of output data')
parser.add_argument('--size', type=int, default=256, help='size of the data (squared assumed)')
parser.add_argument('--n_cpu', type=int, default=1, help='number of cpu threads to use during batch generation')
parser.add_argument('--generator_A2B', type=str, default='./output/netG_A2B.pth', help='A2B generator checkpoint file')
parser.add_argument('--generator_B2A', type=str, default='./output/netG_B2A.pth', help='B2A generator checkpoint file')
opt = parser.parse_args()
print(opt)
###### Definition of variables ######
# Networks
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
netG_A2B = Generator(opt.input_nc).to(device)
netG_B2A = Generator(opt.output_nc).to(device)
# Load state dicts
netG_A2B.load_state_dict(torch.load(opt.generator_A2B))
netG_B2A.load_state_dict(torch.load(opt.generator_B2A))
# Set model's test mode
netG_A2B.eval()
netG_B2A.eval()
# Inputs & targets memory allocation
horse_path = "D:/d2l/CycleGAN/data/horse2zebra/trainA/n02381460_36.jpg"
horse_img = np.array(Image.open(horse_path).convert("RGB"))
zebra_path = "D:/d2l/CycleGAN/data/horse2zebra/trainB/n02391049_77.jpg"
zebra_img = np.array(Image.open(zebra_path).convert("RGB"))
# transforms.ToTensor()
transform = transforms.Compose([
transforms.ToTensor(), # range [0, 255] -> [0.0,1.0]
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
]
)
real_A = transform(horse_img).unsqueeze(0).cuda()
real_B = transform(zebra_img).unsqueeze(0).cuda()
fake_A = 0.5*(netG_B2A(real_B).data + 1.0)
fake_B = 0.5*(netG_A2B(real_A).data + 1.0)
out = fake_B.squeeze().cpu().numpy()
img_1 = np.transpose(out, (1,2,0))
out = fake_A.squeeze().cpu().numpy()
img_2 = np.transpose(out, (1,2,0))
import matplotlib.pyplot as plt
plt.subplot(221),plt.imshow(horse_img),plt.title("input image"),plt.axis("off")
plt.subplot(222),plt.imshow(img_1),plt.title("output image"),plt.axis("off")
plt.subplot(223),plt.imshow(zebra_img),plt.title("input image"),plt.axis("off")
plt.subplot(224),plt.imshow(img_2),plt.title("output image"),plt.axis("off")
训练好的网络