参见另一篇博文:生成式对抗网络(Generative Adversarial Nets,GAN)
import 模块名和from 模块名 import 函数名的区别:
导入一个模块时,会创建新的命名空间,就可以使用命名空间来调用其中的代码;同时,还会在新创建的命名空间中执行模块中包含的代码,如果有输出也可以在控制台看到。
import torch
import torch.nn as nn
from torch.nn import init
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import sampler
from torch.autograd import Variable
import torchvision
import torchvision.transforms as T
import torchvision.datasets as dset
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
%matplotlib inline
模块简介
pylot使用rc配置文件来自定义图形的各种默认属性,称之为rc配置或rc参数。通过rc参数可以修改默认的属性,包括窗体大小、每英寸的点数、线条宽度、颜色、样式、坐标轴、坐标和网络属性、文本、字体等。
# 设置输出图像的默认属性
plt.rcParams['figure.figsize'] = (10.0, 8.0)
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'
参考:
plt.rcParams[]
设置图像大小为(sqrtn, sqrtn),将整张图划分为(sqrtn, sqrtn)个网格,网格间的空隙为(0.05, 0.05),每个网格显示一张图片,图片大小为(sqrtimg, sqrtimg)。
参考:使用GridSpec和其他功能自定义图布局
def show_images(images):
# 将图片地形状更改为(batch_size, D)
# 个人认为这句代码没有意义,因为输入的图片均是二维的,执行该命令之后图片的形状并未发生任何变化
# images = np.reshape(images, [images.shape[0], -1])
sqrtn = int(np.ceil(np.sqrt(images.shape[0])))
sqrtimg = int(np.ceil(np.sqrt(images.shape[1])))
fig = plt.figure(figsize=(sqrtn, sqrtn))
gs = gridspec.GridSpec(sqrtn, sqrtn)
gs.update(wspace=0.05, hspace=0.05)
for i, img in enumerate(images):
ax = plt.subplot(gs[i])
plt.axis('off')
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_aspect('equal')
img = img.reshape([sqrtimg, sqrtimg])
plt.imshow(img)
plt.show()
从网上下载MNIST数据集,从训练集中取前50000张图像用作训练集,接下来的5000张图像用作验证集。
数据最终装载在以下两个变量中:
loader_train:训练集;
loader_val:验证集。
类ChunkSampler继承自父类torch.utils.sampler.Sampler
数据下载及载入
dset.MNIST(’./data/’, train=True, download=True,
transform=T.ToTensor())
imgs = loader_train.iter().next()[0].view(batch_size, 784).numpy().squeeze()
从训练集中取一个样本,将128张图片的像素存放起来,并从Tensor转换成一个128*784维的numpy矩阵,每一行存放着一张图片的所有像素值,将该矩阵传递给show_images()函数,展示图片。
for inputs, labels in dataloaders
进行可迭代对象的访问;# 采样函数为自己定义的序列采样(即按顺序采样)
class ChunkSampler(sampler.Sampler):
"""
顺序采样:
参数:
1)num_samples: 采样个数;
2)start: 采样开始索引值。
"""
def __init__(self, num_samples, start=0):
self.num_samples = num_samples
self.start = start
def __iter__(self):
result = iter(range(self.start, self.start+self.num_samples))
return result
def __len__(self):
return self.num_samples
NUM_TRAIN = 50000 # 训练集数量
NUM_VAL = 5000 # 测试集数量
batch_size = 128 # batch的大小
mnist_train = dset.MNIST('./data/', train=True, download=True,
transform=T.ToTensor())
# 从0位置开始采样NUM_TRAIN个数
loader_train = DataLoader(mnist_train, batch_size=batch_size,
sampler=ChunkSampler(NUM_TRAIN, 0))
mnist_val = dset.MNIST('./data/', train=True, download=True,
transform=T.ToTensor())
# 从NUM_TRAIN位置开始采样NUM_VAL个数
loader_val = DataLoader(mnist_val, batch_size=batch_size,
sampler=ChunkSampler(NUM_VAL, NUM_TRAIN))
# squeeze(): 从数组的形状中删除单维度条目,即把shape中为1的维度去掉。个人认为squeeze()没有意义,因为图像的维度为(batch_size, 784)
# imgs = loader_train.__iter__().next()[0].view(batch_size, 784).numpy().squeeze()
imgs = loader_train.__iter__().next()[0].view(batch_size, 784).numpy()
show_images(imgs)
这里产生一个从-1~1的均匀噪声函数,形状为[batch_size, noise_dim]
def sample_noise(batch_size, noise_dim):
"""
生成均匀随机噪声,充当生成器的输入值
输入:
1)batch_size:batch尺寸;
2)noise_dim:生成噪声的维度。
输出:
维度为(batch_size, noise_dim)的张量,每个元素值介于-1和1之间
"""
temp = torch.rand(batch_size, noise_dim) + torch.rand(batch_size, noise_dim) * (-1)
return temp
class Flatten(nn.Module):
def forward(self, x):
# x.shape = [batch_size, channel, height, weight]
# 将shape中的数据依次赋给N, C, H, W
N, C, H, W = x.size()
# 将张量由[batch_size, channel, height, weight]平铺为
# [batch_size, channel*hight*weight],每一行分别代表一张图片的像素值
return x.view(N, -1)
class Unflatten(nn.Module):
# 将输入张量由[N, C*H*W]转变为[N, C, H, W]
def __init__(self, N=-1, C=128, H=7, W=7):
super(Unflatten, self).__init__()
self.N = N
self.C = C
self.H = H
self.W = W
def forward(self, x):
return x.view(self.N, self.C, self.H, self.W)
Discriminator输入为图片,输出为scalar。
def discriminator():
# 搭建判别器模型
model = nn.Sequential(
Flatten(),
nn.Linear(784, 256),
nn.LeakyReLU(0.01, inplace=True),
nn.Linear(256, 256),
nn.LeakyReLU(0.01, inplace=True),
nn.Linear(256, 1)
)
return model
Generator输入为noise,输出为图片。
def generator(noise_dim=NOISE_DIM):
# 搭建生成器模型
model = nn.Sequential(
nn.Linear(noise_dim, 1024),
nn.ReLU(inplace=True),
nn.Linear(1024, 1024),
nn.ReLU(inplace=True),
nn.Linear(1024, 784),
nn.Tanh()
)
return model
当计算交叉熵时,使用原生的函数会造成不稳定的求导,推荐使用 BCEWithLogitsLoss()。
# 定义损失函数
Bce_loss = nn.BCEWithLogitsLoss()
# 定义优化函数
def get_optimizer(model):
"""
为模型构建优化函数,学习率为0.001,beta1=0.5,beta2=0.999
输入:
模型
输出:
模型的优化器
"""
optimizer = optim.Adam(model.parameters(), lr=0.001, betas=(0.5, 0.999))
return optimizer
分别计算G和D的损失函数,使用Adam优化:
def discriminator_loss(logits_real, logits_fake):
"""
计算判别器的损失函数
输入:
1) logits_real:判别器对输入其中的真实图像的评分,维度为batch_size*1
2) logits_fake:判别器对输入其中的生成图像的评分,维度为batch_size*1
输出:
判别器的损失值
"""
loss = None
# Batch size
N = logits_real.size()
# 目标Label,全部设置为1意味着判别器需要做到的是将正确的全识别为正确,错误的全识别为错误
true_labels = Variable(torch.ones(N)).type(dtype)
# 识别正确的为正确-计算当输入真实图片时,D输出的scalar与正确标签1之间的差距,即D对于真实图片的损失函数
real_image_loss = Bce_loss(logits_real, true_labels)
# 识别错误的为错误-计算当输入生成图片时,D输出的scalar与正确标签0之间的差距,即D对于生成图片的损失函数
fake_image_loss = Bce_loss(logits_fake, 1-true_labels)
# 总的损失值由以上两部分组成
loss = real_image_loss + fake_image_loss
return loss
def generator_loss(logits_fake):
"""
计算生成器的损失值
输入:
logits_fake:判别器对于生成的图片的评分,维度为batch_size*1
输出:
生成器的损失值
"""
# Batch size
N = logits_fake.size()
# 由于生成器的目标是生成尽量真实的图片,试图骗过D,所以G的目标标签应当是1,即生成器的作用是将所有”假“的向真的(1)靠拢
true_labels = Variable(torch.ones(N)).type(dtype)
# 计算生成器损失
loss = Bce_loss(logits_fake, true_labels)
return loss
def run_a_gan(D, G, D_solver, G_solver, show_every=250, batch_size=128,
noise_size=96, num_epochs=10):
"""
训练GAN
参数:
1)D,G:分别表示判别器和生成器模型;
2)D_solver,G_solver:D和G的优化函数;
3)batch_size:训练过程中使用的batch的大小;
4)noise_size:输入生成器的随机向量的维度;
5)num_epochs:训练的次数。
"""
iter_count = 1
for epoch in range(num_epochs):
for x, _ in loader_train:
if len(x) != batch_size:
continue
# step1: Training D
# 将D模型的参数梯度归零
D_solver.zero_grad()
# 转换参数类型,并在GPU可用的情况下完成数据迁移
real_data = Variable(x).type(dtype)
# 2*(real_data-0.5)???
logits_real = D(2*(real_data-0.5)).type(dtype)
# logits_real = D(real_data-0.5).type(dtype)
# 生成随机噪声
g_fake_seed = Variable(sample_noise(batch_size, noise_size)).type(dtype)
# 生成图片
fake_images = G(g_fake_seed).detach()
# 将生成的图片输入D,输出图片的scalar
logits_fake = D(fake_images.view(batch_size, 1, 28, 28))
# 计算D的损失值
d_total_error = discriminator_loss(logits_real, logits_fake)
"""
该函数的功能在于让模型根据计算图自动计算每个节点的梯度值并根据需求进行保留,有了这一步,
模型中的权重参数就可以直接使用在自动梯度过程中求得的梯度值,并结合学习速率来对现有的参数进行更新、优化
"""
d_total_error.backward()
# 使用计算得到的梯度值对各个节点的参数进行梯度更新
D_solver.step()
# Step2: Training G
# 将G模型的参数梯度归零
G_solver.zero_grad()
# 生成随机噪声
g_fake_seed = Variable(sample_noise(batch_size, noise_size)).type(dtype)
# 生成图片
fake_images = G(g_fake_seed)
# 将生成的图片输入D,输出图片的scalar
gen_logits_fake = D(fake_images.view(batch_size, 1, 28, 28))
# 计算G的损失值
g_error = generator_loss(gen_logits_fake)
# 参数更新、优化
g_error.backward()
G_solver.step()
# 每训练show_every次,输出一次样本图片
if(iter_count % show_every==0):
print('Iter:{}, D:{:.4f}, G:{:.4f}'.format(iter_count, d_total_error.data, g_error.data))
# 若GPU可用,则目前数据均为GPU张量,因为之后要完成从数组中取值的操作,所以需要先将数据转换为CPU张量,再转为numpy
if Use_gpu:
imgs_numpy = fake_images.data.cpu().numpy()
else:
imgs_numpy = fake_images.data.numpy()
images = imgs_numpy[0:16]
# 输出前16张图片
show_images(images)
print()
iter_count += 1
# 判断GPU是否可用
Use_gpu = torch.cuda.is_available()
# 若可用,将数据迁移到GPU上,并将数据类型转换成FloatTensor
if Use_gpu:
dtype = torch.cuda.FloatTensor
# 若不可用,数据仍存放在CPU中,仅将数据类型转换成FloatTensor
else:
dtype = torch.FloatTensor
# 定义判别器,并完成数据类型转换
D = discriminator().type(dtype)
# 定义生成器,并完成数据类型转换
G = generator().type(dtype)
# 定义D和G的优化函数
D_solver = get_optimizer(D)
G_solver = get_optimizer(G)
# 开始模型训练
run_a_gan(D, G, D_solver, G_solver)
KL散度
生成式对抗网络(Generative Adversarial Nets,GAN)
机器之心GitHub项目:GAN完整理论推导与实现,Perfect!
PyTorch中文文档
PyTorch英文文档
pytorch实现自由的数据读取-torch.utils.data的学习
NumPy
NumPy 教程
Matplotlib中文文档
Matplotlib
Python数据可视化利器Matplotlib,绘图入门篇,Pyplot介绍
matplotlib: matplotlib.gridspec
Python关于%matplotlib inline
什么是GAN呢?
利用pytorch实现GAN(生成对抗网络)-MNIST图像-cs231n-assignment3
numpy.reshape
reshape(-1,1)什么意思 numpy.reshape
pytorch之dataloader深入剖析
什么是Tensor
PyTorch中网络里面的inplace=True字段的意思
Python中inplace=True的理解
神经网络中常用的激活函数
Pytorch - Cross Entropy Loss
Pytorch详解BCELoss和BCEWithLogitsLoss