GAN的训练数据是没有标签的,如果我们要做有标签的训练,则需要用到CGAN。
对于图像来说,我们既要让输出的图片真实,也要让输出的图片符合标签c。Discriminator输入便被改成了同时输入c和x,输出要做两件事情,一个是判断x是否是真实图片,另一个是x和c是否是匹配的。
在下面两个情况中,左边虽然输出图片清晰,但不符合c;右边输出图片不真实。因此两种情况中D的输出都会是0。
我们来看下简单的示例代码:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import torchvision
from torchvision import transforms
from torch.utils import data
import os
import glob
from PIL import Image
# 独热编码
# 输入x代表默认的torchvision返回的类比值,class_count类别值为10
def one_hot(x, class_count=10):
return torch.eye(class_count)[x, :] # 切片选取,第一维选取第x个,第二维全要
transform =transforms.Compose([transforms.ToTensor(),
transforms.Normalize(0.5, 0.5)])
dataset = torchvision.datasets.MNIST('data',
train=True,
transform=transform,
target_transform=one_hot,
download=False)
dataloader = data.DataLoader(dataset, batch_size=64, shuffle=True)
# 定义生成器
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.linear1 = nn.Linear(10, 128 * 7 * 7)
self.bn1 = nn.BatchNorm1d(128 * 7 * 7)
self.linear2 = nn.Linear(100, 128 * 7 * 7)
self.bn2 = nn.BatchNorm1d(128 * 7 * 7)
self.deconv1 = nn.ConvTranspose2d(256, 128,
kernel_size=(3, 3),
padding=1)
self.bn3 = nn.BatchNorm2d(128)
self.deconv2 = nn.ConvTranspose2d(128, 64,
kernel_size=(4, 4),
stride=2,
padding=1)
self.bn4 = nn.BatchNorm2d(64)
self.deconv3 = nn.ConvTranspose2d(64, 1,
kernel_size=(4, 4),
stride=2,
padding=1)
def forward(self, x1, x2):
x1 = F.relu(self.linear1(x1))
x1 = self.bn1(x1)
x1 = x1.view(-1, 128, 7, 7)
x2 = F.relu(self.linear2(x2))
x2 = self.bn2(x2)
x2 = x2.view(-1, 128, 7, 7)
x = torch.cat([x1, x2], axis=1)
x = F.relu(self.deconv1(x))
x = self.bn3(x)
x = F.relu(self.deconv2(x))
x = self.bn4(x)
x = torch.tanh(self.deconv3(x))
return x
# 定义判别器
# input:1,28,28的图片以及长度为10的condition
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.linear = nn.Linear(10, 1*28*28)
self.conv1 = nn.Conv2d(2, 64, kernel_size=3, stride=2)
self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=2)
self.bn = nn.BatchNorm2d(128)
self.fc = nn.Linear(128*6*6, 1) # 输出一个概率值
def forward(self, x1, x2):
x1 =F.leaky_relu(self.linear(x1))
x1 = x1.view(-1, 1, 28, 28)
x = torch.cat([x1, x2], axis=1)
x = F.dropout2d(F.leaky_relu(self.conv1(x)))
x = F.dropout2d(F.leaky_relu(self.conv2(x)))
x = self.bn(x)
x = x.view(-1, 128*6*6)
x = torch.sigmoid(self.fc(x))
return x
# 初始化模型
device = 'cuda' if torch.cuda.is_available() else 'cpu'
gen = Generator().to(device)
dis = Discriminator().to(device)
# 损失计算函数
loss_function = torch.nn.BCELoss()
# 定义优化器
d_optim = torch.optim.Adam(dis.parameters(), lr=1e-5)
g_optim = torch.optim.Adam(gen.parameters(), lr=1e-4)
# 定义可视化函数
def generate_and_save_images(model, epoch, label_input, noise_input):
predictions = np.squeeze(model(label_input, noise_input).cpu().numpy())
fig = plt.figure(figsize=(4, 4))
for i in range(predictions.shape[0]):
plt.subplot(4, 4, i + 1)
plt.imshow((predictions[i] + 1) / 2, cmap='gray')
plt.axis("off")
plt.show()
noise_seed = torch.randn(16, 100, device=device)
label_seed = torch.randint(0, 10, size=(16,))
label_seed_onehot = one_hot(label_seed).to(device)
print(label_seed)
# print(label_seed_onehot)
# 开始训练
D_loss = []
G_loss = []
# 训练循环
for epoch in range(150):
d_epoch_loss = 0
g_epoch_loss = 0
count = len(dataloader.dataset)
# 对全部的数据集做一次迭代
for step, (img, label) in enumerate(dataloader):
img = img.to(device)
label = label.to(device)
size = img.shape[0]
random_noise = torch.randn(size, 100, device=device)
d_optim.zero_grad()
real_output = dis(label, img)
d_real_loss = loss_function(real_output,
torch.ones_like(real_output, device=device)
)
d_real_loss.backward() #求解梯度
# 得到判别器在生成图像上的损失
gen_img = gen(label,random_noise)
fake_output = dis(label, gen_img.detach()) # 判别器输入生成的图片,f_o是对生成图片的预测结果
d_fake_loss = loss_function(fake_output,
torch.zeros_like(fake_output, device=device))
d_fake_loss.backward()
d_loss = d_real_loss + d_fake_loss
d_optim.step() # 优化
# 得到生成器的损失
g_optim.zero_grad()
fake_output = dis(label, gen_img)
g_loss = loss_function(fake_output,
torch.ones_like(fake_output, device=device))
g_loss.backward()
g_optim.step()
with torch.no_grad():
d_epoch_loss += d_loss.item()
g_epoch_loss += g_loss.item()
with torch.no_grad():
d_epoch_loss /= count
g_epoch_loss /= count
D_loss.append(d_epoch_loss)
G_loss.append(g_epoch_loss)
if epoch % 10 == 0:
print('Epoch:', epoch)
generate_and_save_images(gen, epoch, label_seed_onehot, noise_seed)
这里是尝试地址:https://affinelayer.com/pixsrv/
使用Pix2Pix神经网络模型实现论文中预定义的任务:黑白简笔画到彩图、平面房屋到立体房屋和航拍图到地图等功能:
Pix2pixgan本质上是一个cgan,图片x作为此cGAN的条件, 需要输入到G和D中。 G的输入是x(x 是需要转换的图片),输出是生成的图片G(x)。 D则需要分辨出{x,G(x)}和{x, y}。
这里的生成器模型我们采用U-Net:
在pix2pix中,作者就是把L1 loss 和GAN loss相结合使用,因为作者认为L1 loss 可以恢复图像的低频部分,而GAN loss可以恢复图像的高频部分。判别器使用patchGAN。
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.down1 = Downsample(3, 64)
self.down2 = Downsample(64, 128)
self.down3 = Downsample(128, 256)
self.down4 = Downsample(256, 512)
self.down5 = Downsample(512, 512)
self.down6 = Downsample(512, 512)
self.up1 = Upsample(512, 512)
self.up2 = Upsample(1024, 512)
self.up3 = Upsample(1024, 256)
self.up4 = Upsample(512, 128)
self.up5 = Upsample(256, 64)
self.last = nn.ConvTranspose2d(128, 3,
kernel_size=3,
stride=2,
padding=1,
output_padding=1)
def forward(self, x):
x1 = self.down1(x, is_bn=False) # torch.Size([8, 64, 128, 128])
x2 = self.down2(x1) # torch.Size([8, 128, 64, 64])
x3 = self.down3(x2) # torch.Size([8, 256, 32, 32])
x4 = self.down4(x3) # torch.Size([8, 512, 16, 16])
x5 = self.down5(x4) # torch.Size([8, 512, 8, 8])
x6 = self.down6(x5) # torch.Size([8, 512, 4, 4])
x6 = self.up1(x6, is_drop=True) # torch.Size([8, 512, 8, 8])
x6 = torch.cat([x5, x6], dim=1) # torch.Size([8, 1024, 8, 8])
x6 = self.up2(x6, is_drop=True) # torch.Size([8, 512, 16, 16])
x6 = torch.cat([x4, x6], dim=1) # torch.Size([8, 1024, 16, 16])
x6 = self.up3(x6, is_drop=True)
x6 = torch.cat([x3, x6], dim=1)
x6 = self.up4(x6)
x6 = torch.cat([x2, x6], dim=1)
x6 = self.up5(x6)
x6 = torch.cat([x1, x6], dim=1)
x6 = torch.tanh(self.last(x6))
return x6
# 判别器
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator,self).__init__()
self.down1 = Downsample(6, 64)
self.down2 = Downsample(64, 128)
self.down3 = Downsample(128, 256)
self.conv = nn.Conv2d(256, 512, 3, 1, 1)
self.bn = nn.BatchNorm2d(512)
self.last = nn.Conv2d(512, 1, 3, 1)
def forward(self, anno, img):
x = torch.cat([anno, img], dim=1) # batch*6*H*W
x = self.down1(x, is_bn=False)
x = self.down2(x)
x = F.dropout2d(self.down3(x))
x = F.dropout2d(F.leaky_relu(self.conv(x)))
x = F.dropout2d(self.bn(x))
x = torch.sigmoid(self.last(x))
return x
在pix2pix的基础上,增加了一个“从糙到精生成器(coarse-to-fine generator)”、一个多尺度鉴别器架构和一个健壮的对抗学习目标函数。
1)生成器部分提高分辨率:将生成器U-net拆分成两个子网络G1和G2进行训练:前者输入和输出的分辨率保持一致(如 1024 x 512),后者输出尺寸(2048x1024)是输入尺寸(1024x512)的4倍(长宽各两倍)。如果想要得到更高分辨率的图像,只需要增加更多的局部增强网络即可(如 G={G1,G2,G3})
2)判别器部分将深度改为宽度:使用三个相同结构的判别器,分别处理不同尺寸的输入。
3)损失函数更稳健:除了PatchGAN的损失,还加上了样本与GT使用判别器网络和VGG16网络提取特征后进行的Element-wise loss
4)输入加入高频特征向量,例如图像的边缘信息,与输入的语义标签连接到一起作为输入。
5)额外学习一个Feature encoder网络,可以将原图转化为features,用来控制图像的颜色、纹理信息。
pix2pixGAN有一个明显的缺点就是,在进行训练的时候必须提供成对的数据集。比如当我们想生成梵高风格的画时,梵高本人画的作品肯定是相对较少的,这个时候就可以考虑使用cycleGAN。cycleGAN适用于非配对的图像到图像转换:
其原理可以概括为将一类图片转成成另一类图片,比如,现有两个样本空间X、Y,我们希望把X空间中的样本转换成Y空间中的样本。这种转换只是风格上的转换,实际X Y 的内容是不一样的。实际的目标就是学习从X到Y的映射,假设该映射为F,它就对应着GAN中的生成器,F就可以将X中的图片A转换为Y中的图片F(A)。
为了实现这个过程,我们需要两个生成器 G_AB 和 G_BA:
首先是单向loss的组成:
判别 loss: 判别器 D_B 是用来判断输入的图片是否是真实的 B 图片,这个流程和GAN是一致的。
生成 loss:生成器用来重建图片 a,目的是希望生成的图片 G_BA(G_AB(a)) 和原图 a 尽可能的相似,那么可以很简单的采取 L1 loss 或者 L2 loss。除了GAN loss,还包含如下loss:
① cycle-loss:也就是循环一致损失。因为网络需要保证生成的图像必须保留有原 始图像的特性,所以如果我们使用生成器GA-B生成一张假图像,那么要能够使用另外一个生成器 GB-A来努力恢复成原始图像。此过程必须满足循环一致性
② 等价loss:我们要求 G A B ( b ) = b G_{AB}(b)=b GAB(b)=b,以及 G B A ( a ) = a G_{BA}(a)=a GBA(a)=a。
下面来看下示例代码:
获取苹果橙子数据:
# 加载训练数据
apples_path = glob.glob('data/trainA/*.jpg')
oranges_path = glob.glob('data/trainB/*.jpg')
transform = transforms.Compose([transforms.ToTensor(), # 0-1归一化
transforms.Normalize(0.5, 0.5), # -1,1])
class AppleOrangeDataset(data.Dataset):
def __init__(self, img_path):
self.img_path = img_path
def __getitem__(self, index):
img_path = self.img_path[index]
pil_img = Image.open(img_path)
pil_img = transform(pil_img)
return pil_img
def __len__(self):
return len(self.img_path)
apple_dataset = AppleOrangeDataset(apples_path)
oranges_dataset = AppleOrangeDataset(oranges_path)
基于Unet结构定义上 / 下采样模块,接着定义生成器:
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.down1 = Downsample(3, 64)
self.down2 = Downsample(64, 128)
self.down3 = Downsample(128, 256)
self.down4 = Downsample(256, 512)
self.down5 = Downsample(512, 512)
self.down6 = Downsample(512, 512)
self.up1 = Upsample(512, 512)
self.up2 = Upsample(1024, 512)
self.up3 = Upsample(1024, 256)
self.up4 = Upsample(512, 128)
self.up5 = Upsample(256, 64)
self.last = nn.ConvTranspose2d(128, 3,
kernel_size=3,
stride=2,
padding=1,
output_padding=1)
def forward(self, x):
x1 = self.down1(x, is_bn=False) # torch.Size([8, 64, 128, 128])
x2 = self.down2(x1) # torch.Size([8, 128, 64, 64])
x3 = self.down3(x2) # torch.Size([8, 256, 32, 32])
x4 = self.down4(x3) # torch.Size([8, 512, 16, 16])
x5 = self.down5(x4) # torch.Size([8, 512, 8, 8])
x6 = self.down6(x5) # torch.Size([8, 512, 4, 4])
x6 = self.up1(x6, is_drop=True) # torch.Size([8, 512, 8, 8])
x6 = torch.cat([x5, x6], dim=1) # torch.Size([8, 1024, 8, 8])
x6 = self.up2(x6, is_drop=True) # torch.Size([8, 512, 16, 16])
x6 = torch.cat([x4, x6], dim=1) # torch.Size([8, 1024, 16, 16])
x6 = self.up3(x6, is_drop=True)
x6 = torch.cat([x3, x6], dim=1)
x6 = self.up4(x6)
x6 = torch.cat([x2, x6], dim=1)
x6 = self.up5(x6)
x6 = torch.cat([x1, x6], dim=1)
x6 = torch.tanh(self.last(x6))
return x6
接下来是鉴别器:
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.down1 = Downsample(3, 64) # 128
self.down2 = Downsample(64, 128) # 64
self.last = nn.Conv2d(128, 1, 3)
def forward(self, img):
x = self.down1(img)
x = self.down2(x)
x = torch.sigmoid(self.last(x))
return x
我们需要定义两个生成器和两个鉴别器:
gen_AB = Generator().to(device)
gen_BA = Generator().to(device)
dis_A = Discriminator().to(device)
dis_B = Discriminator().to(device)
# 同时对两个生成器进行优化
gen_optimizer = torch.optim.Adam(itertools.chain(gen_AB.parameters(), gen_BA.parameters()),
lr=2e-4, betas=(0.5, 0.999))
dis_A_optimizer = torch.optim.Adam(dis_A.parameters(), lr=2e-4, betas=(0.5, 0.999))
dis_B_optimizer = torch.optim.Adam(dis_B.parameters(), lr=2e-4, betas=(0.5, 0.999))
训练过程如下:
D_loss = [] # 记录训练过程中判别器loss变化
G_loss = [] # 记录训练过程中生成器loss变化
# 开始训练
for epoch in range(50):
D_epoch_loss = 0
G_epoch_loss = 0
for step, (real_A, real_B) in enumerate(zip(apples_dl, oranges_dl)):
# GAN 训练
gen_optimizer.zero_grad()
# identity loss
same_B = gen_AB(real_B)
identity_B_loss = l1loss_fn(same_B, real_B)
same_A = gen_BA(real_A)
identity_A_loss = l1loss_fn(same_A, real_A)
# GAN loss
fake_B = gen_AB(real_A)
D_pred_fake_B = dis_B(fake_B)
gan_loss_AB = bceloss_fn(D_pred_fake_B,
torch.ones_like(D_pred_fake_B, device=device))
fake_A = gen_BA(real_B)
D_pred_fake_A = dis_A(fake_A)
gan_loss_BA = bceloss_fn(D_pred_fake_A,
torch.ones_like(D_pred_fake_A, device=device))
# cycle consistanse loss
recovered_A = gen_BA(fake_B)
cycle_loss_ABA = l1loss_fn(recovered_A, real_A)
recovered_B = gen_AB(fake_A)
cycle_loss_BAB = l1loss_fn(recovered_B, real_B)
# total_loss
g_loss = (identity_B_loss + identity_A_loss + gan_loss_AB + gan_loss_BA
+ cycle_loss_ABA + cycle_loss_BAB)
g_loss.backward()
gen_optimizer.step()
# dis_A 训练
dis_A_optimizer.zero_grad()
dis_A_real_output = dis_A(real_A) # 判别器输入真实图片
dis_A_real_loss = bceloss_fn(dis_A_real_output,
torch.ones_like(dis_A_real_output, device=device))
dis_A_fake_output = dis_A(fake_A.detach()) # 判别器输入生成图片
dis_A_fake_loss = bceloss_fn(dis_A_fake_output,
torch.zeros_like(dis_A_fake_output, device=device))
dis_A_loss = (dis_A_real_loss + dis_A_fake_loss) * 0.5
dis_A_loss.backward()
dis_A_optimizer.step()
# dis_B 训练
dis_B_optimizer.zero_grad()
dis_B_real_output = dis_B(real_B) # 判别器输入真实图片
dis_B_real_loss = bceloss_fn(dis_B_real_output,
torch.ones_like(dis_B_real_output, device=device))
dis_B_fake_output = dis_B(fake_B.detach()) # 判别器输入生成图片
dis_B_fake_loss = bceloss_fn(dis_B_fake_output,
torch.zeros_like(dis_B_fake_output, device=device))
dis_B_loss = (dis_B_real_loss + dis_B_fake_loss) * 0.5
dis_B_loss.backward()
dis_B_optimizer.step()