这是一种GAN网络增强技术----具有匹配感知的判别器。前面讲过,在InfoGAN中,使用了ACGAN的方式进行指导模拟数据与生成数据的对应关系(分类)。在GAN-cls中该效果会以更简单的方式来实现,即增强判别器的功能,令其不仅能判断图片真伪,还能判断匹配真伪。
(个人理解)没啥实质性改变,时间并未缩短,技术也没有怎么简化甚至变得复杂了。就是思想上的一个转变,原本ACGan是模拟样本+正确分类信息输入进去/真实样本+正确分类信息输入进D去。现在的GAN-cls变为输入真实样本和真实标签、虚拟样本和真实标签、虚拟标签和真实样本的三种组合形式(无对应图片的随机标签)
GAN-cls的具体做法是,在原有的GAN网络上,将判别器的输入变为图片与对应标签的连接数据。这样判别器的输入特征中就会有生成图像的特征与对应标签的特征。然后用这样的判别器分别对真实标签与真实图片、假标签与真实图片、真实标签与假图片进行判断,预期的结果依次为真、假、假,在训练的过程中沿着这个方向收敛即可。而对于生成器,则不需要做任何改动。这样简单的一步就完成了生成根据标签匹配的模拟数据功能。
直接修改上一篇 GAN生成对抗网络合集(五):LSGan-最小二乘GAN(附代码) 代码,将其改成GAN-cls。
# def discriminator(x, num_classes=10, num_cont=2):
def discriminator(x, y): # 判别器函数 : x两次卷积,再接两次全连接; y代表输入的样本标签
reuse = len([t for t in tf.global_variables() if t.name.startswith('discriminator')]) > 0
# print (reuse)
# print (x.get_shape())
with tf.variable_scope('discriminator', reuse=reuse):
y = slim.fully_connected(y, num_outputs=n_input, activation_fn=leaky_relu) # 将y变为与图片一样维度的映射
y = tf.reshape(y, shape=[-1, 28, 28, 1]) # 将y统一成图片格式
x = tf.reshape(x, shape=[-1, 28, 28, 1])
# 将二者连接到一起,统一处理
x = tf.concat(axis=3, values=[x, y]) # x.shape = [-1, 28, 28, 2]
x = slim.conv2d(x, num_outputs=64, kernel_size=[4, 4], stride=2, activation_fn=leaky_relu)
x = slim.conv2d(x, num_outputs=128, kernel_size=[4, 4], stride=2, activation_fn=leaky_relu)
# print ("conv2d",x.get_shape())
x = slim.flatten(x) # 输入扁平化
shared_tensor = slim.fully_connected(x, num_outputs=1024, activation_fn=leaky_relu)
# recog_shared = slim.fully_connected(shared_tensor, num_outputs=128, activation_fn=leaky_relu)
# 生成的数据可以分别连接不同的输出层产生不同的结果
# 1维的输出层产生判别结果1或是0
disc = slim.fully_connected(shared_tensor, num_outputs=1, activation_fn=tf.nn.sigmoid)
disc = tf.squeeze(disc, -1)
# print ("disc",disc.get_shape()) # 0 or 1
# 10维的输出层产生分类结果 (样本标签)
# recog_cat = slim.fully_connected(recog_shared, num_outputs=num_classes, activation_fn=None)
# 2维输出层产生重构造的隐含维度信息
# recog_cont = slim.fully_connected(recog_shared, num_outputs=num_cont, activation_fn=tf.nn.sigmoid)
return disc # recog_cat, recog_cont
注:这里是将3种输入的x与y分别按照batch_size维度连接变为判别器的一个输入的。生成结果后再使用split函数将其裁成3个结果disc_real、disc_fake和disc_mis,分别代表真实样本与真实标签、生成的图像gen与真实标签、真实样本与错误标签所对应的判别值。这么写会使代码看上去简洁一些,当然也可以一个一个地输入x、y,然后调用三次判别器,效果是一样的。
##################################################################
# 3.定义网络模型 : 定义 参数/输入/输出/中间过程(经过G/D)的输入输出
##################################################################
batch_size = 10 # 获取样本的批次大小32
classes_dim = 10 # 10 classes
con_dim = 2 # 隐含信息变量的维度, 应节点为z_con
rand_dim = 38 # 一般噪声的维度, 应节点为z_rand, 二者都是符合标准高斯分布的随机数。
n_input = 784 # 28 * 28
x = tf.placeholder(tf.float32, [None, n_input]) # x为输入真实图片images
y = tf.placeholder(tf.int32, [None]) # y为真实标签labels
misy = tf.placeholder(tf.int32, [None]) # 错误标签
# z_con = tf.random_normal((batch_size, con_dim)) # 2列
z_rand = tf.random_normal((batch_size, rand_dim)) # 38列
z = tf.concat(axis=1, values=[tf.one_hot(y, depth=classes_dim), z_rand]) # 50列 shape = (10, 50)
gen = generator(z) # shape = (10, 28, 28, 1)
genout = tf.squeeze(gen, -1) # shape = (10, 28, 28)
# labels for discriminator
# y_real = tf.ones(batch_size) # 真
# y_fake = tf.zeros(batch_size) # 假
# 判别器D
xin = tf.concat([x, tf.reshape(gen, shape=[-1, 784]), x], 0)
yin = tf.concat([tf.one_hot(y, depth=classes_dim), tf.one_hot(y, depth=classes_dim), tf.one_hot(misy, depth=classes_dim)], 0)
# disc_real, class_real, _ = discriminator(x)
# disc_fake, class_fake, con_fake = discriminator(gen)
# pred_class = tf.argmax(class_fake, dimension=1)
disc_all = discriminator(xin, yin)
# 真实样本与真实标签、生成的图像gen与真实标签、真实样本与错误标签所对应的判别值
disc_real, disc_fake, disc_mis = tf.split(disc_all, 3)
# 判别器 loss
# loss_d_r = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=disc_real, labels=y_real)) # 1
# loss_d_f = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=disc_fake, labels=y_fake)) # 0
loss_d = tf.reduce_sum(tf.square(disc_real-1) + (tf.square(disc_fake-0)+tf.square(disc_mis-0))/2) / 2
# generator loss
loss_g = tf.reduce_sum(tf.square(disc_fake-1)) / 2
# categorical factor loss 分类因素损失
# loss_cf = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=class_fake, labels=y))
# loss_cr = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=class_real, labels=y))
# loss_c = (loss_cf + loss_cr) / 2
# continuous factor loss 隐含信息变量的损失
# loss_con = tf.reduce_mean(tf.square(con_fake - z_con))
##################################################################
# 5.训练与测试
# 建立session,循环中使用run来运行两个优化器
##################################################################
training_epochs = 3
display_step = 1
with tf.train.MonitoredTrainingSession(checkpoint_dir='log/checkpointsnew', save_checkpoint_secs=60) as sess:
total_batch = int(mnist.train.num_examples / batch_size)
print("global_step.eval(session=sess)", global_step.eval(session=sess),
int(global_step.eval(session=sess) / total_batch))
for epoch in range(int(global_step.eval(session=sess) / total_batch), training_epochs):
avg_cost = 0.
# 遍历全部数据集
for i in range(total_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size) # 取数据
_, mis_batch_ys = mnist.train.next_batch(batch_size) # 取数据
feeds = {x: batch_xs, y: batch_ys, misy: mis_batch_ys}
# Fit training using batch data
l_disc, _, l_d_step = sess.run([loss_d, train_disc, global_step], feeds)
l_gen, _, l_g_step = sess.run([loss_g, train_gen, gen_global_step], feeds)
# 显示训练中的详细信息
if epoch % display_step == 0:
print("Epoch:", '%04d' % (epoch + 1), "cost=", "{:.9f} ".format(l_disc), l_gen)
print("完成!")
-----------------------------------------------------------------------------------------------------------------------------------------
附上全部代码:
# !/usr/bin/env python3
# -*- coding: utf-8 -*-
__author__ = '黎明'
##################################################################
# 1.引入头文件并加载mnist数据
##################################################################
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from scipy.stats import norm
import tensorflow.contrib.slim as slim
import time
from timer import Timer
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/media/S318080208/py_pictures/minist/") # ,one_hot=True)
tf.reset_default_graph() # 用于清除默认图形堆栈并重置全局默认图形
##################################################################
# 2.定义生成器与判别器
##################################################################
def generator(x): # 生成器函数 : 两个全连接+两个反卷积模拟样本的生成,每一层都有BN(批量归一化)处理
reuse = len([t for t in tf.global_variables() if t.name.startswith('generator')]) > 0 # 确认该变量作用域没有变量
# print (x.get_shape())
with tf.variable_scope('generator', reuse=reuse):
x = slim.fully_connected(x, 1024)
# print(x)
x = slim.batch_norm(x, activation_fn=tf.nn.relu)
x = slim.fully_connected(x, 7 * 7 * 128)
x = slim.batch_norm(x, activation_fn=tf.nn.relu)
x = tf.reshape(x, [-1, 7, 7, 128])
# print ('22', tf.tensor.get_shape())
x = slim.conv2d_transpose(x, 64, kernel_size=[4, 4], stride=2, activation_fn=None)
# print ('gen',x.get_shape())
x = slim.batch_norm(x, activation_fn=tf.nn.relu)
z = slim.conv2d_transpose(x, 1, kernel_size=[4, 4], stride=2, activation_fn=tf.nn.sigmoid)
# print ('genz',z.get_shape())
return z
def leaky_relu(x):
return tf.where(tf.greater(x, 0), x, 0.01 * x)
# def discriminator(x, num_classes=10, num_cont=2):
def discriminator(x, y): # 判别器函数 : x两次卷积,再接两次全连接; y代表输入的样本标签
reuse = len([t for t in tf.global_variables() if t.name.startswith('discriminator')]) > 0
# print (reuse)
# print (x.get_shape())
with tf.variable_scope('discriminator', reuse=reuse):
y = slim.fully_connected(y, num_outputs=n_input, activation_fn=leaky_relu) # 将y变为与图片一样维度的映射
y = tf.reshape(y, shape=[-1, 28, 28, 1]) # 将y统一成图片格式
x = tf.reshape(x, shape=[-1, 28, 28, 1])
# 将二者连接到一起,统一处理
x = tf.concat(axis=3, values=[x, y]) # x.shape = [-1, 28, 28, 2]
x = slim.conv2d(x, num_outputs=64, kernel_size=[4, 4], stride=2, activation_fn=leaky_relu)
x = slim.conv2d(x, num_outputs=128, kernel_size=[4, 4], stride=2, activation_fn=leaky_relu)
# print ("conv2d",x.get_shape())
x = slim.flatten(x) # 输入扁平化
shared_tensor = slim.fully_connected(x, num_outputs=1024, activation_fn=leaky_relu)
# recog_shared = slim.fully_connected(shared_tensor, num_outputs=128, activation_fn=leaky_relu)
# 生成的数据可以分别连接不同的输出层产生不同的结果
# 1维的输出层产生判别结果1或是0
disc = slim.fully_connected(shared_tensor, num_outputs=1, activation_fn=tf.nn.sigmoid)
disc = tf.squeeze(disc, -1)
# print ("disc",disc.get_shape()) # 0 or 1
# 10维的输出层产生分类结果 (样本标签)
# recog_cat = slim.fully_connected(recog_shared, num_outputs=num_classes, activation_fn=None)
# 2维输出层产生重构造的隐含维度信息
# recog_cont = slim.fully_connected(recog_shared, num_outputs=num_cont, activation_fn=tf.nn.sigmoid)
return disc # recog_cat, recog_cont
##################################################################
# 3.定义网络模型 : 定义 参数/输入/输出/中间过程(经过G/D)的输入输出
##################################################################
batch_size = 10 # 获取样本的批次大小32
classes_dim = 10 # 10 classes
con_dim = 2 # 隐含信息变量的维度, 应节点为z_con
rand_dim = 38 # 一般噪声的维度, 应节点为z_rand, 二者都是符合标准高斯分布的随机数。
n_input = 784 # 28 * 28
x = tf.placeholder(tf.float32, [None, n_input]) # x为输入真实图片images
y = tf.placeholder(tf.int32, [None]) # y为真实标签labels
misy = tf.placeholder(tf.int32, [None]) # 错误标签
# z_con = tf.random_normal((batch_size, con_dim)) # 2列
z_rand = tf.random_normal((batch_size, rand_dim)) # 38列
z = tf.concat(axis=1, values=[tf.one_hot(y, depth=classes_dim), z_rand]) # 50列 shape = (10, 50)
gen = generator(z) # shape = (10, 28, 28, 1)
genout = tf.squeeze(gen, -1) # shape = (10, 28, 28)
# labels for discriminator
# y_real = tf.ones(batch_size) # 真
# y_fake = tf.zeros(batch_size) # 假
# 判别器D
xin = tf.concat([x, tf.reshape(gen, shape=[-1, 784]), x], 0)
yin = tf.concat([tf.one_hot(y, depth=classes_dim), tf.one_hot(y, depth=classes_dim), tf.one_hot(misy, depth=classes_dim)], 0)
# disc_real, class_real, _ = discriminator(x)
# disc_fake, class_fake, con_fake = discriminator(gen)
# pred_class = tf.argmax(class_fake, dimension=1)
disc_all = discriminator(xin, yin)
# 真实样本与真实标签、生成的图像gen与真实标签、真实样本与错误标签所对应的判别值
disc_real, disc_fake, disc_mis = tf.split(disc_all, 3)
##################################################################
# 4.定义损失函数和优化器
##################################################################
# 判别器 loss
# loss_d_r = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=disc_real, labels=y_real)) # 1
# loss_d_f = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=disc_fake, labels=y_fake)) # 0
loss_d = tf.reduce_sum(tf.square(disc_real-1) + (tf.square(disc_fake-0)+tf.square(disc_mis-0))/2) / 2
# generator loss
loss_g = tf.reduce_sum(tf.square(disc_fake-1)) / 2
# categorical factor loss 分类因素损失
# loss_cf = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=class_fake, labels=y))
# loss_cr = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=class_real, labels=y))
# loss_c = (loss_cf + loss_cr) / 2
# continuous factor loss 隐含信息变量的损失
# loss_con = tf.reduce_mean(tf.square(con_fake - z_con))
# 获得各个网络中各自的训练参数列表
t_vars = tf.trainable_variables()
d_vars = [var for var in t_vars if 'discriminator' in var.name]
g_vars = [var for var in t_vars if 'generator' in var.name]
# 优化器
# disc_global_step = tf.Variable(0, trainable=False)
gen_global_step = tf.Variable(0, trainable=False)
global_step = tf.train.get_or_create_global_step() # 使用MonitoredTrainingSession,必须有
train_disc = tf.train.AdamOptimizer(0.0001).minimize(loss_d, var_list=d_vars,
global_step=global_step)
train_gen = tf.train.AdamOptimizer(0.001).minimize(loss_g, var_list=g_vars,
global_step=gen_global_step)
##################################################################
# 5.训练与测试
# 建立session,循环中使用run来运行两个优化器
##################################################################
training_epochs = 3
display_step = 1
with tf.train.MonitoredTrainingSession(checkpoint_dir='log/checkpointsnew', save_checkpoint_secs=60) as sess:
total_batch = int(mnist.train.num_examples / batch_size)
print("global_step.eval(session=sess)", global_step.eval(session=sess),
int(global_step.eval(session=sess) / total_batch))
for epoch in range(int(global_step.eval(session=sess) / total_batch), training_epochs):
avg_cost = 0.
# 遍历全部数据集
for i in range(total_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size) # 取数据
_, mis_batch_ys = mnist.train.next_batch(batch_size) # 取数据
feeds = {x: batch_xs, y: batch_ys, misy: mis_batch_ys}
# Fit training using batch data
l_disc, _, l_d_step = sess.run([loss_d, train_disc, global_step], feeds)
l_gen, _, l_g_step = sess.run([loss_g, train_gen, gen_global_step], feeds)
# 显示训练中的详细信息
if epoch % display_step == 0:
print("Epoch:", '%04d' % (epoch + 1), "cost=", "{:.9f} ".format(l_disc), l_gen)
print("完成!")
# 测试
_, mis_batch_ys = mnist.train.next_batch(batch_size)
print("result:",
loss_d.eval({x: mnist.test.images[:batch_size], y: mnist.test.labels[:batch_size], misy: mis_batch_ys},
session=sess)
, loss_g.eval({x: mnist.test.images[:batch_size], y: mnist.test.labels[:batch_size], misy: mis_batch_ys},
session=sess))
# 根据图片模拟生成图片
show_num = 10
gensimple, inputx, inputy = sess.run(
[genout, x, y], feed_dict={x: mnist.test.images[:batch_size], y: mnist.test.labels[:batch_size]})
f, a = plt.subplots(2, 10, figsize=(10, 2))
for i in range(show_num):
a[0][i].imshow(np.reshape(inputx[i], (28, 28)))
a[1][i].imshow(np.reshape(gensimple[i], (28, 28)))
plt.draw()
plt.show()
-----------------------------------------------------------------------------------------------------------------------------------------
运行结果:
使用GAN-cls技术同样也实现了生成与标签对应的样本,而且整体代码的运算要比ACGAN简洁很多(丝毫没觉得,专门算过时间,没啥变化 =.=)。