新手上路,小心晕车
一、前言
1.学习代码地址:https://github.com/carpedm20/DCGAN-tensorflow
2.参考博客:https://blog.csdn.net/liuxiao214/article/details/74502975
二、代码整体概述
1.代码分为五个部分:download.py、utils.py、ops.py、 main.py、 model.py,前三个文件用途分别为下载数据集、定义图像操作函数、定义G网和D网的操作函数、,后两个文件分别为定义程序入口、建立对抗网络模型及训练过程。
2.程序使用的数据集有三个:celebA、lsun、mnist
3.程序根据数据集的不同,分为两个部分:一个是处理不带标签的数据集,例如celebA数据集;另一个则是带标签的数据集例如mnist数据集,分两个部分也就是说构建的G网和D网网络结构不一样
4.代码参数有很多,做如下总结:
y_dim: 标签位数,例如mnist数据集y_dim=10
z_dim: g网输入的噪声数据z的维度,默认是100
c_dim: 图片的通道数,等于1表示灰度图,等于3表示RGB彩图
gf_dim: g网最后一层的特征图数量,默认是64
gfc_dim: g网网络参数,默认1024
df_dim: d网第一个卷积层特征图数量,默认64
dfc_dim: d网最后一个全连接层特征点数
g_bn0\g_bn1\g_bn2\g_bn3: g网的第h0、h1、h2、h3层的批标准化对象
d_bn1\d_bn2\d_bn3: g网的第h1、h2、h3层的批标准化对象
data_X: mnist数据集输入数据[70000,28,28,1]
data_y: mnist数据集标签数据[70000,10]
inputs: D网输入图片的placeholder
y: d网标签数据的placeholder,根据y_dim来取值
z:g网输入的噪声数据[batch_size,z_dim],服从[-1,1)的均匀分布
output_height\output_width: g网生成图片的尺寸
input_height\input_width: 真实图片的尺寸
5.使用celebA数据集与mnist数据集测试g网与d网的结构,结果如下:
a.对于celebA数据集(代表无标签数据集):
g网: z(64,100)---h0(64,4,4,512)---h1(64,8,8,256)---h2(64,16,16,128)--h3(64,32,32,64)--h4(64,64,64,3)
d网: image(64,64,64,3)--h0(64,32,32,64)--h1(64,16,16,128)--h2(64,8,8,256)--h3(64,4,4,512)--h4(64,1)
b.对于mnist数据集(代表有标签的数据集,不需要裁剪):
g网: z(64,110)--h0(64,1034)--h1(64,7,7,138)--h2(64,14,14,138)--h3(64,28,28,1)
d网: image(64,28,28,1)--x(64,28,28,11)--h0(64,24,24,21)--h1(64,3636)--h2(64,1034)--h3(64,1)
三、代码注解
1. utils.py
# _future_库可以让程序员使用未来python版本的功能,必须放在代码第一行,division支持精确除法
from __future__ import division
# python数学函数库,celi函数向上取整
import math
# json(javascript object notation)用于在python里接受js对象标注数据
import json
import random
#格式化输出python数据结构
import pprint
# 图像处理库
import scipy.misc
import numpy as np
# 获取当前时间(服务器时间),gmtime获取能被计算机处理的时间格式,strtime支持格式化输出时间
from time import gmtime, strftime
# six.moves消除python2和python3之间的差异
from six.moves import xrange
import tensorflow as tf
# tensorflow.contrib库并不是官方提供的库,它是其它开发者提供的开源库,slim用于简化代码,也就是变化代码书写方式,使之简洁明了
import tensorflow.contrib.slim as slim
pp = pprint.PrettyPrinter()
get_stddev = lambda x, k_h, k_w: 1/math.sqrt(k_w*k_h*x.get_shape()[-1])
# 打印参与训练的所有变量数据结构及信息
def show_all_variables():
# tf.trainable_variables()返回的是需要训练的变量列表
model_vars = tf.trainable_variables()
# 打印训练变量信息
slim.model_analyzer.analyze_vars(model_vars, print_info=True)
# 从指定路径中获取图片并裁剪图片、把图像数据正则化到[-1,1]
# 是用这些参数去改变图像
def get_image(image_path, input_height, input_width,resize_height=64, resize_width=64,crop=True, grayscale=False):
image = imread(image_path, grayscale)
return transform(image, input_height, input_width,
resize_height, resize_width, crop)
# 把一个批准的图片,合并成一张大图并保存到指定路径中
def save_images(images, size, image_path):
return imsave(inverse_transform(images), size, image_path)
# 从指定路径读取图片
def imread(path, grayscale = False):
if (grayscale):
# scripy.misc.imread()函数读取图片,返回的是numpy类型,flatten参数使之灰度化
return scipy.misc.imread(path, flatten = True).astype(np.float)
else:
return scipy.misc.imread(path).astype(np.float)
# ???
def merge_images(images, size):
return inverse_transform(images)
# merge函数用于将一个批次的图片,按列合并成一张大图,列如64张图片,合成一张8*8的大图,返回的类型是numpy[a,b,c]/numpy[a,b]
# images是四维数组[a,b,c,d] a是图片数量,b是高,c是宽,d是通道数
# size是合并的大图尺寸,size[0]是一列的图片张数,size[1]是一行的图片张数
def merge(images, size):
h, w = images.shape[1], images.shape[2]
# RGB图
if (images.shape[3] in (3,4)):
c = images.shape[3]
# img是以0为填充的空白大图
img = np.zeros((h * size[0], w * size[1], c))
for idx, image in enumerate(images):
i = idx % size[1]
# //运算符除数运算并将结果向下取整
j = idx // size[1]
# i为列,j为行,把图片按列排开,形成一张大图
img[j * h:j * h + h, i * w:i * w + w, :] = image
return img
# 灰度图
elif images.shape[3]==1:
img = np.zeros((h * size[0], w * size[1]))
for idx, image in enumerate(images):
i = idx % size[1]
j = idx // size[1]
img[j * h:j * h + h, i * w:i * w + w] = image[:,:,0]
return img
else:
raise ValueError('in merge(images,size) images parameter '
'must have dimensions: HxW or HxWx3 or HxWx4')
# imsave函数按指定路径存储图片
def imsave(images, size, path):
# np.squeeze 函数:从数组的形状中删除单维度条目,即把shape中为1的维度去掉
image = np.squeeze(merge(images, size))
return scipy.misc.imsave(path, image)
# 中央裁剪函数,为transform函数所调用,输入图片真实尺寸h,w,输入尺寸i,j,裁剪尺寸s,t
# 图片被分为:(h-i)/2 -i(s)-(h-i)/2
def center_crop(x, crop_h, crop_w,resize_h=64, resize_w=64):
if crop_w is None:
crop_w = crop_h
h, w = x.shape[:2]
# 函数round返回浮点数的四舍五入值
j = int(round((h - crop_h)/2.))
i = int(round((w - crop_w)/2.))
return scipy.misc.imresize(
x[j:j+crop_h, i:i+crop_w], [resize_h, resize_w])
# 输入图片裁剪函数
def transform(image, input_height, input_width, resize_height=64, resize_width=64, crop=True):
if crop:
cropped_image = center_crop(
image, input_height, input_width,
resize_height, resize_width)
else:
# scripy.misc.imresize()函数会改变图像大小并且隐藏归一化到0-255区间
cropped_image = scipy.misc.imresize(image, [resize_height, resize_width])
# 图像数据正则化到[-1,1]之间,可能和G网络使用tanh函数作为激活函数有关
return np.array(cropped_image)/127.5 - 1.
# 把图像数据变成[0,1]之间,被save_images函数调用,为什么???
def inverse_transform(images):
return (images+1.)/2.
#
def to_json(output_path, *layers):
with open(output_path, "w") as layer_f:
lines = ""
for w, b, bn in layers:
layer_idx = w.name.split('/')[0].split('h')[1]
B = b.eval()
if "lin/" in w.name:
W = w.eval()
depth = W.shape[1]
else:
W = np.rollaxis(w.eval(), 2, 0)
depth = W.shape[0]
biases = {"sy": 1, "sx": 1, "depth": depth, "w": ['%.2f' % elem for elem in list(B)]}
if bn != None:
gamma = bn.gamma.eval()
beta = bn.beta.eval()
gamma = {"sy": 1, "sx": 1, "depth": depth, "w": ['%.2f' % elem for elem in list(gamma)]}
beta = {"sy": 1, "sx": 1, "depth": depth, "w": ['%.2f' % elem for elem in list(beta)]}
else:
gamma = {"sy": 1, "sx": 1, "depth": 0, "w": []}
beta = {"sy": 1, "sx": 1, "depth": 0, "w": []}
if "lin/" in w.name:
fs = []
for w in W.T:
fs.append({"sy": 1, "sx": 1, "depth": W.shape[0], "w": ['%.2f' % elem for elem in list(w)]})
lines += """
var layer_%s = {
"layer_type": "fc",
"sy": 1, "sx": 1,
"out_sx": 1, "out_sy": 1,
"stride": 1, "pad": 0,
"out_depth": %s, "in_depth": %s,
"biases": %s,
"gamma": %s,
"beta": %s,
"filters": %s
};""" % (layer_idx.split('_')[0], W.shape[1], W.shape[0], biases, gamma, beta, fs)
else:
fs = []
for w_ in W:
fs.append({"sy": 5, "sx": 5, "depth": W.shape[3], "w": ['%.2f' % elem for elem in list(w_.flatten())]})
lines += """
var layer_%s = {
"layer_type": "deconv",
"sy": 5, "sx": 5,
"out_sx": %s, "out_sy": %s,
"stride": 2, "pad": 1,
"out_depth": %s, "in_depth": %s,
"biases": %s,
"gamma": %s,
"beta": %s,
"filters": %s
};""" % (layer_idx, 2**(int(layer_idx)+2), 2**(int(layer_idx)+2),
W.shape[0], W.shape[3], biases, gamma, beta, fs)
layer_f.write(" ".join(lines.replace("'","").split()))
#
def make_gif(images, fname, duration=2, true_image=False):
import moviepy.editor as mpy
def make_frame(t):
try:
x = images[int(len(images)/duration*t)]
except:
x = images[-1]
if true_image:
# uint8无符号8位整数,0~255
return x.astype(np.uint8)
else:
return ((x+1)/2*255).astype(np.uint8)
clip = mpy.VideoClip(make_frame, duration=duration)
clip.write_gif(fname, fps = len(images) / duration)
# 保存测试图片
def visualize(sess, dcgan, config, option):
# a**b表示a的b次幂,.5表示0.5
image_frame_dim = int(math.ceil(config.batch_size**.5))
if option == 0:
z_sample = np.random.uniform(-0.5, 0.5, size=(config.batch_size, dcgan.z_dim))
samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample})
save_images(samples, [image_frame_dim, image_frame_dim], './samples/test_%s.png' % strftime("%Y-%m-%d-%H-%M-%S", gmtime()))
elif option == 1:
# np.arange()函数返回一个有终点和起点的固定步长的排列
values = np.arange(0, 1, 1./config.batch_size)
for idx in xrange(dcgan.z_dim):
print(" [*] %d" % idx)
z_sample = np.random.uniform(-1, 1, size=(config.batch_size , dcgan.z_dim))
for kdx, z in enumerate(z_sample):
z[idx] = values[kdx]
if config.dataset == "mnist":
y = np.random.choice(10, config.batch_size)
y_one_hot = np.zeros((config.batch_size, 10))
y_one_hot[np.arange(config.batch_size), y] = 1
samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample, dcgan.y: y_one_hot})
else:
samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample})
save_images(samples, [image_frame_dim, image_frame_dim], './samples/test_arange_%s.png' % (idx))
elif option == 2:
values = np.arange(0, 1, 1./config.batch_size)
for idx in [random.randint(0, dcgan.z_dim - 1) for _ in xrange(dcgan.z_dim)]:
print(" [*] %d" % idx)
z = np.random.uniform(-0.2, 0.2, size=(dcgan.z_dim))
z_sample = np.tile(z, (config.batch_size, 1))
#z_sample = np.zeros([config.batch_size, dcgan.z_dim])
for kdx, z in enumerate(z_sample):
z[idx] = values[kdx]
if config.dataset == "mnist":
y = np.random.choice(10, config.batch_size)
y_one_hot = np.zeros((config.batch_size, 10))
y_one_hot[np.arange(config.batch_size), y] = 1
samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample, dcgan.y: y_one_hot})
else:
samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample})
try:
make_gif(samples, './samples/test_gif_%s.gif' % (idx))
except:
save_images(samples, [image_frame_dim, image_frame_dim], './samples/test_%s.png' % strftime("%Y-%m-%d-%H-%M-%S", gmtime()))
elif option == 3:
values = np.arange(0, 1, 1./config.batch_size)
for idx in xrange(dcgan.z_dim):
print(" [*] %d" % idx)
z_sample = np.zeros([config.batch_size, dcgan.z_dim])
for kdx, z in enumerate(z_sample):
z[idx] = values[kdx]
samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample})
make_gif(samples, './samples/test_gif_%s.gif' % (idx))
elif option == 4:
image_set = []
values = np.arange(0, 1, 1./config.batch_size)
for idx in xrange(dcgan.z_dim):
print(" [*] %d" % idx)
z_sample = np.zeros([config.batch_size, dcgan.z_dim])
for kdx, z in enumerate(z_sample): z[idx] = values[kdx]
image_set.append(sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample}))
make_gif(image_set[-1], './samples/test_gif_%s.gif' % (idx))
new_image_set = [merge(np.array([images[idx] for images in image_set]), [10, 10]) \
for idx in range(64) + range(63, -1, -1)]
make_gif(new_image_set, './samples/test_gif_merged.gif', duration=8)
# 判断图片数量能否被开方,若能,则返回开方后的高与宽,否则报错
def image_manifold_size(num_images):
# np.floor()向下取整
manifold_h = int(np.floor(np.sqrt(num_images)))
# np.ceil()向上取整
manifold_w = int(np.ceil(np.sqrt(num_images)))
# 断言:如果h!=w,则报错,否则继续执行
assert manifold_h * manifold_w == num_images,'图片数量不可被开方'
return manifold_h, manifold_w
2. ops.py
import math
import numpy as np
import tensorflow as tf
from tensorflow.python.framework import ops
from utils import *
try:
image_summary = tf.image_summary
scalar_summary = tf.scalar_summary
histogram_summary = tf.histogram_summary
merge_summary = tf.merge_summary
SummaryWriter = tf.train.SummaryWriter
except:
image_summary = tf.summary.image
scalar_summary = tf.summary.scalar
histogram_summary = tf.summary.histogram
merge_summary = tf.summary.merge
SummaryWriter = tf.summary.FileWriter
# dir为python的内置函数,无参时,返回当前模块/对象的属性及方法列表,有参时,返回参数的属性及方法列表
# 此if-else语句为的是定义concat()函数
if "concat_v2" in dir(tf):
# tf.concat()是张量拼接函数,axis=[0,1,2]表示高维到低维
def concat(tensors, axis, *args, **kwargs):
return tf.concat_v2(tensors, axis, *args, **kwargs)
else:
def concat(tensors, axis, *args, **kwargs):
return tf.concat(tensors, axis, *args, **kwargs)
# 数据标准化
class batch_norm(object):
# 当构造类对象时,构造函数会自动调用
def __init__(self, epsilon=1e-5, momentum = 0.9, name="batch_norm"):
with tf.variable_scope(name):
self.epsilon = epsilon
self.momentum = momentum
self.name = name
def __call__(self, x, train=True):
return tf.contrib.layers.batch_norm(x,
decay=self.momentum,
updates_collections=None,
epsilon=self.epsilon,
scale=True,
is_training=train,
scope=self.name)
# 连接tensor
def conv_cond_concat(x, y):
"""Concatenate conditioning vector on feature map axis."""
x_shapes = x.get_shape()
y_shapes = y.get_shape()
return concat([
x, y*tf.ones([x_shapes[0], x_shapes[1], x_shapes[2], y_shapes[3]])], 3)
# input_输入数据,output_dim输出特征图的数量,k_h卷积核高 k_w卷积核宽,d_h卷积步长高,d_w卷积步长宽,stddev标准偏差
def conv2d(input_, output_dim, k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02,name="conv2d"):
with tf.variable_scope(name):
w = tf.get_variable('w', [k_h, k_w, input_.get_shape()[-1], output_dim],
initializer=tf.truncated_normal_initializer(stddev=stddev))
conv = tf.nn.conv2d(input_, w, strides=[1, d_h, d_w, 1], padding='SAME')
biases = tf.get_variable('biases', [output_dim], initializer=tf.constant_initializer(0.0))
conv = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape())
return conv
# 转置卷积,input_输入数据,output_shape要求的数据
def deconv2d(input_, output_shape,k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02,name="deconv2d", with_w=False):
with tf.variable_scope(name):
# filter : [height, width, output_channels, in_channels]
w = tf.get_variable('w', [k_h, k_w, output_shape[-1], input_.get_shape()[-1]],
initializer=tf.random_normal_initializer(stddev=stddev))
try:
deconv = tf.nn.conv2d_transpose(input_, w, output_shape=output_shape,
strides=[1, d_h, d_w, 1])
# Support for verisons of TensorFlow before 0.7.0
except AttributeError:
deconv = tf.nn.deconv2d(input_, w, output_shape=output_shape,
strides=[1, d_h, d_w, 1])
biases = tf.get_variable('biases', [output_shape[-1]], initializer=tf.constant_initializer(0.0))
deconv = tf.reshape(tf.nn.bias_add(deconv, biases), deconv.get_shape())
if with_w:
return deconv, w, biases
else:
return deconv
# 定义leaky relu激励函数
def lrelu(x, leak=0.2, name="lrelu"):
# tf.maximum(a,b)返回a,b之间的最大值
return tf.maximum(x, leak*x)
# 定义x*w+b的运算,并返回结果
def linear(input_, output_size, scope=None, stddev=0.02, bias_start=0.0, with_w=False):
shape = input_.get_shape().as_list()
with tf.variable_scope(scope or "Linear"):
try:
matrix = tf.get_variable("Matrix", [shape[1], output_size], tf.float32,
tf.random_normal_initializer(stddev=stddev))
except ValueError as err:
msg = "NOTE: Usually, this is due to an issue with the image dimensions. Did you correctly set '--crop' or '--input_height' or '--output_height'?"
err.args = err.args + (msg,)
raise
bias = tf.get_variable("bias", [output_size],
initializer=tf.constant_initializer(bias_start))
if with_w:
return tf.matmul(input_, matrix) + bias, matrix, bias
else:
return tf.matmul(input_, matrix) + bias
3.main.py
import os
import scipy.misc
import numpy as np
from model import DCGAN
from utils import pp, visualize, to_json, show_all_variables
import tensorflow as tf
# 定义命令行参数
flags = tf.app.flags
flags.DEFINE_integer("epoch", 25, "Epoch to train [25]")
flags.DEFINE_float("learning_rate", 0.0002, "Learning rate of for adam [0.0002]")
flags.DEFINE_float("beta1", 0.5, "Momentum term of adam [0.5]")
flags.DEFINE_float("train_size", np.inf, "The size of train images [np.inf]")
flags.DEFINE_integer("batch_size", 64, "The size of batch images [64]")
flags.DEFINE_integer("input_height", 28, "The size of image to use (will be center cropped). [108]")
flags.DEFINE_integer("input_width", None, "The size of image to use (will be center cropped). If None, same value as input_height [None]")
flags.DEFINE_integer("output_height", 28, "The size of the output images to produce [64]")
flags.DEFINE_integer("output_width", None, "The size of the output images to produce. If None, same value as output_height [None]")
flags.DEFINE_string("dataset", "mnist", "The name of dataset [celebA, mnist, lsun]")
flags.DEFINE_string("input_fname_pattern", "*.jpg", "Glob pattern of filename of input images [*]")
flags.DEFINE_string("checkpoint_dir", "checkpoint", "Directory name to save the checkpoints [checkpoint]")
flags.DEFINE_string("data_dir", "./data", "Root directory of dataset [data]")
flags.DEFINE_string("sample_dir", "samples", "Directory name to save the image samples [samples]")
flags.DEFINE_boolean("train", True, "True for training, False for testing [False]")
flags.DEFINE_boolean("crop", False, "True for training, False for testing [False]")
flags.DEFINE_boolean("visualize", False, "True for visualizing, False for nothing [False]")
flags.DEFINE_integer("generate_test_images", 100, "Number of images to generate during test. [100]")
FLAGS = flags.FLAGS
def main(_):
pp.pprint(flags.FLAGS.__flags)
if FLAGS.input_width is None:
FLAGS.input_width = FLAGS.input_height
if FLAGS.output_width is None:
FLAGS.output_width = FLAGS.output_height
if not os.path.exists(FLAGS.checkpoint_dir):
os.makedirs(FLAGS.checkpoint_dir)
if not os.path.exists(FLAGS.sample_dir):
os.makedirs(FLAGS.sample_dir)
#gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.333)
run_config = tf.ConfigProto()
run_config.gpu_options.allow_growth=True
with tf.Session(config=run_config) as sess:
if FLAGS.dataset == 'mnist':
dcgan = DCGAN(
sess,
input_width=FLAGS.input_width,
input_height=FLAGS.input_height,
output_width=FLAGS.output_width,
output_height=FLAGS.output_height,
batch_size=FLAGS.batch_size,
sample_num=FLAGS.batch_size,
y_dim=10,
z_dim=FLAGS.generate_test_images,
dataset_name=FLAGS.dataset,
input_fname_pattern=FLAGS.input_fname_pattern,
crop=FLAGS.crop,
checkpoint_dir=FLAGS.checkpoint_dir,
sample_dir=FLAGS.sample_dir,
data_dir=FLAGS.data_dir)
else:
dcgan = DCGAN(
sess,
input_width=FLAGS.input_width,
input_height=FLAGS.input_height,
output_width=FLAGS.output_width,
output_height=FLAGS.output_height,
batch_size=FLAGS.batch_size,
sample_num=FLAGS.batch_size,
z_dim=FLAGS.generate_test_images,
dataset_name=FLAGS.dataset,
input_fname_pattern=FLAGS.input_fname_pattern,
crop=FLAGS.crop,
checkpoint_dir=FLAGS.checkpoint_dir,
sample_dir=FLAGS.sample_dir,
data_dir=FLAGS.data_dir)
show_all_variables()
if FLAGS.train:
dcgan.train(FLAGS)
else:
if not dcgan.load(FLAGS.checkpoint_dir)[0]:
raise Exception("[!] Train a model first, then run test mode")
# to_json("./web/js/layers.js", [dcgan.h0_w, dcgan.h0_b, dcgan.g_bn0],
# [dcgan.h1_w, dcgan.h1_b, dcgan.g_bn1],
# [dcgan.h2_w, dcgan.h2_b, dcgan.g_bn2],
# [dcgan.h3_w, dcgan.h3_b, dcgan.g_bn3],
# [dcgan.h4_w, dcgan.h4_b, None])
# 测试
# Below is codes for visualization
OPTION = 1
visualize(sess, dcgan, FLAGS, OPTION)
# _name_是python的默认类属性,_name_在自己py文件里是_main_,在被引用的py文件里是文件名,
# 可以通过if语句,设置程序入口,也就是主函数,并且可以使文件被import后,主函数不会被运行
if __name__ == '__main__':
#tf.app.run()先解析命令行参数对象,tf.app.flags.Flags._parse_flags();再调用程序的main()函数,如果没有main()函数,则传入函数名作为main()函数,tf.app.run(test)
tf.app.run()
4.model.py
from __future__ import division
import os
import time
import math
from glob import glob
import tensorflow as tf
import numpy as np
from six.moves import xrange
from tensorflow.examples.tutorials.mnist import input_data
from ops import *
from utils import *
def conv_out_size_same(size, stride):
return int(math.ceil(float(size) / float(stride)))
class DCGAN(object):
def __init__(self, sess, input_height=108, input_width=108, crop=True,
batch_size=64, sample_num = 64, output_height=64, output_width=64,
y_dim=None, z_dim=100, gf_dim=64, df_dim=64,
gfc_dim=1024, dfc_dim=1024, c_dim=3, dataset_name='default',
input_fname_pattern='*.jpg', checkpoint_dir=None, sample_dir=None, data_dir='./data'):
self.sess = sess
self.crop = crop
self.batch_size = batch_size
self.sample_num = sample_num
self.input_height = input_height
self.input_width = input_width
self.output_height = output_height
self.output_width = output_width
self.y_dim = y_dim
self.z_dim = z_dim
self.gf_dim = gf_dim
self.df_dim = df_dim
self.gfc_dim = gfc_dim
self.dfc_dim = dfc_dim
# batch normalization : deals with poor initialization helps gradient flow
self.d_bn1 = batch_norm(name='d_bn1')
self.d_bn2 = batch_norm(name='d_bn2')
if not self.y_dim:
self.d_bn3 = batch_norm(name='d_bn3')
self.g_bn0 = batch_norm(name='g_bn0')
self.g_bn1 = batch_norm(name='g_bn1')
self.g_bn2 = batch_norm(name='g_bn2')
if not self.y_dim:
self.g_bn3 = batch_norm(name='g_bn3')
self.dataset_name = dataset_name
self.input_fname_pattern = input_fname_pattern
self.checkpoint_dir = checkpoint_dir
self.data_dir = data_dir
if self.dataset_name == 'mnist':
self.data_X, self.data_y = self.load_mnist()
self.c_dim = self.data_X[0].shape[-1]
else:
data_path = os.path.join(self.data_dir, self.dataset_name, self.input_fname_pattern)
self.data = glob(data_path)
if len(self.data) == 0:
raise Exception("[!] No data found in '" + data_path + "'")
np.random.shuffle(self.data)
imreadImg = imread(self.data[0])
if len(imreadImg.shape) >= 3: #check if image is a non-grayscale image by checking channel number
self.c_dim = imread(self.data[0]).shape[-1]
else:
self.c_dim = 1
if len(self.data) < self.batch_size:
raise Exception("[!] Entire dataset size is less than the configured batch_size")
self.grayscale = (self.c_dim == 1)
self.build_model()
#建立模型
def build_model(self):
if self.y_dim:
self.y = tf.placeholder(tf.float32, [self.batch_size, self.y_dim], name='y')
else:
self.y = None
if self.crop:
image_dims = [self.output_height, self.output_width, self.c_dim]
else:
image_dims = [self.input_height, self.input_width, self.c_dim]
self.inputs = tf.placeholder(
tf.float32, [self.batch_size] + image_dims, name='real_images')
inputs = self.inputs
self.z = tf.placeholder(
tf.float32, [None, self.z_dim], name='z')
self.z_sum = histogram_summary("z", self.z)
self.G = self.generator(self.z, self.y)
self.D, self.D_logits = self.discriminator(inputs, self.y, reuse=False)
self.sampler = self.sampler(self.z, self.y)
self.D_, self.D_logits_ = self.discriminator(self.G, self.y, reuse=True)
self.d_sum = histogram_summary("d", self.D)
self.d__sum = histogram_summary("d_", self.D_)
self.G_sum = image_summary("G", self.G)
def sigmoid_cross_entropy_with_logits(x, y):
try:
return tf.nn.sigmoid_cross_entropy_with_logits(logits=x, labels=y)
except:
return tf.nn.sigmoid_cross_entropy_with_logits(logits=x, targets=y)
self.d_loss_real = tf.reduce_mean(
sigmoid_cross_entropy_with_logits(self.D_logits, tf.ones_like(self.D)))
self.d_loss_fake = tf.reduce_mean(
sigmoid_cross_entropy_with_logits(self.D_logits_, tf.zeros_like(self.D_)))
self.g_loss = tf.reduce_mean(
sigmoid_cross_entropy_with_logits(self.D_logits_, tf.ones_like(self.D_)))
self.d_loss_real_sum = scalar_summary("d_loss_real", self.d_loss_real)
self.d_loss_fake_sum = scalar_summary("d_loss_fake", self.d_loss_fake)
self.d_loss = self.d_loss_real + self.d_loss_fake
self.g_loss_sum = scalar_summary("g_loss", self.g_loss)
self.d_loss_sum = scalar_summary("d_loss", self.d_loss)
t_vars = tf.trainable_variables()
self.d_vars = [var for var in t_vars if 'd_' in var.name]
self.g_vars = [var for var in t_vars if 'g_' in var.name]
self.saver = tf.train.Saver()
#训练
def train(self, config):
# 定义优化器以及优化对象
d_optim = tf.train.AdamOptimizer(config.learning_rate, beta1=config.beta1) \
.minimize(self.d_loss, var_list=self.d_vars)
g_optim = tf.train.AdamOptimizer(config.learning_rate, beta1=config.beta1) \
.minimize(self.g_loss, var_list=self.g_vars)
# 初始化图变量
try:
tf.global_variables_initializer().run()
except:
tf.initialize_all_variables().run()
# g网和d网的日志数据汇总
self.g_sum = merge_summary([self.z_sum, self.d__sum,
self.G_sum, self.d_loss_fake_sum, self.g_loss_sum])
self.d_sum = merge_summary(
[self.z_sum, self.d_sum, self.d_loss_real_sum, self.d_loss_sum])
self.writer = SummaryWriter("./logs", self.sess.graph)
# 采样数据设置代码:每隔100次采样、保存图片一次
# g网输入数据生成:生成[-1,1)之间的均匀分布数据,shape=[batch_size,z_dim]
# D网输入数据生成:数据集固定取前64个数据
sample_z = np.random.uniform(-1, 1, size=(self.sample_num , self.z_dim))
if config.dataset == 'mnist':
sample_inputs = self.data_X[0:self.sample_num]
sample_labels = self.data_y[0:self.sample_num]
else:
sample_files = self.data[0:self.sample_num]
# 图片列表,获取的图片无论crop是true还是false,最后尺寸都是[output_height,out_put_widh]
sample = [
get_image(sample_file,
input_height=self.input_height,
input_width=self.input_width,
resize_height=self.output_height,
resize_width=self.output_width,
crop=self.crop,
grayscale=self.grayscale) for sample_file in sample_files]
# 灰度图需要拓展维度[batch_size,height,width,chenal]
if (self.grayscale):
sample_inputs = np.array(sample).astype(np.float32)[:, :, :, None]
else:
sample_inputs = np.array(sample).astype(np.float32)
# 加载最新模型
counter = 1
start_time = time.time()
# counld_load表示是否成功读取checkpoint,load函数恢复最新的checkpoint文件给self.sess
could_load, checkpoint_counter = self.load(self.checkpoint_dir)
if could_load:
counter = checkpoint_counter
print(" [*] Load SUCCESS")
else:
print(" [!] Load failed...")
# 开始取数据训练模型
# config=25,train_size=inf(正无穷),batch_idxs表示训练一次全部数据的batch个数
for epoch in xrange(config.epoch):
# 设置batch个数
if config.dataset == 'mnist':
batch_idxs = min(len(self.data_X), config.train_size) // config.batch_size
else:
self.data = glob(os.path.join(
config.data_dir, config.dataset, self.input_fname_pattern))
np.random.shuffle(self.data)
batch_idxs = min(len(self.data), config.train_size) // config.batch_size
# mnist数据集输入:batch_images/batch_labels [64,28,28,1]/[64,10]
# 其它数据集输入:batch_images [64,64,64,chenal]
# G网输入:batch_z [64,100] 符合[-1,1)的均匀分布
for idx in xrange(0, int(batch_idxs)):
# 取训练数据
if config.dataset == 'mnist':
batch_images = self.data_X[idx*config.batch_size:(idx+1)*config.batch_size]
batch_labels = self.data_y[idx*config.batch_size:(idx+1)*config.batch_size]
else:
batch_files = self.data[idx*config.batch_size:(idx+1)*config.batch_size]
batch = [
get_image(batch_file,
input_height=self.input_height,
input_width=self.input_width,
resize_height=self.output_height,
resize_width=self.output_width,
crop=self.crop,
grayscale=self.grayscale) for batch_file in batch_files]
if self.grayscale:
batch_images = np.array(batch).astype(np.float32)[:, :, :, None]
else:
batch_images = np.array(batch).astype(np.float32)
batch_z = np.random.uniform(-1, 1, [config.batch_size, self.z_dim]) \
.astype(np.float32)
#向图注入数据训练
if config.dataset == 'mnist':
# D网训练(其实里面包含一次G网的训练)
_, summary_str = self.sess.run([d_optim, self.d_sum],
feed_dict={
self.inputs: batch_images,
self.z: batch_z,
self.y:batch_labels,
})
self.writer.add_summary(summary_str, counter)
# G网训练
_, summary_str = self.sess.run([g_optim, self.g_sum],
feed_dict={
self.z: batch_z,
self.y:batch_labels,
})
self.writer.add_summary(summary_str, counter)
# 再运行一次G网
# Run g_optim twice to make sure that d_loss does not go to zero (different from paper)
_, summary_str = self.sess.run([g_optim, self.g_sum],
feed_dict={ self.z: batch_z, self.y:batch_labels })
self.writer.add_summary(summary_str, counter)
errD_fake = self.d_loss_fake.eval({
self.z: batch_z,
self.y:batch_labels
})
errD_real = self.d_loss_real.eval({
self.inputs: batch_images,
self.y:batch_labels
})
errG = self.g_loss.eval({
self.z: batch_z,
self.y: batch_labels
})
else:
# D网训练
_, summary_str = self.sess.run([d_optim, self.d_sum],
feed_dict={ self.inputs: batch_images, self.z: batch_z })
self.writer.add_summary(summary_str, counter)
# G网训练
_, summary_str = self.sess.run([g_optim, self.g_sum],
feed_dict={ self.z: batch_z })
self.writer.add_summary(summary_str, counter)
# Run g_optim twice to make sure that d_loss does not go to zero (different from paper)
_, summary_str = self.sess.run([g_optim, self.g_sum],
feed_dict={ self.z: batch_z })
self.writer.add_summary(summary_str, counter)
errD_fake = self.d_loss_fake.eval({ self.z: batch_z })
errD_real = self.d_loss_real.eval({ self.inputs: batch_images })
errG = self.g_loss.eval({self.z: batch_z})
# 输出一次训练信息
counter += 1
print("Epoch: [%2d/%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \
% (epoch, config.epoch, idx, batch_idxs,
time.time() - start_time, errD_fake+errD_real, errG))
# 以下代码都是在存图片和模型
# 每训练100次,存一次图片
if np.mod(counter, 100) == 1:
if config.dataset == 'mnist':
samples, d_loss, g_loss = self.sess.run(
[self.sampler, self.d_loss, self.g_loss],
feed_dict={
self.z: sample_z,
self.inputs: sample_inputs,
self.y:sample_labels,
}
)
save_images(samples, image_manifold_size(samples.shape[0]),
'./{}/train_{:02d}_{:04d}.png'.format(config.sample_dir, epoch, idx))
print("[Sample] d_loss: %.8f, g_loss: %.8f" % (d_loss, g_loss))
else:
try:
samples, d_loss, g_loss = self.sess.run(
[self.sampler, self.d_loss, self.g_loss],
feed_dict={
self.z: sample_z,
self.inputs: sample_inputs,
},
)
save_images(samples, image_manifold_size(samples.shape[0]),
'./{}/train_{:02d}_{:04d}.png'.format(config.sample_dir, epoch, idx))
print("[Sample] d_loss: %.8f, g_loss: %.8f" % (d_loss, g_loss))
except:
print("one pic error!...")
# 每500次存一个checkpoint
if np.mod(counter, 500) == 2:
self.save(config.checkpoint_dir, counter)
# D网
def discriminator(self, image, y=None, reuse=False):
with tf.variable_scope("discriminator") as scope:
if reuse:
scope.reuse_variables()
print('输出D网训练信息:','y_dim:',self.y_dim)
if not self.y_dim:
h0 = lrelu(conv2d(image, self.df_dim, name='d_h0_conv'))
h1 = lrelu(self.d_bn1(conv2d(h0, self.df_dim*2, name='d_h1_conv')))
h2 = lrelu(self.d_bn2(conv2d(h1, self.df_dim*4, name='d_h2_conv')))
h3 = lrelu(self.d_bn3(conv2d(h2, self.df_dim*8, name='d_h3_conv')))
h4 = linear(tf.reshape(h3, [self.batch_size, -1]), 1, 'd_h4_lin')
print('image:', image.shape, ' h0:', h0.shape, ' h1:', h1.shape, ' h2:', h2.shape, ' h3:', h3.shape, ' h4:',
h4.shape)
return tf.nn.sigmoid(h4), h4
else:
yb = tf.reshape(y, [self.batch_size, 1, 1, self.y_dim])
x = conv_cond_concat(image, yb)
h0 = lrelu(conv2d(x, self.c_dim + self.y_dim, name='d_h0_conv'))
h0 = conv_cond_concat(h0, yb)
h1 = lrelu(self.d_bn1(conv2d(h0, self.df_dim + self.y_dim, name='d_h1_conv')))
h1 = tf.reshape(h1, [self.batch_size, -1])
h1 = concat([h1, y], 1)
h2 = lrelu(self.d_bn2(linear(h1, self.dfc_dim, 'd_h2_lin')))
h2 = concat([h2, y], 1)
h3 = linear(h2, 1, 'd_h3_lin')
print('image:', image.shape, " yb:",yb.shape,' x:',x.shape,' h0:', h0.shape,
' h1:', h1.shape, ' h2:', h2.shape, ' h3:', h3.shape)
return tf.nn.sigmoid(h3), h3
# G网
def generator(self, z, y=None):
with tf.variable_scope("generator") as scope:
print('显示G网训练信息\n','y_dim',self.y_dim)
if not self.y_dim:
s_h, s_w = self.output_height, self.output_width
s_h2, s_w2 = conv_out_size_same(s_h, 2), conv_out_size_same(s_w, 2)
s_h4, s_w4 = conv_out_size_same(s_h2, 2), conv_out_size_same(s_w2, 2)
s_h8, s_w8 = conv_out_size_same(s_h4, 2), conv_out_size_same(s_w4, 2)
s_h16, s_w16 = conv_out_size_same(s_h8, 2), conv_out_size_same(s_w8, 2)
# project `z` and reshape
self.z_, self.h0_w, self.h0_b = linear(
z, self.gf_dim*8*s_h16*s_w16, 'g_h0_lin', with_w=True)
self.h0 = tf.reshape(
self.z_, [-1, s_h16, s_w16, self.gf_dim * 8])
h0 = tf.nn.relu(self.g_bn0(self.h0))
self.h1, self.h1_w, self.h1_b = deconv2d(
h0, [self.batch_size, s_h8, s_w8, self.gf_dim*4], name='g_h1', with_w=True)
h1 = tf.nn.relu(self.g_bn1(self.h1))
h2, self.h2_w, self.h2_b = deconv2d(
h1, [self.batch_size, s_h4, s_w4, self.gf_dim*2], name='g_h2', with_w=True)
h2 = tf.nn.relu(self.g_bn2(h2))
h3, self.h3_w, self.h3_b = deconv2d(
h2, [self.batch_size, s_h2, s_w2, self.gf_dim*1], name='g_h3', with_w=True)
h3 = tf.nn.relu(self.g_bn3(h3))
h4, self.h4_w, self.h4_b = deconv2d(
h3, [self.batch_size, s_h, s_w, self.c_dim], name='g_h4', with_w=True)
print('z:',z.shape,' h0:',h0.shape,' h1:',h1.shape,' h2:',h2.shape,' h3:',h3.shape,' h4:',h4.shape)
return tf.nn.tanh(h4)
else:
s_h, s_w = self.output_height, self.output_width
s_h2, s_h4 = int(s_h/2), int(s_h/4)
s_w2, s_w4 = int(s_w/2), int(s_w/4)
# yb = tf.expand_dims(tf.expand_dims(y, 1),2)
yb = tf.reshape(y, [self.batch_size, 1, 1, self.y_dim])
z = concat([z, y], 1)
h0 = tf.nn.relu(
self.g_bn0(linear(z, self.gfc_dim, 'g_h0_lin')))
h0 = concat([h0, y], 1)
h1 = tf.nn.relu(self.g_bn1(
linear(h0, self.gf_dim*2*s_h4*s_w4, 'g_h1_lin')))
h1 = tf.reshape(h1, [self.batch_size, s_h4, s_w4, self.gf_dim * 2])
h1 = conv_cond_concat(h1, yb)
h2 = tf.nn.relu(self.g_bn2(deconv2d(h1,
[self.batch_size, s_h2, s_w2, self.gf_dim * 2], name='g_h2')))
h2 = conv_cond_concat(h2, yb)
h3=tf.nn.sigmoid(
deconv2d(h2, [self.batch_size, s_h, s_w, self.c_dim], name='g_h3'))
print('z:', z.shape, ' yb:',yb.shape,' h0:', h0.shape, ' h1:', h1.shape,
' h2:', h2.shape, ' h3:', h3.shape)
return h3
#采样器
def sampler(self, z, y=None):
with tf.variable_scope("generator") as scope:
scope.reuse_variables()
if not self.y_dim:
s_h, s_w = self.output_height, self.output_width
s_h2, s_w2 = conv_out_size_same(s_h, 2), conv_out_size_same(s_w, 2)
s_h4, s_w4 = conv_out_size_same(s_h2, 2), conv_out_size_same(s_w2, 2)
s_h8, s_w8 = conv_out_size_same(s_h4, 2), conv_out_size_same(s_w4, 2)
s_h16, s_w16 = conv_out_size_same(s_h8, 2), conv_out_size_same(s_w8, 2)
# project `z` and reshape
h0 = tf.reshape(
linear(z, self.gf_dim*8*s_h16*s_w16, 'g_h0_lin'),
[-1, s_h16, s_w16, self.gf_dim * 8])
h0 = tf.nn.relu(self.g_bn0(h0, train=False))
h1 = deconv2d(h0, [self.batch_size, s_h8, s_w8, self.gf_dim*4], name='g_h1')
h1 = tf.nn.relu(self.g_bn1(h1, train=False))
h2 = deconv2d(h1, [self.batch_size, s_h4, s_w4, self.gf_dim*2], name='g_h2')
h2 = tf.nn.relu(self.g_bn2(h2, train=False))
h3 = deconv2d(h2, [self.batch_size, s_h2, s_w2, self.gf_dim*1], name='g_h3')
h3 = tf.nn.relu(self.g_bn3(h3, train=False))
h4 = deconv2d(h3, [self.batch_size, s_h, s_w, self.c_dim], name='g_h4')
return tf.nn.tanh(h4)
else:
s_h, s_w = self.output_height, self.output_width
s_h2, s_h4 = int(s_h/2), int(s_h/4)
s_w2, s_w4 = int(s_w/2), int(s_w/4)
# yb = tf.reshape(y, [-1, 1, 1, self.y_dim])
yb = tf.reshape(y, [self.batch_size, 1, 1, self.y_dim])
z = concat([z, y], 1)
h0 = tf.nn.relu(self.g_bn0(linear(z, self.gfc_dim, 'g_h0_lin'), train=False))
h0 = concat([h0, y], 1)
h1 = tf.nn.relu(self.g_bn1(
linear(h0, self.gf_dim*2*s_h4*s_w4, 'g_h1_lin'), train=False))
h1 = tf.reshape(h1, [self.batch_size, s_h4, s_w4, self.gf_dim * 2])
h1 = conv_cond_concat(h1, yb)
h2 = tf.nn.relu(self.g_bn2(
deconv2d(h1, [self.batch_size, s_h2, s_w2, self.gf_dim * 2], name='g_h2'), train=False))
h2 = conv_cond_concat(h2, yb)
return tf.nn.sigmoid(deconv2d(h2, [self.batch_size, s_h, s_w, self.c_dim], name='g_h3'))
# 函数load_mnist()返回的是70000张已经随机打乱且数值在0~1之间的图像数据[70000,28,28,1]和已经one_hot编码的标签[70000,10]
def load_mnist(self):
mnist = input_data.read_data_sets("data/", one_hot=True)
ti = mnist.train.images.reshape(55000, 28, 28, 1).astype(np.float)
mv = mnist.validation.images.reshape(5000, 28, 28, 1).astype(np.float)
x = np.concatenate([ti, mv], axis=0)
y_ = np.concatenate([mnist.train.labels, mnist.validation.labels], axis=0)
mt = mnist.test.images.reshape(10000, 28, 28, 1).astype(np.float)
X = np.concatenate([x, mt], axis=0)
y = np.concatenate([y_, mnist.test.labels], axis=0).astype(np.float)
seed = 547
# 设置相同的随机数,使得训练样本打乱顺序和标签打乱顺序一致
np.random.seed(seed)
# shuffle()函数默认按照数组最高维打乱
np.random.shuffle(X)
np.random.seed(seed)
np.random.shuffle(y)
return X/255.,y
# @property类装饰器,使model_dir函数能用属性方式调用self.model_dir,返回datasetname_batchsize_height_width
@property
def model_dir(self):
return "{}_{}_{}_{}".format(
self.dataset_name, self.batch_size,
self.output_height, self.output_width)
def save(self, checkpoint_dir, step):
model_name = "DCGAN.model"
checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
self.saver.save(self.sess,
os.path.join(checkpoint_dir, model_name),
global_step=step)
#读取checkpoint,恢复数据到self.sess
def load(self, checkpoint_dir):
# re包全称regulation expression,正则表达式(用特定的字符组成规则的字符串),是对字符串操作的逻辑公式
import re
print(" [*] Reading checkpoints...")
checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)
# ckpt是CheckpointState proto类型数据结构,有model_checkpoint_path和all_model_checkpoint_paths两个属性
# 其中model_checkpoint_path保存了最新的tensorflow模型文件的文件名,all_model_checkpoint_paths则有未被删除的所有tensorflow模型文件的文件名
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
if ckpt and ckpt.model_checkpoint_path:
# os.path.basename(path)返回path路径‘/’的最后一个字符串
ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
counter = int(next(re.finditer("(\d+)(?!.*\d)",ckpt_name)).group(0))
print(" [*] Success to read {}".format(ckpt_name))
return True, counter
else:
print(" [*] Failed to find a checkpoint")
return False, 0