GANs(Generative Adversarial Networks ),全名又叫做生成式对抗网络,设计者使用的是一种类似于“左右手互博”的思想,所以GANs的作者周伯通(英文名:lan Goodfellow)在设计的时候遵循的就是这个原则。“左右手”分别指代的是GANs中的生成器(Generator)和判别器(Discriminator)。
图片来源于网络
生成器的主要作用就是随机生成一个指定格式的图片,判别器的主要作用是能够对输入的图片真假进行判断,下图就是GANs最原始的网络架构。
图片来源于网络
所以在GANs中重点需要实现的就是生成器和判别器,下面我们通过两种不同的方式对GANs进行实现,方法一中的生成器和判别器由简单的神经网络构成,方法二中生成器和判别器由卷积神经网络构成。
这里我们重点介绍生成器、判别器的实现以及如何定义模型的损失和优化,完整代码会在最后贴出来。首先是判别器,这里使用的网络架构比较简单,是输入层-隐藏层-输出层的三层结构。输入图像我们都知道MINST数据集的图片是28*28的,激活函数使用的LeakyReLU。
class Discriminator(torch.nn.Module):
def __init__(self):
super(Discriminator,self).__init__()
self.discriminator = torch.nn.Sequential(
torch.nn.Linear(28*28,128),
torch.nn.LeakyReLU(),
torch.nn.Linear(128,1)
)
def forward(self, input):
output = self.discriminator(input)
return output
然后是生成器,生成器通过输入一个指定大小的随机数生成出28*28的图片,最后我们生成器生成的图片越接近真实图片说明生成器的效果越好。
class Generator(torch.nn.Module):
def __init__(self):
super(Generator,self).__init__()
self.generator = torch.nn.Sequential(
torch.nn.Linear(100,128),
torch.nn.LeakyReLU(),
torch.nn.Linear(128,28*28),
torch.nn.Tanh()
)
def forward(self,input):
output = self.generator(input)
return output
生成我们生成器需要用到的随机数我们使用一个函数来定义。
def rand_img(batchsize,output_size):
Z = np.random.uniform(-1.,1., size=(batchsize, output_size))
Z = np.float32(Z)
Z = torch.from_numpy(Z)
Z = Variable(Z.cuda())
return Z
接下来是损失的定义,我们只要把握住两个原则,我们希望判别器对输入的真实图片全部判断为1,输入的虚假图片全部判断为0,同时对于生成器我们要求生产的图片输入到判别器后能够被判断为1。这就是GAN是的精髓,具体实现如下。
model_discriminator = Discriminator_conv().cuda()
model_generator = Generator_conv().cuda()
X_gen = model_generator(Z)
X_gen = X_gen.view(-1,1,28,28)
X_train = X_train.view(-1,1,28,28)
logits_real = model_discriminator(X_train)
logits_fake = model_discriminator(X_gen)
d_loss = loss_f(logits_real, torch.ones_like(logits_real))+loss_f(logits_fake, torch.zeros_like(logits_fake))
Z = rand_img(batchsize=batchsize, output_size=100)
X_gen = model_generator(Z)
X_gen = X_gen.view(-1,1,28,28)
logits_fake = model_discriminator(X_gen)
g_loss = loss_f(logits_fake,torch.ones_like(logits_fake))
我们通过训练减小d_loss来提升判别器的能力,同时又在训练减小g_loss来提升生产器的能力,这两个看似矛盾的方向却可以让整个模型取得非常好的效果。
使用卷积方式实现的GANs也被称作为DCGANs,卷积的实现最大的不同就是在模型的结构中加入了卷积的成分,当然最后效果相对前者会更加理想。
判别器,使用的是非常常用的卷积神经网络结构。
class Discriminator_conv(torch.nn.Module):
def __init__(self):
super(Discriminator_conv,self).__init__()
self.conv = torch.nn.Sequential(
torch.nn.Conv2d(1,32,kernel_size=5,stride=1),
torch.nn.LeakyReLU(),
torch.nn.MaxPool2d(kernel_size=2,stride=2),
torch.nn.Conv2d(32,64,kernel_size=5,stride=1),
torch.nn.LeakyReLU(),
torch.nn.MaxPool2d(kernel_size=2,stride=2)
)
self.dense = torch.nn.Sequential(
torch.nn.Linear(64*4*4,64*4*4),
torch.nn.LeakyReLU(),
torch.nn.Linear(64*4*4,1)
)
def forward(self, input):
output = self.conv(input)
output = output.view(-1,64*4*4)
output = self.dense(output)
return output
生成器,其中用到的一个逆向卷积的方法,公式如下:
class Generator_conv(torch.nn.Module):
def __init__(self):
super(Generator_conv,self).__init__()
self.conv_dense = torch.nn.Sequential(
torch.nn.Linear(100,1024),
torch.nn.LeakyReLU(),
torch.nn.BatchNorm1d(num_features=1024),
torch.nn.Linear(1024,7*7*128),
torch.nn.BatchNorm1d(num_features=7*7*128)
)
self.transpose_conv = torch.nn.Sequential(
torch.nn.ConvTranspose2d(128,64,kernel_size=4,stride=2,padding=1),
torch.nn.ReLU(),
torch.nn.BatchNorm2d(num_features=64),
torch.nn.ConvTranspose2d(64,1,kernel_size=4,stride=2,padding=1),
torch.nn.Tanh()
)
def forward(self, input):
output = self.conv_dense(input)
output = output.view(-1,128,7,7)
output = self.transpose_conv(output)
return output
最后我把模型训练1个epoch、10个epoch和20个epoch后得到的结果贴出来,可以看出我们的生成器已经可以生成同MINIST数据类似的图片了。
1个epoch
10个epoch
20个epoch
最后说几点小的诀窍。
1、我们可以将原来的d_loss改成如下形式。
d_loss = loss_f(logits_real, torch.ones_like(logits_real)*(1-smooth))+loss_f(logits_fake, torch.zeros_like(logits_fake))
通过乘上一个(1-smooth)的参数(其中smooth可以设为0.1-0.9)来防止判别器模型的过拟合。
2、通过改变降低优化函数的初始学习速率来降低生成器的g_loss。
optimizer_dis = torch.optim.Adam(model_discriminator.parameters(),lr=0.0001)
optimizer_gen = torch.optim.Adam(model_generator.parameters(),lr=0.0001)
3、构建更加深度的网络结构能够取得更好的结果,当然也会开销更多的训练时间。
非常全的GANs衍生模型
import torch
import torchvision
from torch.autograd import Variable
from torchvision import datasets,models,transforms
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline
%config InlineBackend.figure_format="retina"
epoch_n =20
batchsize = 128
smooth = 0.1
train_transform=transforms.ToTensor()
train_data = datasets.MNIST(root="data",download=True,train=True,transform=train_transform)
train_load = torch.utils.data.DataLoader(dataset=train_data,shuffle=True,batch_size=batchsize)
def plot_img(img):
img = torchvision.utils.make_grid(img)
img = img.numpy().transpose(1,2,0)
plt.figure(figsize=(12,9))
plt.imshow(img)
class Discriminator_conv(torch.nn.Module):
def __init__(self):
super(Discriminator_conv,self).__init__()
self.conv = torch.nn.Sequential(
torch.nn.Conv2d(1,32,kernel_size=5,stride=1),
torch.nn.LeakyReLU(),
torch.nn.MaxPool2d(kernel_size=2,stride=2),
torch.nn.Conv2d(32,64,kernel_size=5,stride=1),
torch.nn.LeakyReLU(),
torch.nn.MaxPool2d(kernel_size=2,stride=2)
)
self.dense = torch.nn.Sequential(
torch.nn.Linear(64*4*4,64*4*4),
torch.nn.LeakyReLU(),
torch.nn.Linear(64*4*4,1)
)
def forward(self, input):
output = self.conv(input)
output = output.view(-1,64*4*4)
output = self.dense(output)
return output
class Generator_conv(torch.nn.Module):
def __init__(self):
super(Generator_conv,self).__init__()
self.conv_dense = torch.nn.Sequential(
torch.nn.Linear(100,1024),
torch.nn.LeakyReLU(),
torch.nn.BatchNorm1d(num_features=1024),
torch.nn.Linear(1024,7*7*128),
torch.nn.BatchNorm1d(num_features=7*7*128)
)
self.transpose_conv = torch.nn.Sequential(
torch.nn.ConvTranspose2d(128,64,kernel_size=4,stride=2,padding=1),
torch.nn.ReLU(),
torch.nn.BatchNorm2d(num_features=64),
torch.nn.ConvTranspose2d(64,1,kernel_size=4,stride=2,padding=1),
torch.nn.Tanh()
)
def forward(self, input):
output = self.conv_dense(input)
output = output.view(-1,128,7,7)
output = self.transpose_conv(output)
return output
def initialize_weights(m):
if isinstance(m,torch.nn.Linear) or isinstance(m,torch.nn.Conv2d):
torch.nn.init.xavier_uniform_(m.weight.data)
model_discriminator = Discriminator_conv().cuda()
model_discriminator.apply(initialize_weights)
model_generator = Generator_conv().cuda()
model_generator.apply(initialize_weights)
loss_f = torch.nn.BCEWithLogitsLoss()
optimizer_dis = torch.optim.Adam(model_discriminator.parameters(),lr=0.0001)
optimizer_gen = torch.optim.Adam(model_generator.parameters(),lr=0.0001)
samples = []
losses = []
def rand_img(batchsize,output_size):
Z = np.random.uniform(-1.,1., size=(batchsize, output_size))
Z = np.float32(Z)
Z = torch.from_numpy(Z)
Z = Variable(Z.cuda())
return Z
for epoch in range(epoch_n):
for batch in train_load:
X_train,y_train = batch
X_train,y_train = Variable(X_train.cuda()),Variable(y_train.cuda())
#X_train,y_train = Variable(X_train),Variable(y_train)
Z = rand_img(batchsize=batchsize, output_size=100)
optimizer_dis.zero_grad()
X_gen = model_generator(Z)
X_gen = X_gen.view(-1,1,28,28)
X_train = X_train.view(-1,1,28,28)
logits_real = model_discriminator(X_train)
logits_fake = model_discriminator(X_gen)
d_loss = loss_f(logits_real, torch.ones_like(logits_real)*(1-smooth))+loss_f(logits_fake, torch.zeros_like(logits_fake))
d_loss.backward(retain_graph=True)
optimizer_dis.step()
optimizer_gen.zero_grad()
Z = rand_img(batchsize=batchsize, output_size=100)
X_gen = model_generator(Z)
X_gen = X_gen.view(-1,1,28,28)
logits_fake = model_discriminator(X_gen)
g_loss = loss_f(logits_fake,torch.ones_like(logits_fake))
g_loss.backward()
optimizer_gen.step()
print("Epoch{}/{}...".format(epoch+1, epoch_n),
"Discriminator Loss:{:.4f}...".format(d_loss),
"Generator Loss:{:.4f}...".format(g_loss))
losses.append((d_loss, g_loss))
fake_img = model_generator(Z)
samples.append(fake_img)
fig, ax = plt.subplots()
losses = np.array(losses)
plt.plot(losses.T[0], label='Discriminator')
plt.plot(losses.T[1], label='Generator')
plt.title("Training Losses")
plt.legend()
def to_img(img):
img = img.detach().cpu().data
img = img.clamp(0,1)
img = img.view(-1,1,28,28)
return img
for i in range(len(samples)):
img = to_img(samples[i])
plot_img(img)
https://zhuanlan.zhihu.com/p/40393929