【注1】代码的原文来自以下网址,修改部分及增添注释(基本上都注释了)。修改版整体见最后,原版下方链接,均可以跑通,有问题欢迎交流。生成对抗网络GAN---生成mnist手写数字图像示例(
附代码)_陶将的博客-CSDN博客_gan生成手写数字
【注2】环境要求:≥python3.8,Windows10,pycharm2019,tensorflow2.70
如果是tensorflow版本问题可以考虑升级或使用原代码
import tensorflow as tf
tf.compat.v1.disable_eager_execution()
import numpy as np
import os
from tensorflow.examples.tutorials.mnist import input_data
from matplotlib import pyplot as plt
这里由于2.x版本的语法进行了修改,导致原文中部分代码无法运行,compat.v1使得可以让2.x直接运行1.x版本的代码。
【注3】正文的很多部分例如layers,dense等都需要在tf和包中间加入compat.v1
【注4】examples包我是自己下载的,csdn上已经有大佬上传了,记得找一个日期前一点的。
# 初始化准备
BATCH_SIZE = 64 # 每一轮训练的数量
UNITS_SIZE = 128 # 生成器隐藏层的参数
LEARNING_RATE = 0.001 # 学习速率
EPOCH = 300 # 训练迭代轮数
SMOOTH = 0.1 # 标签平滑
# 读入mnist数据,理论上1.9版本之后数据集不会自动下载,但是这个代码运行的时候是会下载出数据集的。
mnist = input_data.read_data_sets('/mnist_data/', one_hot=True)
【注5】参数可以随自己修改,本地cpu的要是显卡太垃圾(和我一样的话)建议找云服务器
# 生成器
def generatorModel(noise_img, units_size, out_size, alpha=0.01):
# 参数解析
# noise_img:生成器生成噪声图片
# units_size: 隐藏层单元数
# out_size:生成器输出图片大小
# alpha:激活函数的系数
with tf.compat.v1.variable_scope('generator'):
# 创建一个空间generator,使得在这个空间当中,变量可以重复使用
# 全连接,连接输入和隐藏层
FC = tf.compat.v1.layers.dense(noise_img, units_size)
# 隐藏层的激活函数,之后的dropout方法是为了防止发生过拟合的现象
reLu = tf.nn.leaky_relu(FC, alpha)
drop = tf.compat.v1.layers.dropout(reLu, rate=0.2)
# 全连接,连接隐藏层和输出层,输出层的激活函数选择tanh
logits = tf.compat.v1.layers.dense(drop, out_size)
outputs = tf.tanh(logits)
return logits, outputs
# 判别模型
def discriminatorModel(images, units_size, alpha=0.01, reuse=False):
# 参数详解
# images:真实图片
# reuse:是否重复占用空间
with tf.compat.v1.variable_scope('discriminator', reuse=reuse):
# 全连接
FC = tf.compat.v1.layers.dense(images, units_size)
# 隐藏层激活函数
reLu = tf.nn.leaky_relu(FC, alpha)
# 全连接,这里输出层的激活函数改为sigmoid
logits = tf.compat.v1.layers.dense(reLu, 1)
outputs = tf.sigmoid(logits)
return logits, outputs
【注6】这里可以看出,判别器和生成器的主要差别在于输出层的激活函数
def loss_function(real_logits, fake_logits, smooth):
# 生成器希望判别器判别出来的标签为1; tf.ones_like()创建一个将所有元素都设置为1的张量
G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=fake_logits,
labels=tf.ones_like(fake_logits) * (1 - smooth)))
# 判别器识别生成器产出的图片,希望识别出来的标签为0
fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=fake_logits,
labels=tf.zeros_like(fake_logits)))
# 判别器判别真实图片,希望判别出来的标签为1
real_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=real_logits,
labels=tf.ones_like(real_logits) * (1 - smooth)))
# 判别器总loss
D_loss = tf.add(fake_loss, real_loss)
return G_loss, fake_loss, real_loss, D_loss
【注7】
tf.nn.sigmoid_cross_entropy_with_logits
这个方法对传入的参数先使用sigmoid进行计算,然后在计算他们的交叉熵损失,使得结果不会溢出。
# 优化器
def optimizer(G_loss, D_loss, learning_rate):
# 首先获取网络结构中的参数,也就是判别器和生成器的变量,在后面的最小化损失时修改
train_var = tf.compat.v1.trainable_variables()
G_var = [var for var in train_var if var.name.startswith('generator')]
D_var = [var for var in train_var if var.name.startswith('discriminator')]
# 因为GAN中一共训练了两个网络,所以分别对G和D进行优化
# 这里使用AdamOptimizer方法来减少损失(娘希匹的2.x这玩意怎么直接用),动态调整每个参数的学习速率。
G_optimizer = tf.compat.v1.train.AdadeltaOptimizer(learning_rate).minimize(G_loss, var_list=G_var)
D_optimizer = tf.compat.v1.train.AdadeltaOptimizer(learning_rate).minimize(D_loss, var_list=D_var)
return G_optimizer, D_optimizer
【注8】以下两个部分均定义在一个def下!
def train(mnist):
# 前期准备,该流程和上面的逻辑顺序相同
# 真实图片的大小
image_size = mnist.train.images[0].shape[0]
# 定义接收输入的方法,占位符placeholder来获得输入的数据
real_images = tf.compat.v1.placeholder(tf.float32, [None, image_size])
fake_images = tf.compat.v1.placeholder(tf.float32, [None, image_size])
# 生成器参数解释
# 将噪声,生成器隐藏层节点数,真实图片大小传入生成器(这样搞可以生成大小一样的图片)
G_logits, G_output = generatorModel(fake_images, UNITS_SIZE, image_size)
# 判别器:先传入参数,给真实图片打分,再给生成图片打分。
# D对真实图像的判别
real_logits, real_output = discriminatorModel(real_images, UNITS_SIZE)
# D对G生成图像的判别,为其打分
fake_logits, fake_output = discriminatorModel(G_output, UNITS_SIZE, reuse=True)
# 计算损失函数
G_loss, real_loss, fake_loss, D_loss = loss_function(real_logits, fake_logits, SMOOTH)
# 优化
G_optimizer, D_optimizer = optimizer(G_loss, D_loss, LEARNING_RATE)
# 保存生成器变量
saver = tf.compat.v1.train.Saver()
step = 0
with tf.compat.v1.Session() as session:
# 初始化模型的参数
session.run(tf.compat.v1.global_variables_initializer())
for epoch in range(EPOCH):
for batch_i in range(mnist.train.num_examples // BATCH_SIZE):
batch_image, _ = mnist.train.next_batch(BATCH_SIZE)
# 对图像像素进行scale,tanh的输出结果为(-1,1),real和fake图片共享参数
batch_image = batch_image * 2 - 1
# 生成模型的输入噪声(图片)
noise_image = np.random.uniform(-1, 1, size=(BATCH_SIZE, image_size))
# 先训练生成器,在训练判别器
session.run(G_optimizer, feed_dict={fake_images: noise_image})
session.run(D_optimizer, feed_dict={real_images: batch_image, fake_images: noise_image})
step = step + 1
# 判别器D的损失(每一轮训练之后)
loss_D = session.run(D_loss, feed_dict={real_images: batch_image, fake_images: noise_image})
# D对真实图片(训练时)
loss_real = session.run(real_loss, feed_dict={real_images: batch_image, fake_images: noise_image})
# D对生成图片(训练时)
loss_fake = session.run(fake_loss, feed_dict={real_images: batch_image, fake_images: noise_image})
# 生成器的损失
loss_G = session.run(G_loss, feed_dict={fake_images: noise_image})
print('epoch:', epoch, 'loss_D:', loss_D, ' loss_real', loss_real, ' loss_fake', loss_fake, ' loss_G',
loss_G)
model_path = os.getcwd() + os.sep + "mnist.model"
# 存储
saver.save(session, model_path, global_step=step)
下面是代码成功运行的图片,300轮的化大概24分钟左右(我是垃圾显卡2g)
从这里可以看出,还是比较模糊的,在不大的改变代码的情况下,只能增加迭代次数。
import tensorflow as tf
tf.compat.v1.disable_eager_execution()
import numpy as np
import os
from tensorflow.examples.tutorials.mnist import input_data
from matplotlib import pyplot as plt
# 初始化准备
BATCH_SIZE = 64 # 每一轮训练的数量
UNITS_SIZE = 128 # 生成器隐藏层的参数
LEARNING_RATE = 0.001 # 学习速率
EPOCH = 300 # 训练迭代轮数
SMOOTH = 0.1 # 标签平滑
# 读入mnist数据,理论上1.9版本之后数据集不会自动下载,但是这个代码运行的时候是会下载出数据集的。
mnist = input_data.read_data_sets('/mnist_data/', one_hot=True)
# 生成器
def generatorModel(noise_img, units_size, out_size, alpha=0.01):
# 参数解析
# noise_img:生成器生成噪声图片
# units_size: 隐藏层单元数
# out_size:生成器输出图片大小
# alpha:激活函数的系数
with tf.compat.v1.variable_scope('generator'):
# 创建一个空间generator,使得在这个空间当中,变量可以重复使用
# 全连接,连接输入和隐藏层
FC = tf.compat.v1.layers.dense(noise_img, units_size)
# 隐藏层的激活函数,之后的dropout方法是为了防止发生过拟合的现象
reLu = tf.nn.leaky_relu(FC, alpha)
drop = tf.compat.v1.layers.dropout(reLu, rate=0.2)
# 全连接,连接隐藏层和输出层,输出层的激活函数选择tanh
logits = tf.compat.v1.layers.dense(drop, out_size)
outputs = tf.tanh(logits)
return logits, outputs
# 判别模型
def discriminatorModel(images, units_size, alpha=0.01, reuse=False):
# 参数详解
# images:真实图片
# reuse:是否重复占用空间
with tf.compat.v1.variable_scope('discriminator', reuse=reuse):
# 全连接
FC = tf.compat.v1.layers.dense(images, units_size)
# 隐藏层激活函数
reLu = tf.nn.leaky_relu(FC, alpha)
# 全连接,这里输出层的激活函数改为sigmoid
logits = tf.compat.v1.layers.dense(reLu, 1)
outputs = tf.sigmoid(logits)
return logits, outputs
# 损失函数
"""
判别器的目的是:
1. 对于真实图片,D要为其打上标签1
2. 对于生成图片,D要为其打上标签0
生成器的目的是:对于生成的图片,G希望D打上标签1
"""
def loss_function(real_logits, fake_logits, smooth):
# 生成器希望判别器判别出来的标签为1; tf.ones_like()创建一个将所有元素都设置为1的张量
G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=fake_logits,
labels=tf.ones_like(fake_logits) * (1 - smooth)))
# 判别器识别生成器产出的图片,希望识别出来的标签为0
fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=fake_logits,
labels=tf.zeros_like(fake_logits)))
# 判别器判别真实图片,希望判别出来的标签为1
real_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=real_logits,
labels=tf.ones_like(real_logits) * (1 - smooth)))
# 判别器总loss
D_loss = tf.add(fake_loss, real_loss)
return G_loss, fake_loss, real_loss, D_loss
# 优化器
def optimizer(G_loss, D_loss, learning_rate):
# 首先获取网络结构中的参数,也就是判别器和生成器的变量,在后面的最小化损失时修改
train_var = tf.compat.v1.trainable_variables()
G_var = [var for var in train_var if var.name.startswith('generator')]
D_var = [var for var in train_var if var.name.startswith('discriminator')]
# 因为GAN中一共训练了两个网络,所以分别对G和D进行优化
# 这里使用AdamOptimizer方法来减少损失(娘希匹的2.x这玩意怎么直接用),动态调整每个参数的学习速率。
G_optimizer = tf.compat.v1.train.AdadeltaOptimizer(learning_rate).minimize(G_loss, var_list=G_var)
D_optimizer = tf.compat.v1.train.AdadeltaOptimizer(learning_rate).minimize(D_loss, var_list=D_var)
return G_optimizer, D_optimizer
# 训练代码
def train(mnist):
# 前期准备,该流程和上面的逻辑顺序相同
# 真实图片的大小
image_size = mnist.train.images[0].shape[0]
# 定义接收输入的方法,占位符placeholder来获得输入的数据
real_images = tf.compat.v1.placeholder(tf.float32, [None, image_size])
fake_images = tf.compat.v1.placeholder(tf.float32, [None, image_size])
# 生成器参数解释
# 将噪声,生成器隐藏层节点数,真实图片大小传入生成器(这样搞可以生成大小一样的图片)
G_logits, G_output = generatorModel(fake_images, UNITS_SIZE, image_size)
# 判别器:先传入参数,给真实图片打分,再给生成图片打分。
# D对真实图像的判别
real_logits, real_output = discriminatorModel(real_images, UNITS_SIZE)
# D对G生成图像的判别,为其打分
fake_logits, fake_output = discriminatorModel(G_output, UNITS_SIZE, reuse=True)
# 计算损失函数
G_loss, real_loss, fake_loss, D_loss = loss_function(real_logits, fake_logits, SMOOTH)
# 优化
G_optimizer, D_optimizer = optimizer(G_loss, D_loss, LEARNING_RATE)
# 保存生成器变量
saver = tf.compat.v1.train.Saver()
step = 0
with tf.compat.v1.Session() as session:
# 初始化模型的参数
session.run(tf.compat.v1.global_variables_initializer())
for epoch in range(EPOCH):
for batch_i in range(mnist.train.num_examples // BATCH_SIZE):
batch_image, _ = mnist.train.next_batch(BATCH_SIZE)
# 对图像像素进行scale,tanh的输出结果为(-1,1),real和fake图片共享参数
batch_image = batch_image * 2 - 1
# 生成模型的输入噪声(图片)
noise_image = np.random.uniform(-1, 1, size=(BATCH_SIZE, image_size))
# 先训练生成器,在训练判别器
session.run(G_optimizer, feed_dict={fake_images: noise_image})
session.run(D_optimizer, feed_dict={real_images: batch_image, fake_images: noise_image})
step = step + 1
# 判别器D的损失(每一轮训练之后)
loss_D = session.run(D_loss, feed_dict={real_images: batch_image, fake_images: noise_image})
# D对真实图片(训练时)
loss_real = session.run(real_loss, feed_dict={real_images: batch_image, fake_images: noise_image})
# D对生成图片(训练时)
loss_fake = session.run(fake_loss, feed_dict={real_images: batch_image, fake_images: noise_image})
# 生成器的损失
loss_G = session.run(G_loss, feed_dict={fake_images: noise_image})
print('epoch:', epoch, 'loss_D:', loss_D, ' loss_real', loss_real, ' loss_fake', loss_fake, ' loss_G',
loss_G)
model_path = os.getcwd() + os.sep + "mnist.model"
# 存储
saver.save(session, model_path, global_step=step)
def main(argv=None):
train(mnist)
if __name__ == '__main__':
tf.compat.v1.app.run()