搭建DCGAN网络
"""
PROJECT:MNIST_DCGAN
Author:Ephemeroptera
Date:2018-4-25
QQ:605686962
Reference:' improved_wgan_training-master':
'Zardinality/WGAN-tensorflow':
'NELSONZHAO/zhihu':
"""
"""
Note: in this section , we add batch-normalization-laysers in G\D to acclerate training.Additionally,we use
moving average model to G to get well products from G
"""
import tensorflow as tf
import numpy as np
import pickle
import visualization
import os
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from threading import Thread
import time
from time import sleep
import cv2
mnist_dir = r'../mnist_dataset'
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets(mnist_dir)
def deconv(img,new_size,fmaps,name='deconv'):
with tf.variable_scope(name):
img = tf.image.resize_nearest_neighbor(img,new_size,name='upscale')
return tf.layers.conv2d(img,fmaps,3,padding='SAME',name='conv2d')
def Generator_DC_28x28(latents,is_train):
with tf.variable_scope("generator",reuse=(not is_train)):
dense0 = tf.layers.dense(latents,4*4*512,name='dense0')
dense0 = tf.reshape(dense0,[-1,4,4,512])
dense0 = tf.layers.batch_normalization(dense0, training=is_train)
dense0 = tf.nn.leaky_relu(dense0)
dense0 = tf.layers.dropout(dense0,rate=0.2)
a = tf.get_variable_scope().name
deconv1 = deconv(dense0,(7,7),256,name='deconv1')
deconv1 = tf.layers.batch_normalization(deconv1, training=is_train)
deconv1 = tf.nn.leaky_relu(deconv1)
deconv1 = tf.layers.dropout(deconv1,rate=0.2)
deconv2 = deconv(deconv1, (14, 14), 128, name='deconv2')
deconv2 = tf.layers.batch_normalization(deconv2, training=is_train)
deconv2 = tf.nn.leaky_relu(deconv2)
deconv2 = tf.layers.dropout(deconv2, rate=0.2)
deconv3 = deconv(deconv2, (28, 28), 64, name='deconv3')
deconv3 = tf.layers.batch_normalization(deconv3, training=is_train)
deconv3 = tf.nn.leaky_relu(deconv3)
deconv3 = tf.layers.dropout(deconv3, rate=0.2)
toimg = tf.layers.conv2d(deconv3,1,3,padding='SAME',bias_initializer=tf.zeros_initializer,
activation=tf.nn.tanh,name='toimg')
return toimg
def Discriminator_DC_28x28(img,reuse = False):
with tf.variable_scope("discriminator", reuse=reuse):
conv0 = tf.layers.conv2d(img,128,3,padding='SAME',activation=tf.nn.leaky_relu,
kernel_initializer=tf.random_normal_initializer(0,1), name='conv0')
conv0 = tf.layers.average_pooling2d(conv0,2,2,padding='SAME',name='pool0')
conv1 = tf.layers.conv2d(conv0, 256, 3, padding='SAME',
kernel_initializer=tf.random_normal_initializer(0, 1), name='conv1')
conv1 = tf.layers.batch_normalization(conv1,training=True)
conv1 = tf.nn.leaky_relu(conv1)
conv1 = tf.layers.average_pooling2d(conv1, 2, 2, padding='SAME', name='pool1')
conv2 = tf.layers.conv2d(conv1, 512, 3, padding='VALID',
kernel_initializer=tf.random_normal_initializer(0, 1), name='conv2')
conv2 = tf.layers.batch_normalization(conv2, training=True)
conv2 = tf.nn.leaky_relu(conv2)
dense3 = tf.reshape(conv2,[-1,5*5*512])
dense3 = tf.layers.dense(dense3,1,name='dense3')
outputs = tf.nn.sigmoid(dense3)
return dense3,outputs
def COUNT_VARS(vars):
total_para = 0
for variable in vars:
shape = variable.get_shape()
variable_para = 1
for dim in shape:
variable_para *= dim.value
total_para += variable_para
return total_para
def ShowParasList(paras):
p = open('./trainLog/Paras.txt', 'w')
p.writelines(['vars_total: %d'%COUNT_VARS(paras),'\n'])
for variable in paras:
p.writelines([variable.name, str(variable.get_shape()),'\n'])
print(variable.name, variable.get_shape())
p.close()
def GEN_DIR():
if not os.path.isdir('ckpt'):
print('DIR:ckpt NOT FOUND,BUILDING ON CURRENT PATH..')
os.mkdir('ckpt')
if not os.path.isdir('trainLog'):
print('DIR:ckpt NOT FOUND,BUILDING ON CURRENT PATH..')
os.mkdir('trainLog')
latents_dim = 128
smooth = 0.1
latents = tf.placeholder(shape=[None,latents_dim],dtype=tf.float32,name='latents')
input_real = tf.placeholder(shape=[None,28,28,1],dtype=tf.float32,name='input_real')
g_outputs = Generator_DC_28x28(latents,is_train=True)
g_test = Generator_DC_28x28(latents,is_train=False)
d_logits_real, d_outputs_real = Discriminator_DC_28x28(input_real,reuse=False)
d_logits_fake, d_outputs_fake = Discriminator_DC_28x28(g_outputs,reuse=True)
d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_real,
labels=tf.ones_like(d_logits_real)) * (1 - smooth))
d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake,
labels=tf.zeros_like(d_logits_fake)))
d_loss = tf.add(d_loss_real, d_loss_fake)
g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake,
labels=tf.ones_like(d_logits_fake)) * (1 - smooth))
train_vars = tf.trainable_variables()
d_train_vars = [var for var in train_vars if var.name.startswith("discriminator")]
g_train_vars = [var for var in train_vars if var.name.startswith("generator")]
for var in g_train_vars:
tf.add_to_collection('G_RAW',var)
all_vars = tf.all_variables()
g_all_vars = [var for var in all_vars if var.name.startswith("generator")]
g_bn_m_v = [var for var in g_all_vars if 'moving_mean' in var.name]
g_bn_m_v += [var for var in g_all_vars if 'moving_variance' in var.name]
for var in g_bn_m_v:
tf.add_to_collection('G_BN_MV',var)
learn_rate = 2e-4
G_step = tf.Variable(0, trainable=False)
D_step = tf.Variable(0, trainable=False)
with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
d_train_opt = tf.train.AdamOptimizer(learn_rate,beta1=0.5).minimize(d_loss, var_list=d_train_vars,global_step=D_step)
g_train_opt = tf.train.AdamOptimizer(learn_rate,beta1=0.5).minimize(g_loss, var_list=g_train_vars,global_step=G_step)
G_averages = tf.train.ExponentialMovingAverage(0.999, G_step)
gvars_averages_op = G_averages.apply(g_train_vars)
g_vars_ema = [G_averages.average(var) for var in g_train_vars]
for ema in g_vars_ema:
tf.add_to_collection('G_EMA',ema)
with tf.control_dependencies([g_train_opt,gvars_averages_op]):
g_train_opt_ema = tf.no_op(name='g_train_opt_ema')
max_iters = 5000
batch_size = 50
critic_n = 1
GEN_DIR()
GenLog = []
Losses = []
saver = tf.train.Saver(var_list=g_train_vars+g_vars_ema+g_bn_m_v)
def SavingRecords():
global Losses
global GenLog
with open('./trainLog/loss_variation.loss', 'wb') as l:
losses = np.array(Losses)
pickle.dump(losses, l)
print('saving Losses sucessfully!')
with open('./trainLog/GenLog.log', 'wb') as g:
GenLog = np.array(GenLog)
pickle.dump(GenLog, g)
print('saving GenLog sucessfully!')
def training():
with tf.Session() as sess:
init = (tf.global_variables_initializer(), tf.local_variables_initializer())
sess.run(init)
time_start = time.time()
for steps in range(max_iters+1):
data_batch = mnist.train.next_batch(batch_size)[0]
data_batch = np.reshape(data_batch,[-1,28,28,1])
data_batch = data_batch * 2 - 1
data_batch = data_batch.astype(np.float32)
z = np.random.normal(0, 1, size=[batch_size, latents_dim]).astype(np.float32)
for n in range(critic_n):
sess.run(d_train_opt, feed_dict={input_real: data_batch, latents: z})
sess.run(g_train_opt_ema, feed_dict={input_real: data_batch,latents: z})
train_loss_d = sess.run(d_loss, feed_dict={input_real: data_batch, latents: z})
train_loss_g = sess.run(g_loss, feed_dict={latents: z})
info = [steps, train_loss_d, train_loss_g]
gen_sanmpes = sess.run(g_outputs, feed_dict={latents: z})
visualization.CV2_BATCH_SHOW((gen_sanmpes[0:9] + 1) / 2, 0.5, 3, 3, delay=1)
print('iters::%d/%d..Discriminator_loss:%.3f..Generator_loss:%.3f..' % (
steps, max_iters, train_loss_d, train_loss_g))
if steps % 5 == 0:
Losses.append(info)
GenLog.append(gen_sanmpes)
if steps % 1000 == 0 and steps > 0:
saver.save(sess, './ckpt/generator.ckpt', global_step=steps)
if steps == max_iters:
time_over = time.time()
print('iterating is over! consuming time :%.2f'%(time_over-time_start))
sleep(3)
thread1 = Thread(target=SavingRecords, args=())
thread1.start()
yield info
"""
note: in this code , we will see the runtime-variation of G,D losses
"""
iters = []
dloss = []
gloss = []
fig = plt.figure('runtime-losses')
ax1 = fig.add_subplot(2,1,1,xlim=(0, max_iters), ylim=(-10, 10))
ax2 = fig.add_subplot(2,1,2,xlim=(0, max_iters), ylim=(-20, 20))
ax1.set_title('discriminator_loss')
ax2.set_title('generator_loss')
line1, = ax1.plot([], [], color='red',lw=1,label='discriminator')
line2, = ax2.plot([], [],color='blue', lw=1,label='generator')
fig.tight_layout()
def init():
line1.set_data([], [])
line2.set_data([], [])
return line1,line2
def update(info):
iters.append(info[0])
dloss.append(info[1])
gloss.append(info[2])
line1.set_data(iters, dloss)
line2.set_data(iters, gloss)
return line1, line2
ani = FuncAnimation(fig, update, frames=training,init_func=init, blit=True,interval=1,repeat=False)
plt.show()
验证DCGAN
DCGAN模型经过5000次迭代,结果如下
1.损失函数
2.生成日志展示
生成器验证