这种训练方式定义了一种全新的网络结构,就是生成对抗网络,也就是 GANs。
根据这个名字就可以知道这个网络是由两部分组成的,第一部分是生成,第二部分是对抗。简单来说,就是有一个生成网络和一个判别网络,通过训练让两个网络相互竞争,生成网络来生成假的数据,对抗网络通过判别器去判别真伪,最后希望生成器生成的数据能够以假乱真。
判别网络的结构非常简单,就是一个二分类器,结构如下:
其中 leakyrelu 是指 f(x) = max( x, x)
def discriminator():
net = nn.Sequential(
nn.Linear(784, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 1)
)
return net
接下来我们看看生成网络,生成网络的结构也很简单,就是根据一个随机噪声生成一个和数据维度一样的张量,结构如下:
def generator(noise_dim=NOISE_DIM):
net = nn.Sequential(
nn.Linear(noise_dim, 1024),
nn.ReLU(True),
nn.Linear(1024, 1024),
nn.ReLU(True),
nn.Linear(1024, 784),
nn.Tanh()
)
return net
接下来是两个网络的loss,对于判别网络来说,我们需要让它的判断越来越好,所以我们需要用真实数据和1做loss,假的数据和0做loss。
bce_loss = nn.BCEWithLogitsLoss()
def discriminator_loss(logits_real, logits_fake): # 判别器的 loss
size = logits_real.shape[0]
true_labels = Variable(torch.ones(size, 1)).float().cuda() #全1的tensor
false_labels = Variable(torch.zeros(size, 1)).float().cuda() #全0的tensor
loss = bce_loss(logits_real, true_labels) + bce_loss(logits_fake, false_labels)
return loss
而生成网络我们需要让他生成的假数据接近真实的数据,所以将生成的假数据进入我们训练好的判别器得到分数并和全1的tensor做loss。
def generator_loss(logits_fake): # 生成器的 loss
size = logits_fake.shape[0]
true_labels = Variable(torch.ones(size, 1)).float().cuda()
loss = bce_loss(logits_fake, true_labels)
return loss
优化函数。
# 使用 adam 来进行训练,学习率是 3e-4, beta1 是 0.5, beta2 是 0.999
def get_optimizer(net):
optimizer = torch.optim.Adam(net.parameters(), lr=3e-4, betas=(0.5, 0.999))
return optimizer
def train_a_gan(D_net, G_net, D_optimizer, G_optimizer, discriminator_loss, generator_loss, show_every=250,
noise_size=96, num_epochs=10):
iter_count = 0
for epoch in range(num_epochs):
for x, _ in train_data:
bs = x.shape[0]
# 判别网络
real_data = Variable(x).view(bs, -1).cuda() # 真实数据
logits_real = D_net(real_data) # 判别网络得分
sample_noise = (torch.rand(bs, noise_size) - 0.5) / 0.5 # -1 ~ 1 的均匀分布
g_fake_seed = Variable(sample_noise).cuda()
fake_images = G_net(g_fake_seed) # 生成的假的数据
logits_fake = D_net(fake_images) # 判别网络得分
d_total_error = discriminator_loss(logits_real, logits_fake) # 判别器的 loss
D_optimizer.zero_grad()
d_total_error.backward()
D_optimizer.step() # 优化判别网络
# 生成网络
g_fake_seed = Variable(sample_noise).cuda()
fake_images = G_net(g_fake_seed) # 生成的假的数据
gen_logits_fake = D_net(fake_images) # 将假的数据在判别器得到分数
g_error = generator_loss(gen_logits_fake) # 生成网络的 loss
G_optimizer.zero_grad()
g_error.backward()
G_optimizer.step() # 优化生成网络
if (iter_count % show_every == 0):
print('Iter: {}, D: {:.4}, G:{:.4}'.format(iter_count, d_total_error.data[0], g_error.data[0]))
imgs_numpy = deprocess_img(fake_images.data.cpu().numpy())
show_images(imgs_numpy[0:16])
plt.show()
print()
iter_count += 1
D = discriminator().cuda()
G = generator().cuda()
D_optim = get_optimizer(D)
G_optim = get_optimizer(G)
train_a_gan(D, G, D_optim, G_optim, discriminator_loss, generator_loss)
Least Squares GAN 比最原始的 GANs 的 loss 更加稳定,通过名字我们也能够看出这种 GAN 是通过最小平方误差来进行估计,而不是通过二分类的损失函数,下面我们看看 loss 的计算公式
可以看到 Least Squares GAN 通过最小二乘代替了二分类的 loss,下面我们定义一下 loss 函数
def ls_discriminator_loss(scores_real, scores_fake): loss = 0.5 * ((scores_real - 1) ** 2).mean() + 0.5 * (scores_fake ** 2).mean() return loss def ls_generator_loss(scores_fake): loss = 0.5 * ((scores_fake - 1) ** 2).mean() return loss
深度卷积生成对抗网络特别简单,就是将生成网络和对抗网络都改成了卷积网络的形式,下面我们来实现一下
卷积判别网络就是一个一般的卷积网络,结构如下
class build_dc_classifier(nn.Module): def __init__(self): super(build_dc_classifier, self).__init__() self.conv = nn.Sequential( nn.Conv2d(1, 32, 5, 1), nn.LeakyReLU(0.01), nn.MaxPool2d(2, 2), nn.Conv2d(32, 64, 5, 1), nn.LeakyReLU(0.01), nn.MaxPool2d(2, 2) ) self.fc = nn.Sequential( nn.Linear(1024, 1024), nn.LeakyReLU(0.01), nn.Linear(1024, 1) ) def forward(self, x): x = self.conv(x) x = x.view(x.shape[0], -1) x = self.fc(x) return x
卷积生成网络需要将一个低维的噪声向量变成一个图片数据,结构如下
class build_dc_generator(nn.Module): def __init__(self, noise_dim=NOISE_DIM): super(build_dc_generator, self).__init__() self.fc = nn.Sequential( nn.Linear(noise_dim, 1024), nn.ReLU(True), nn.BatchNorm1d(1024), nn.Linear(1024, 7 * 7 * 128), nn.ReLU(True), nn.BatchNorm1d(7 * 7 * 128) ) self.conv = nn.Sequential( nn.ConvTranspose2d(128, 64, 4, 2, padding=1), nn.ReLU(True), nn.BatchNorm2d(64), nn.ConvTranspose2d(64, 1, 4, 2, padding=1), nn.Tanh() ) def forward(self, x): x = self.fc(x) x = x.view(x.shape[0], 128, 7, 7) # reshape 通道是 128,大小是 7x7 x = self.conv(x) return x
原始链接:https://github.com/L1aoXingyu/code-of-learn-deep-learning-with-pytorch/blob/master/chapter6_GAN/gan.ipynb