- 本文为365天深度学习训练营 中的学习记录博客
- 参考文章:365天深度学习训练营-第G1周:生成对抗网络入门
- 原作者:K同学啊|接辅导、项目定制
GAN并不表示某一种具体的深度学习网络,而是一种基于博弈论的神经网络,其分为Generation
和Discriminiation
两个部分,目的是为了将真实的样本和人工样本进行区分,在训练过程当中Generation
和Discriminlation
相互交替出现,互相博弈,当判别器Discrimination
无法成功地将人工样本和真实样本区分开的时候就会停止运行
import argparse
import os
import numpy as np
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
import torch.nn as nn
import torch
## 创建文件夹
os.makedirs("./images/", exist_ok=True) ## 记录训练过程的图片效果
os.makedirs("./save/", exist_ok=True) ## 训练完成时模型保存的位置
os.makedirs("./datasets/mnist", exist_ok=True) ## 下载数据集存放的位置
## 超参数配置
n_epochs=50
batch_size=512
lr=0.0002
b1=0.5
b2=0.999
n_cpu=2
latent_dim=100
img_size=28
channels=1
sample_interval=500
## 图像的尺寸:(1, 28, 28), 和图像的像素面积:(784)
img_shape = (channels, img_size, img_size)
img_area = np.prod(img_shape)
## 设置cuda:(cuda:0)
cuda = True if torch.cuda.is_available() else False
print(cuda)
True
## mnist数据集下载
mnist = datasets.MNIST(
root='./datasets/', train=True, download=True, transform=transforms.Compose(
[transforms.Resize(img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]),
)
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./datasets/MNIST/raw/train-images-idx3-ubyte.gz
100%|██████████| 9912422/9912422 [00:00<00:00, 438576233.89it/s]
Extracting ./datasets/MNIST/raw/train-images-idx3-ubyte.gz to ./datasets/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./datasets/MNIST/raw/train-labels-idx1-ubyte.gz
100%|██████████| 28881/28881 [00:00<00:00, 26106830.57it/s]
Extracting ./datasets/MNIST/raw/train-labels-idx1-ubyte.gz to ./datasets/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./datasets/MNIST/raw/t10k-images-idx3-ubyte.gz
100%|██████████| 1648877/1648877 [00:00<00:00, 95251028.09it/s]
Extracting ./datasets/MNIST/raw/t10k-images-idx3-ubyte.gz to ./datasets/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./datasets/MNIST/raw/t10k-labels-idx1-ubyte.gz
100%|██████████| 4542/4542 [00:00<00:00, 5843720.48it/s]
Extracting ./datasets/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./datasets/MNIST/raw
## 配置数据到加载器
dataloader = DataLoader(
mnist,
batch_size=batch_size,
shuffle=True,
)
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Linear(img_area, 512), # 输入特征数为784,输出为512
nn.LeakyReLU(0.2, inplace=True), # 进行非线性映射
nn.Linear(512, 256), # 输入特征数为512,输出为256
nn.LeakyReLU(0.2, inplace=True), # 进行非线性映射
nn.Linear(256, 1), # 输入特征数为256,输出为1
nn.Sigmoid(), # sigmoid是一个激活函数,二分类问题中可将实数映射到[0, 1],作为概率值, 多分类用softmax函数
)
def forward(self, img):
img_flat = img.view(img.size(0), -1) # 鉴别器输入是一个被view展开的(784)的一维图像:(64, 784)
validity = self.model(img_flat) # 通过鉴别器网络
return validity # 鉴别器返回的是一个[0, 1]间的概率
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
## 模型中间块儿
def block(in_feat, out_feat, normalize=True): # block(in, out )
layers = [nn.Linear(in_feat, out_feat)] # 线性变换将输入映射到out维
if normalize:
layers.append(nn.BatchNorm1d(out_feat, 0.8)) # 正则化
layers.append(nn.LeakyReLU(0.2, inplace=True)) # 非线性激活函数
return layers
## prod():返回给定轴上的数组元素的乘积:1*28*28=784
self.model = nn.Sequential(
*block(latent_dim, 128, normalize=False), # 线性变化将输入映射 100 to 128, 正则化, LeakyReLU
*block(128, 256), # 线性变化将输入映射 128 to 256, 正则化, LeakyReLU
*block(256, 512), # 线性变化将输入映射 256 to 512, 正则化, LeakyReLU
*block(512, 1024), # 线性变化将输入映射 512 to 1024, 正则化, LeakyReLU
nn.Linear(1024, img_area), # 线性变化将输入映射 1024 to 784
nn.Tanh() # 将(784)的数据每一个都映射到[-1, 1]之间
)
## view():相当于numpy中的reshape,重新定义矩阵的形状:这里是reshape(64, 1, 28, 28)
def forward(self, z): # 输入的是(64, 100)的噪声数据
imgs = self.model(z) # 噪声数据通过生成器模型
imgs = imgs.view(imgs.size(0), *img_shape) # reshape成(64, 1, 28, 28)
return imgs # 输出为64张大小为(1, 28, 28)的图像
## 创建生成器,判别器对象
generator = Generator()
discriminator = Discriminator()
## 首先需要定义loss的度量方式 (二分类的交叉熵)
criterion = torch.nn.BCELoss()
## 其次定义 优化函数,优化函数的学习率为0.0003
## betas:用于计算梯度以及梯度平方的运行平均值的系数
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))
## 如果有显卡,都在cuda模式中运行
if torch.cuda.is_available():
generator = generator.cuda()
discriminator = discriminator.cuda()
criterion = criterion.cuda()
for epoch in range(n_epochs): # epoch:50
for i, (imgs, _) in enumerate(dataloader): # imgs:(64, 1, 28, 28) _:label(64)
imgs = imgs.view(imgs.size(0), -1) # 将图片展开为28*28=784 imgs:(64, 784)
real_img = Variable(imgs).cuda() # 将tensor变成Variable放入计算图中,tensor变成variable之后才能进行反向传播求梯度
real_label = Variable(torch.ones(imgs.size(0), 1)).cuda() ## 定义真实的图片label为1
fake_label = Variable(torch.zeros(imgs.size(0), 1)).cuda() ## 定义假的图片的label为0
real_out = discriminator(real_img) # 将真实图片放入判别器中
loss_real_D = criterion(real_out, real_label) # 得到真实图片的loss
real_scores = real_out # 得到真实图片的判别值,输出的值越接近1越好
## 计算假的图片的损失
## detach(): 从当前计算图中分离下来避免梯度传到G,因为G不用更新
z = Variable(torch.randn(imgs.size(0), latent_dim)).cuda() ## 随机生成一些噪声, 大小为(128, 100)
fake_img = generator(z).detach() ## 随机噪声放入生成网络中,生成一张假的图片。
fake_out = discriminator(fake_img) ## 判别器判断假的图片
loss_fake_D = criterion(fake_out, fake_label) ## 得到假的图片的loss
fake_scores = fake_out
## 损失函数和优化
loss_D = loss_real_D + loss_fake_D # 损失包括判真损失和判假损失
optimizer_D.zero_grad() # 在反向传播之前,先将梯度归0
loss_D.backward() # 将误差反向传播
optimizer_D.step() # 更新参数
z = Variable(torch.randn(imgs.size(0), latent_dim)).cuda() ## 得到随机噪声
fake_img = generator(z) ## 随机噪声输入到生成器中,得到一副假的图片
output = discriminator(fake_img) ## 经过判别器得到的结果
## 损失函数和优化
loss_G = criterion(output, real_label) ## 得到的假的图片与真实的图片的label的loss
optimizer_G.zero_grad() ## 梯度归0
loss_G.backward() ## 进行反向传播
optimizer_G.step() ## step()一般用在反向传播后面,用于更新生成网络的参数
## 打印训练过程中的日志
## item():取出单元素张量的元素值并返回该值,保持原元素类型不变
if ( i + 1 ) % 100 == 0:
print(
"[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] [D real: %f] [D fake: %f]"
% (epoch, n_epochs, i, len(dataloader), loss_D.item(), loss_G.item(), real_scores.data.mean(), fake_scores.data.mean())
)
## 保存训练过程中的图像
batches_done = epoch * len(dataloader) + i
if batches_done % sample_interval == 0:
save_image(fake_img.data[:25], "./images/%d.png" % batches_done, nrow=5, normalize=True)
[Epoch 0/50] [Batch 99/118] [D loss: 1.392201] [G loss: 0.720495] [D real: 0.660287] [D fake: 0.605812]
[Epoch 1/50] [Batch 99/118] [D loss: 1.237947] [G loss: 0.645015] [D real: 0.627908] [D fake: 0.521414]
[Epoch 2/50] [Batch 99/118] [D loss: 1.206743] [G loss: 1.371007] [D real: 0.831680] [D fake: 0.612845]
[Epoch 3/50] [Batch 99/118] [D loss: 1.042226] [G loss: 1.403389] [D real: 0.738673] [D fake: 0.510074]
[Epoch 4/50] [Batch 99/118] [D loss: 0.950092] [G loss: 1.879504] [D real: 0.826950] [D fake: 0.522136]
[Epoch 5/50] [Batch 99/118] [D loss: 0.964685] [G loss: 1.957089] [D real: 0.719685] [D fake: 0.451049]
[Epoch 6/50] [Batch 99/118] [D loss: 1.250396] [G loss: 3.282628] [D real: 0.890914] [D fake: 0.671558]
[Epoch 7/50] [Batch 99/118] [D loss: 1.436018] [G loss: 2.394736] [D real: 0.889253] [D fake: 0.720616]
[Epoch 8/50] [Batch 99/118] [D loss: 0.997427] [G loss: 1.298990] [D real: 0.646482] [D fake: 0.406250]
[Epoch 9/50] [Batch 99/118] [D loss: 1.015665] [G loss: 1.133040] [D real: 0.583219] [D fake: 0.313580]
[Epoch 10/50] [Batch 99/118] [D loss: 1.104344] [G loss: 0.890277] [D real: 0.487496] [D fake: 0.231743]
[Epoch 11/50] [Batch 99/118] [D loss: 0.805449] [G loss: 1.476539] [D real: 0.771647] [D fake: 0.393425]
[Epoch 12/50] [Batch 99/118] [D loss: 0.886394] [G loss: 1.365783] [D real: 0.625857] [D fake: 0.272819]
[Epoch 13/50] [Batch 99/118] [D loss: 1.160892] [G loss: 2.270098] [D real: 0.819089] [D fake: 0.600196]
[Epoch 14/50] [Batch 99/118] [D loss: 0.990126] [G loss: 2.221423] [D real: 0.845885] [D fake: 0.544953]
[Epoch 15/50] [Batch 99/118] [D loss: 0.814652] [G loss: 1.242788] [D real: 0.629658] [D fake: 0.231237]
[Epoch 16/50] [Batch 99/118] [D loss: 1.292980] [G loss: 2.503134] [D real: 0.784579] [D fake: 0.633823]
[Epoch 17/50] [Batch 99/118] [D loss: 1.031758] [G loss: 2.699657] [D real: 0.815461] [D fake: 0.542024]
[Epoch 18/50] [Batch 99/118] [D loss: 0.988402] [G loss: 1.569268] [D real: 0.678417] [D fake: 0.390312]
[Epoch 19/50] [Batch 99/118] [D loss: 1.008053] [G loss: 2.010935] [D real: 0.820956] [D fake: 0.532101]
[Epoch 20/50] [Batch 99/118] [D loss: 0.928145] [G loss: 1.021322] [D real: 0.581234] [D fake: 0.226135]
[Epoch 21/50] [Batch 99/118] [D loss: 0.901849] [G loss: 1.050935] [D real: 0.586989] [D fake: 0.204523]
[Epoch 22/50] [Batch 99/118] [D loss: 0.741626] [G loss: 1.595031] [D real: 0.732048] [D fake: 0.289632]
[Epoch 23/50] [Batch 99/118] [D loss: 1.299593] [G loss: 0.612525] [D real: 0.399023] [D fake: 0.091832]
[Epoch 24/50] [Batch 99/118] [D loss: 0.892590] [G loss: 1.221449] [D real: 0.614085] [D fake: 0.213004]
[Epoch 25/50] [Batch 99/118] [D loss: 0.911475] [G loss: 1.191773] [D real: 0.555970] [D fake: 0.123458]
[Epoch 26/50] [Batch 99/118] [D loss: 0.995571] [G loss: 2.681632] [D real: 0.854862] [D fake: 0.527149]
[Epoch 27/50] [Batch 99/118] [D loss: 1.160800] [G loss: 0.873747] [D real: 0.452850] [D fake: 0.120306]
[Epoch 28/50] [Batch 99/118] [D loss: 0.812079] [G loss: 3.594522] [D real: 0.888997] [D fake: 0.474353]
[Epoch 29/50] [Batch 99/118] [D loss: 0.629567] [G loss: 1.881860] [D real: 0.751697] [D fake: 0.232717]
[Epoch 30/50] [Batch 99/118] [D loss: 0.990794] [G loss: 1.515161] [D real: 0.638545] [D fake: 0.315502]
[Epoch 31/50] [Batch 99/118] [D loss: 0.719803] [G loss: 2.252865] [D real: 0.770139] [D fake: 0.299516]
[Epoch 32/50] [Batch 99/118] [D loss: 0.692556] [G loss: 2.410813] [D real: 0.821995] [D fake: 0.356726]
[Epoch 33/50] [Batch 99/118] [D loss: 0.804586] [G loss: 1.265048] [D real: 0.616957] [D fake: 0.143595]
[Epoch 34/50] [Batch 99/118] [D loss: 1.002946] [G loss: 0.998340] [D real: 0.542722] [D fake: 0.129555]
[Epoch 35/50] [Batch 99/118] [D loss: 0.638164] [G loss: 2.125413] [D real: 0.781333] [D fake: 0.267743]
[Epoch 36/50] [Batch 99/118] [D loss: 0.796073] [G loss: 2.105274] [D real: 0.733458] [D fake: 0.275736]
[Epoch 37/50] [Batch 99/118] [D loss: 0.707837] [G loss: 1.712856] [D real: 0.723462] [D fake: 0.213725]
[Epoch 38/50] [Batch 99/118] [D loss: 0.574488] [G loss: 1.854247] [D real: 0.762389] [D fake: 0.191676]
[Epoch 39/50] [Batch 99/118] [D loss: 0.665478] [G loss: 2.232678] [D real: 0.820965] [D fake: 0.334450]
[Epoch 40/50] [Batch 99/118] [D loss: 0.721808] [G loss: 2.323272] [D real: 0.814131] [D fake: 0.358871]
[Epoch 41/50] [Batch 99/118] [D loss: 0.794077] [G loss: 1.027562] [D real: 0.625926] [D fake: 0.107704]
[Epoch 42/50] [Batch 99/118] [D loss: 0.809034] [G loss: 1.199560] [D real: 0.628133] [D fake: 0.120829]
[Epoch 43/50] [Batch 99/118] [D loss: 0.679062] [G loss: 2.121304] [D real: 0.762760] [D fake: 0.273702]
[Epoch 44/50] [Batch 99/118] [D loss: 0.565462] [G loss: 1.739519] [D real: 0.764866] [D fake: 0.179747]
[Epoch 45/50] [Batch 99/118] [D loss: 0.788362] [G loss: 1.181516] [D real: 0.638391] [D fake: 0.095294]
[Epoch 46/50] [Batch 99/118] [D loss: 0.761360] [G loss: 3.341885] [D real: 0.876044] [D fake: 0.434379]
[Epoch 47/50] [Batch 99/118] [D loss: 0.755073] [G loss: 2.987507] [D real: 0.849711] [D fake: 0.408325]
[Epoch 48/50] [Batch 99/118] [D loss: 0.737976] [G loss: 1.869045] [D real: 0.739559] [D fake: 0.229506]
[Epoch 49/50] [Batch 99/118] [D loss: 0.726861] [G loss: 1.471112] [D real: 0.676479] [D fake: 0.113519]
torch.save(generator.state_dict(), './generator.pth')
torch.save(discriminator.state_dict(), './discriminator.pth')