1.数据集
数据集来自torchvision的dataset的MNIST手写0-9数据集(28x28)
具体请自行了解
2.模型
生成器(Generator)和判别器(Discriminator)
model.py文件
from torch import nn
# 图像生成器
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.gen = nn.Sequential(
nn.Linear(100, 256),
nn.ReLU(),
nn.Linear(256, 512),
nn.ReLU(),
nn.Linear(512, 784), #图片大小为28*28=784
nn.Tanh() # Tanh激活使得生成数据分布在[-1,1]之间,因为输入的真实数据的经过transforms之后也是这个分布
)
def forward(self, x):
x = self.gen(x)
return x
# 图像判别器
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.f1 = nn.Sequential(
nn.Linear(784, 512),
nn.LeakyReLU(0.2)
)
self.f2 = nn.Sequential(
nn.Linear(512, 256),
nn.LeakyReLU(0.2)
)
self.out = nn.Sequential(
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, x):
x = self.f1(x)
x = self.f2(x)
x = self.out(x)
return x
3.训练模型
train.py
import torch
from torch import nn
from torch.autograd import Variable
from torchvision import transforms, datasets
from torchvision.utils import save_image
from model import Discriminator, Generator
def to_img(x):
out = 0.5 * (x + 1)
out = out.clamp(0, 1) # Clamp函数可以将随机变化的数值限制在一个给定的区间[min, max]内:
out = out.view(-1, 1, 28, 28) # view()函数作用是将一个多行的Tensor,拼接成一行
return out
def GAN_train_model(dataset, generator, discriminator, batch_size, epoch, lr, z_dim, device):
device = device
batch_size = batch_size
epoch = epoch
lr = lr
z_dim = z_dim
# 返回一个数据迭代器
# shuffle:是否打乱顺序
data_loader = torch.utils.data.DataLoader(dataset=dataset,
batch_size=batch_size,
shuffle=True)
if device == "CUDA":
D = discriminator.cuda()
G = generator.cuda()
else:
D = discriminator.cpu()
G = generator.cpu()
criterion = nn.BCELoss() # 定义损失函数
d_optimizer = torch.optim.Adam(D.parameters(), lr=lr)
g_optimizer = torch.optim.Adam(G.parameters(), lr=lr)
steps_per_epoch = len(data_loader)
# 开始训练
for cur_epoch in range(epoch): # 进行多个epoch的训练
total_d_loss = 0
total_g_loss = 0
for i, (img, _) in enumerate(data_loader):
num_img = img.size(0)
# 将图像变为1维数据
img = img.view(num_img, -1)
real_img = img
# 定义真实的图片label为1
real_label = torch.ones(num_img, 1)
# 定义假的图片的label为0
fake_label = torch.zeros(num_img, 1)
if device == "CUDA":
real_img = real_img.cuda()
real_label = real_label.cuda()
fake_label = fake_label.cuda()
else:
real_img = real_img.cpu()
real_label = real_label.cpu()
fake_label = fake_label.cpu()
# 判别器训练
# 将真实图片放入判别器中
real_out = D(real_img)
# 得到真实图片的loss
d_loss_real = criterion(real_out, real_label)
# 得到真实图片的判别值,real_out输出的值越接近1越好
real_scores = real_out
# 计算假的图片的损失
z = torch.randn(num_img, z_dim) # 随机生成一些噪声
if device == "CUDA":
z = z.cuda()
else:
z = z.cpu()
# 随机噪声放入生成网络中,生成一张假的图片。
# 避免梯度传到G,因为G不用更新, detach分离
fake_img = G(z).detach()
# 判别器判断假的图片
fake_out = D(fake_img)
# 得到假的图片的loss
d_loss_fake = criterion(fake_out, fake_label)
# 得到假图片的判别值,对于判别器来说,假图片的d_loss_fake损失越接近0越好
fake_scores = fake_out
# 损失函数和优化,总的来讲就是训练判别器能判断图片是真图还是假图(生成图)
d_loss = d_loss_real + d_loss_fake # 损失包括判真损失和判假损失
total_d_loss += d_loss.data.item()
d_optimizer.zero_grad() # 在反向传播之前,先将梯度归0
d_loss.backward() # 将误差反向传播
d_optimizer.step() # 更新参数
# 训练生成器
# 原理:目的是希望生成的假的图片被判别器判断为真的图片,
# 在此过程中,将判别器固定,将假的图片传入判别器的结果与real_label的对应,
# 使得生成的图片让判别器以为是真的
# 这样就达到了对抗的目的
# 计算假的图片的损失
z = torch.randn(num_img, z_dim) # 得到随机噪声
if device == "CUDA":
z = z.cuda()
else:
z = z.cpu()
fake_img = G(z) # 随机噪声输入到生成器中,得到一副假的图片
output = D(fake_img) # 经过判别器得到的结果
g_loss = criterion(output, real_label) # 得到的假的图片与真实的图片的label的loss
total_g_loss += g_loss.data.item()
# bp and optimize
g_optimizer.zero_grad() # 梯度归0
g_loss.backward() # 进行反向传播
g_optimizer.step() # .step()一般用在反向传播后面,用于更新生成网络的参数
# 打印每个epoch的损失
print('Epoch[{}/{}],d_loss:{:.6f},g_loss:{:.6f} '.format(
cur_epoch, epoch, total_d_loss / steps_per_epoch, total_g_loss / steps_per_epoch) # 打印的是真实图片的损失均值
)
if cur_epoch == 0:
real_images = to_img(real_img.cpu().data)
save_image(real_images, './img/real_images.png')
fake_images = to_img(fake_img.data)
save_image(fake_images, './img/fake_images-{}.png'.format(cur_epoch + 1))
# 保存生成器和判别器模型
torch.save(generator, "model/generator.pkl")
torch.save(discriminator, "model/discriminator.pkl")
if __name__ == "__main__":
# 图像变化器,转为tensor并标准化数据
transform = transforms.Compose([
transforms.ToTensor(), # 数据范围[0,1],归一化
transforms.Normalize((0.5,), (0.5,)) # (x-mean) / std,数据范围[-1,1],经过Normalize后,可以加快模型的收敛速度(不确定)
])
# 加载数据集
dataset = datasets.MNIST(root='./data/',
train=True,
transform=transform,
download=True)
# 初始生成器generator与判别器discriminator
discriminator = Discriminator()
generator = Generator()
# batch_size
batch_size = 128
# epoch次数
epoch = 100
# lr学习率
lr = 3e-4
# 噪声维度
z_dim = 100
GAN_train_model(
dataset=dataset,
discriminator=discriminator,
generator=generator,
batch_size=batch_size,
epoch=epoch,
lr=lr,
z_dim=z_dim,
device="CUDA"
)