生成式对抗网络是一种无监督深度学习模型
主要组成:
生成模型G (generative model)
判别模型D (discriminative model)
网络的训练过程是这样的。首先,我们有一些随机噪声Z,喂入生成模型里面它会生成相应的数据,这个数据是假的。而相对于这个假的数据,我们有一些进行标注的或者是识别的真实数据。有了这个真数据和假数据之后,我们分别喂入判别网络模型,让他去学习里面的特征,数据类型,给出真假的一个判断。
举一个更通俗的例子来帮助大家理解。我们可以把生成模型想象成为一个伪造字画的造假的人。判别模型,可以想象成一个鉴定字画的鉴定师。开始的时候他们都还是新手,需要不断地学习,G的任务是生成以假乱真的字画,D的任务是从真假字画当中判别它的真假。
具体过程是这样的,首先我们给G一些材料也就是噪声,然后他会生成一幅字画,D会学习假字画特征和真字画的特征,然后去判定G生成这个到底是真的还是假的。开始的时候,G生成的这些字画是非常的拙劣,所以D一开始会判定为假。然后这个假的结果会返回到生成模型,他会在思考怎么才能生成更好,以假乱真的数据。如此这样的循环往复就是一个对抗的过程,最后,我们的目的就是,给生成模型寄一些随机的噪声,它就能生成以假乱真的一些数据骗过我们的判别模型。
对抗过程,用数学来表达,是一个二元极小极大值博弈,关于判别网络和生成网络的一个价值函数。
第一个是关于D的训练网络的一个函数,我们希望最大化log D(x),训练网络能够最大概率的分配这个训练模型的标签,也就是说我们希望鉴定师,他能够对真假数据的判定越来越准确。
第二个我们希望最小化log(1-D(G(Z)))。意思是,训练网络G生成能够欺骗网络D的数据,最大化D的损失。假设真数据为1,假数据为零。把G生成的数据拿去判别,结果应该为0,但是站在G的角度,希望能判别为1。
导入相关的包
tensorflow (谷歌开源的机器学习平台,有很多关于深度学习训练的函数)
numpy (用于数值计算)
matplotlib (画图用的)
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
读入mnist数据集的数据
将代码和数据集文件夹(mnistdata)放在同一目录下
mnist = input_data.read_data_sets('./mnistdata', one_hot=True)
随机噪声函数
采用均匀分布 np.randow.uniform(low, high,size)
返回值是一个从-1到1的噪声
def sample_z(m,n):
return np.random.uniform(-1.,1.,size=[m,n])
生成模型输入和参数初始化
Z = tf.placeholder(tf.float32, shape=[None, 100])
G_W1 = tf.get_variable("G_W1", shape=[100,128],initializer=tf.contrib.layers.xavier_initializer())
G_b1= tf.Variable(tf.zeros(shape=[128]))
G_W2 = tf.get_variable("G_W2", shape=[128,784],initializer=tf.contrib.layers.xavier_initializer())
G_b2= tf.Variable(tf.zeros(shape=[784]))
theta_G = [G_W1, G_W2, G_b1, G_b2]
生成模型
def Gene(Z):
G_h1 = tf.nn.relu(tf.matmul(Z, G_W1)+G_b1)#第一层矩阵相乘后激活
G_log_prob = tf.matmul(G_h1, G_W2)+G_b2#第二层
return tf.nn.sigmoid(G_log_prob)
判别模型输入和参数初始化
与生成模型基本相同,但有些数据需要对应,从784到1
X = tf.placeholder(tf.float32, shape=[None, 784])
D_W1 = tf.get_variable("D_W1", shape=[784,128],initializer=tf.contrib.layers.xavier_initializer())
D_b1= tf.Variable(tf.zeros(shape=[128]))
D_W2 = tf.get_variable("D_W2", shape=[128,1],initializer=tf.contrib.layers.xavier_initializer())
D_b2= tf.Variable(tf.zeros(shape=[1]))
theta_D = [D_W1, D_W2, D_b1, D_b2]
判别模型
比生成模型多返回第二层
def Disc(x):
D_h1 = tf.nn.relu(tf.matmul(x, D_W1)+D_b1)
D_logit = tf.matmul(D_h1, D_W2)+D_b2
D_prob = tf.nn.sigmoid(D_logit)
return D_prob, D_logit
画图
def plot(samples):
fig = plt.figure(figsize=(4, 4))
gs = gridspec.GridSpec(4, 4)
gs.update(wspace=0.05, hspace=0.05)
for i, sample in enumerate(samples):
ax = plt.subplot(gs[i])
plt.axis('off')
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_aspect('equal')
plt.imshow(sample.reshape(28, 28), cmap='Greys_r')
return fig
喂入数据
G_sample = Gene(Z)
D_real, D_logit_real = Disc(X)
D_fake, D_logit_fake = Disc(G_sample)
计算G和D的损失(loss)均值
交叉熵(度量两个概率分布间的差异性信息),差异越大,交叉熵越大
D_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits = D_logit_real, labels=tf.ones_like(D_logit_real)))
D_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits = D_logit_fake, labels=tf.zeros_like(D_logit_fake)))
D_loss = D_loss_real + D_loss_fake
G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake, labels=tf.ones_like(D_logit_fake)))
Adam算法优化器,学习率为0.0001
D_solver = tf.train.AdamOptimizer(0.0001).minimize(D_loss,var_list=theta_D)
G_solver = tf.train.AdamOptimizer(0.0001).minimize(G_loss, var_list=theta_G)
图像输出的位置
if not os.path.exists('out/'):
os.makedirs('out/')
开始训练
一共迭代了1000000次,每1000次会生成一张图片
i=0
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
mb_size = 128
Z_dim = 100
for it in range(1000000):
if it % 1000 == 0:
samples = sess.run(G_sample, feed_dict={Z: sample_z(16, Z_dim)})
fig = plot(samples)
plt.savefig('out/{}.png'.format(str(i).zfill(3)), bbox_inches='tight')
i += 1
plt.close(fig)
X_mb, _ = mnist.train.next_batch(mb_size)
_, D_loss_curr = sess.run([D_solver, D_loss], feed_dict={X: X_mb, Z: sample_z(mb_size, Z_dim)})
_, G_loss_curr = sess.run([G_solver, G_loss], feed_dict={Z: sample_z(mb_size, Z_dim)})
if it % 1000 == 0:
print('Iter: {}'.format(it),'D loss: {}'.format(D_loss_curr),'G_loss: {}'.format(G_loss_curr))
从左到右,从上到下依次是迭代了1000次,333000次,666000次,999000次后得到的图像
完整代码
https://github.com/SinsoledadFairy/mnist-gan