本文为365天深度学习训练营 中的学习记录博客
原作者:K同学啊|接辅导、项目定制
我的环境:
1.语言:python3.7
2.编译器:pycharm
3.深度学习框架Pytorch 1.8.0+cu111
GAN网络介绍:
生成对抗网络(GAN)是一种深度学习模型,最初由Ian Goodfellow等人于2014年提出。GAN的核心思想是通过训练两个神经网络,一个生成器(Generator)和一个判别器(Discriminator),彼此博弈来学习生成与真实数据相似的样本。
生成器(Generator):
判别器(Discriminator):
对抗过程:
目标函数:
GAN的优点包括能够生成高质量、逼真的数据,无需显式规定生成规则。它在图像生成、风格迁移、图像超分辨率等任务上取得了显著的成功。然而,GAN的训练也面临一些挑战,如模式崩溃、训练不稳定等,需要仔细的调参和设计。
结合整体模型图示,再以生成图片作为例子具体说明下面。我们有两个网络,G(Generator)和D(Discriminator)。Generator是一个生成图片的网络,它接收一个随机的噪声z,通过这个噪声生成图片,记做G(z)。Discriminator是一个判别网络,判别一张图片是不是“真实的”。它的输入是x,x代表一张图片,输出D(x)代表x为真实图片的概率,如果为1,就代表100%是真实的图片,而输出为0,就代表不可能是真实的图片。
import argparse
import os
import numpy as np
import torch.cuda
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
import torch.nn as nn
import torch
os.makedirs("./images/",exist_ok = True) #记录训练过程的图片效果
os.makedirs("./save/",exist_ok = True) #训练完成时模型保存的位置
os.makedirs("./datasets/mnist",exist_ok = True) #下载数据集存放的位置
#超参数配置
n_epochs = 50
batch_size = 64
Ir = 0.0002
b1 = 0.5
b2 = 0.999
n_cpu = 2
latent_dim = 100
img_size = 28
channels = 1
sample_interval = 500
img_shape = (channels,img_size,img_size) #图像尺寸(1,28,28)
img_area = np.prod(img_shape) #图像像素面积784
cuda = True if torch.cuda.is_available() else False
print(cuda)
mnist = datasets.MNIST(
root = './datasets/',train = True,download =True,transform = transforms.Compose(
[transforms.Resize(img_size),transforms.ToTensor(),transforms.Normalize([0.5],[0.5])]),)
#鉴别器
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator,self).__init__()
self.model = nn.Sequential(
nn.Linear(img_area,512),
nn.LeakyReLU(0.2,inplace = True),
nn.Linear(512,256),
nn.LeakyReLU(0.2,inplace=True),
nn.Linear(256,1),
nn.Sigmoid(),
)
def forward(self,img):
img_flat = img.view(img.size(0),-1)
validity = self.model(img_flat)
return validity
判别器是一个神经网络模型,它接受一个图像作为输入并输出一个标量值,表示输入图像是真实图像的概率。通过使用nn.Sequential
定义了一个包含线性层、LeakyReLU激活函数和Sigmoid激活函数的模型。
输入图像经过一系列线性变换和非线性激活,最终通过 Sigmoid 激活函数输出一个范围在 [0, 1] 内的值,表示输入图像是真实图像的概率。
forward
方法定义了数据在模型中前向传播的流程。img.view(img.size(0),-1)
:将输入图像展平为一维张量,保留 batch 的维度。self.model(img_flat)
:将展平后的图像输入到定义的神经网络模型中。return validity
:返回模型的输出,即表示判别为真实的概率。##生成器
class Generator(nn.Module):
def __init__(self):
super(Generator,self).__init__()
def block(in_feat,out_feat,normalize = True):
layers = [nn.Linear(in_feat,out_feat)]
if normalize:
layers.append(nn.BatchNorm1d(out_feat,0.8))
layers.append(nn.LeakyReLU(0.2,inplace = True))
return layers
self.model = nn.Sequential(
*block(latent_dim,128,normalize = False),
*block(128,256),
*block(256,512),
*block(512,1024),
nn.Linear(1024,img_area),
nn.Tanh()
)
def forward(self,z):
imgs = self.model(z)
imgs = imgs.view(imgs.size(0),*img_shape)
return imgs
Generator
类同样继承自 nn.Module
。block
,用于创建生成器中的一个块(block)。这个块由一个线性层、一个批归一化(如果需要)、和 Leaky ReLU 激活函数组成。__init__
方法中,通过堆叠多个块构建了整个生成器模型。*block(latent_dim,128,normalize=False)
:输入层,从潜在向量(latent vector)映射到大小为 128 的隐藏层,不进行批归一化。img_area
的一维向量。nn.Tanh()
:最后一层使用 Tanh 激活函数,将输出限制在 (-1, 1) 范围内。forward
方法定义了数据在模型中前向传播的流程。self.model(z)
:将输入潜在向量 z
通过生成器模型。imgs.view(imgs.size(0), *img_shape)
:将生成器的输出展平为图像形状。return imgs
:返回生成的图像。generator = Generator()
discriminator = Discriminator()
criterion = torch.nn.BCELoss()
optimizer_G = torch.optim.Adam(generator.parameters(), lr=Ir, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(),lr = Ir,betas = (b1,b2))
if torch.cuda.is_available():
generator = generator.cuda()
discriminator = discriminator.cuda()
criterion = criterion.cuda()
这两个模型是 GAN 中的关键组件,通过对抗训练的方式使得生成器生成逼真的图像,而鉴别器则尽可能准确地判别真实和生成的图像。在训练循环中,生成器和鉴别器通过对抗损失函数进行优化,以实现更好的生成图像和更准确的判别。
for epoch in range(n_epochs):
for i,(imgs,_) in enumerate(dataloader):
imgs = imgs.view(imgs.size(0),-1)
real_img = Variable(imgs).cuda()
real_label = Variable(torch.ones(imgs.size(0),1)).cuda()
fake_label = Variable(torch.zeros(imgs.size(0), 1)).cuda()
real_out = discriminator(real_img)
loss_real_D = criterion(real_out,real_label)
real_scores = real_out
z = Variable(torch.randn(imgs.size(0),latent_dim)).cuda()
fake_img = generator(z).detach()
fake_out = discriminator(fake_img)
loss_fake_D = criterion(fake_out,fake_label)
fake_scores = fake_out
loss_D = loss_real_D + loss_fake_D
optimizer_D.zero_grad()
loss_D.backward()
optimizer_D.step()
z = Variable(torch.randn(imgs.size(0), latent_dim)).cuda()
fake_img = generator(z)
fake_out = discriminator(fake_img)
loss_G = criterion(fake_out, real_label)
optimizer_G.zero_grad()
loss_G.backward()
optimizer_G.step()
if (i + 1) % 300 == 0:
print(
"[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] [D real: %f] [D fake: %f]"
% (epoch, n_epochs, i ,len(dataloader), loss_D.item(), loss_G.item(), real_scores.data.mean(), fake_scores.data.mean())
)
batches_done = epoch * len(dataloader) + i
if batches_done % sample_interval == 0:
save_image(fake_img.data[:25],"./images/%d.png" % batches_done, nrow = 5, normalize = True)
#模型保存
torch.save(generator.state_dict(),'./save/generator.pth')
torch.save(discriminator.state_dict(),'./save/discriminator.pth')
[Epoch 0/50] [Batch 299/938] [D loss: 1.062147] [G loss: 0.793500] [D real: 0.557045] [D fake: 0.361011]
[Epoch 0/50] [Batch 599/938] [D loss: 1.059478] [G loss: 1.188343] [D real: 0.587030] [D fake: 0.359636]
[Epoch 0/50] [Batch 899/938] [D loss: 1.042897] [G loss: 1.466038] [D real: 0.653879] [D fake: 0.433240]
[Epoch 1/50] [Batch 299/938] [D loss: 0.843557] [G loss: 1.561677] [D real: 0.771608] [D fake: 0.421925]
[Epoch 1/50] [Batch 599/938] [D loss: 1.026156] [G loss: 2.394096] [D real: 0.862107] [D fake: 0.558821]
[Epoch 1/50] [Batch 899/938] [D loss: 0.836007] [G loss: 1.557716] [D real: 0.652867] [D fake: 0.245031]
[Epoch 2/50] [Batch 299/938] [D loss: 1.035034] [G loss: 0.812236] [D real: 0.550502] [D fake: 0.262485]
[Epoch 2/50] [Batch 599/938] [D loss: 1.063006] [G loss: 1.113460] [D real: 0.483566] [D fake: 0.111593]
[Epoch 2/50] [Batch 899/938] [D loss: 0.803735] [G loss: 1.973702] [D real: 0.765450] [D fake: 0.374167]
[Epoch 3/50] [Batch 299/938] [D loss: 0.825772] [G loss: 1.667717] [D real: 0.780164] [D fake: 0.376694]
[Epoch 3/50] [Batch 599/938] [D loss: 1.435523] [G loss: 0.494344] [D real: 0.369774] [D fake: 0.139426]
[Epoch 3/50] [Batch 899/938] [D loss: 0.917040] [G loss: 0.905851] [D real: 0.552047] [D fake: 0.157987]
[Epoch 4/50] [Batch 299/938] [D loss: 1.141817] [G loss: 2.367955] [D real: 0.860573] [D fake: 0.589422]
[Epoch 4/50] [Batch 599/938] [D loss: 0.996989] [G loss: 0.822017] [D real: 0.590626] [D fake: 0.270388]
[Epoch 4/50] [Batch 899/938] [D loss: 0.753533] [G loss: 1.717019] [D real: 0.742206] [D fake: 0.312312]
[Epoch 5/50] [Batch 299/938] [D loss: 0.989197] [G loss: 0.888657] [D real: 0.634080] [D fake: 0.284923]
[Epoch 5/50] [Batch 599/938] [D loss: 0.844832] [G loss: 1.811573] [D real: 0.747003] [D fake: 0.359605]
[Epoch 5/50] [Batch 899/938] [D loss: 0.694102] [G loss: 1.491613] [D real: 0.720356] [D fake: 0.237716]