论文依据:Cvae-gan: fine-grained image generation through asymmetric training。
代码来源:github
关于论文讲解分析,csdn已经有不少例子,在此不做详细解释。
该模型呢,可以应用于,如图像修复、超分辨率和数据增强,以训练更好的人脸识别模型等领域。
该程序,按实验目的分为3部分,即训练网络,测试网络以及分类网络。又有三个基础支持网络,即基本模型,VAE网络和判别网络。分工明确,环环相扣。
首先是基本模型
model_utils.py
import tensorflow as tf
import tensorlayer as tl
import numpy as np
def _channel_shuffle(x, n_group):
n, h, w, c = x.shape.as_list()
x_reshaped = tf.reshape(x, [-1, h, w, n_group, c // n_group])
x_transposed = tf.transpose(x_reshaped, [0, 1, 2, 4, 3])
output = tf.reshape(x_transposed, [-1, h, w, c])
return output
def _group_norm_and_channel_shuffle(x, is_train, G=32, epsilon=1e-12, use_shuffle=False, name='_group_norm'):
with tf.variable_scope(name):
N, H, W, C = x.get_shape().as_list()
if N == None:
N = -1
G = min(G, C)
x = tf.reshape(x, [N, G, H, W, C // G])
mean, var = tf.nn.moments(x, [2, 3, 4], keep_dims=True)
x = (x - mean) / tf.sqrt(var + epsilon)
# shuffle channel
if use_shuffle:
x = tf.transpose(x, [0, 4, 2, 3, 1])
# per channel gamma and beta
gamma = tf.get_variable('gamma', [C], initializer=tf.constant_initializer(1.0), trainable=is_train)
beta = tf.get_variable('beta', [C], initializer=tf.constant_initializer(0.0), trainable=is_train)
gamma = tf.reshape(gamma, [1, 1, 1, C])
beta = tf.reshape(beta, [1, 1, 1, C])
output = tf.reshape(x, [N, H, W, C]) * gamma + beta
return output
def _switch_norm(x, name='_switch_norm') :
with tf.variable_scope(name) :
ch = x.shape[-1]
eps = 1e-5
batch_mean, batch_var = tf.nn.moments(x, [0, 1, 2], keep_dims=True)
ins_mean, ins_var = tf.nn.moments(x, [1, 2], keep_dims=True)
layer_mean, layer_var = tf.nn.moments(x, [1, 2, 3], keep_dims=True)
gamma = tf.get_variable("gamma", [ch], initializer=tf.constant_initializer(1.0))
beta = tf.get_variable("beta", [ch], initializer=tf.constant_initializer(0.0))
mean_weight = tf.nn.softmax(tf.get_variable("mean_weight", [3], initializer=tf.constant_initializer(1.0)))
var_wegiht = tf.nn.softmax(tf.get_variable("var_weight", [3], initializer=tf.constant_initializer(1.0)))
mean = mean_weight[0] * batch_mean + mean_weight[1] * ins_mean + mean_weight[2] * layer_mean
var = var_wegiht[0] * batch_var + var_wegiht[1] * ins_var + var_wegiht[2] * layer_var
x = (x - mean) / (tf.sqrt(var + eps))
x = x * gamma + beta
return x
def _add_coord(x):
batch_size = tf.shape(x)[0]
height, width = x.shape.as_list()[1:3]
# 加1是为了使坐标值为[0,1],不加1则是[0,1)
y_coord = tf.range(0, height, dtype=tf.float32)
y_coord = tf.reshape(y_coord, [1, -1, 1, 1]) # b,h,w,c
y_coord = tf.tile(y_coord, [batch_size, 1, width, 1]) / (height-1)
x_coord = tf.range(0, width, dtype=tf.float32)
x_coord = tf.reshape(x_coord, [1, 1, -1, 1]) # b,h,w,c
x_coord = tf.tile(x_coord, [batch_size, height, 1, 1]) / (width-1)
o = tf.concat([x, y_coord, x_coord], 3)
return o
def coord_layer(net):
return tl.layers.LambdaLayer(net, _add_coord, name='coord_layer')
def switchnorm_layer(net, act, name):
net = tl.layers.LambdaLayer(net, _switch_norm, name=name)
if act is not None:
net = tl.layers.LambdaLayer(net, act, name=name)
return net
def groupnorm_layer(net, is_train, G, use_shuffle, act, name):
net = tl.layers.LambdaLayer(net, _group_norm_and_channel_shuffle, {
'is_train':is_train, 'G':G, 'use_shuffle':use_shuffle, 'name':name}, name=name)
if act is not None:
net = tl.layers.LambdaLayer(net, act, name=name)
return net
def upsampling_layer(net, shortpoint):
hw = shortpoint.outputs.shape.as_list()[1:3]
net_upsamping = tl.layers.UpSampling2dLayer(net, hw, is_scale=False)
net = tl.layers.ConcatLayer([net_upsamping, shortpoint], -1)
return net
def upsampling_layer2(net, shortpoint, name):
with tf.variable_scope(name):
hw = shortpoint.outputs.shape.as_list()[1:3]
dim1 = net.outputs.shape.as_list()[3]
dim2 = shortpoint.outputs.shape.as_list()[3]
net = conv2d(net, dim1//2, 1, 1, None, 'SAME', True, True, False, 'up1')
shortpoint = conv2d(shortpoint, dim2//2, 1, 1, None, 'SAME', True, True, False, 'up2')
net = tl.layers.UpSampling2dLayer(net, hw, is_scale=False)
net = tl.layers.ConcatLayer([net, shortpoint], -1)
return net
def upsampling_layer3(net, shortpoint):
hw = shortpoint.outputs.shape.as_list()[1:3]
shortpoint = tl.layers.LambdaLayer(shortpoint, lambda x: tf.split(x, 2, -1)[0])
net = tl.layers.UpSampling2dLayer(net, hw, is_scale=False)
net = tl.layers.ConcatLayer([net, shortpoint], -1)
return net
def conv2d(net, n_filter, filter_size, strides, act, padding, use_norm, name):
filter_size = np.broadcast_to(filter_size, [2])
strides = np.broadcast_to(strides, [2])
with tf.variable_scope(name):
if use_norm:
net = tl.layers.Conv2d(net, n_filter, filter_size, strides, None, padding, b_init=None, name='c2d')
# net = groupnorm_layer(net, is_train, n_group, use_shuffle, act, 'gn')
net = switchnorm_layer(net, act, 'sn')
else:
net = tl.layers.Conv2d(net, n_filter, filter_size, strides, act, padding, name='c2d')
return net
def groupconv2d(net, n_filter, filter_size, strides, n_group, act, padding, use_norm, use_shuffle, name):
filter_size = np.broadcast_to(filter_size, [2])
strides = np.broadcast_to(strides, [2])
with tf.variable_scope(name):
if use_norm:
net = tl.layers.GroupConv2d(net, n_filter, filter_size, strides, n_group, None, padding, b_init=None, name='gc2d')
net = switchnorm_layer(net, act, 'sn')
if use_shuffle:
net = tl.layers.LambdaLayer(net, lambda x: _channel_shuffle(x, n_group))
else:
net = tl.layers.GroupConv2d(net, n_filter, filter_size, strides, n_group, act, padding, name='gc2d')
return net
def deconv2d(net, n_filter, filter_size, strides, act, padding, use_norm, name):
filter_size = np.broadcast_to(filter_size, [2])
strides = np.broadcast_to(strides, [2])
with tf.variable_scope(name):
if use_norm:
net = tl.layers.DeConv2d(net, n_filter, filter_size, strides=strides, padding=padding, b_init=None, name='dc2d')
net = switchnorm_layer(net, act, 'sn')
else:
net = tl.layers.DeConv2d(net, n_filter, filter_size, strides=strides, act=act, padding=padding, name='dc2d')
return net
def depthwiseconv2d(net, depth_multiplier, filter_size, strides, act, padding, use_norm, name, dilation_rate=1):
filter_size = np.broadcast_to(filter_size, [2])
strides = np.broadcast_to(strides, [2])
dilation_rate = np.broadcast_to(dilation_rate, [2])
with tf.variable_scope(name):
if use_norm:
net = tl.layers.DepthwiseConv2d(net, filter_size, strides, None, padding, dilation_rate=dilation_rate, depth_multiplier=depth_multiplier, b_init=None, name='dwc2d')
net = switchnorm_layer(net, act, 'sn')
else:
net = tl.layers.DepthwiseConv2d(net, filter_size, strides, act, padding, dilation_rate=dilation_rate, depth_multiplier=depth_multiplier, name='dwc2d')
return net
def resblock_1(net, n_filter, strides, act, name):
strides = np.broadcast_to(strides, [2])
with tf.variable_scope(name):
if np.max(strides) > 1 or net.outputs.shape.as_list()[-1] != n_filter:
shortcut = conv2d(net, n_filter, (3, 3), strides, None, 'SAME', True, 'shortcut')
else:
shortcut = net
net = coord_layer(net)
net = conv2d(net, n_filter, 1, 1, None, 'SAME', False, 'c1')
net = groupconv2d(net, n_filter, 3, 1, 20, act, 'SAME', True, True, 'gc1')
net = depthwiseconv2d(net, 1, 3, strides, None, 'SAME', True, 'dwc2')
net = groupconv2d(net, n_filter, 1, 1, 20, act, 'SAME', True, True, 'gc2')
net = tl.layers.ElementwiseLayer([shortcut, net], tf.add)
return net
def resblock_2(net, n_filter, strides, act, name):
strides = np.broadcast_to(strides, [2])
with tf.variable_scope(name):
if np.max(strides) > 1 or net.outputs.shape.as_list()[-1] != n_filter:
shortcut = conv2d(net, n_filter, (3, 3), strides, None, 'SAME', True, 'shortcut')
else:
shortcut = net
net = conv2d(net, n_filter, 3, strides, act, 'SAME', True, 'c1')
net = conv2d(net, n_filter, 3, 1, act, 'SAME', True, 'c1')
net = tl.layers.ElementwiseLayer([shortcut, net], tf.add)
return net
def ablock(net, n_filter, strides, act, n_block, name):
with tf.variable_scope('ab_' + name):
net = resblock_1(net, n_filter, strides, act, 'rb_0')
for i in range(1, n_block):
net = resblock_1(net, n_filter, 1, act, 'rb_%d' % i)
return net
def group_block(net, n_filter, strides, act, block_type, n_block, name):
with tf.variable_scope('gb_' + name):
net = block_type(net, n_filter=n_filter, strides=strides, act=act, name='b_0')
for i in range(1, n_block):
net = block_type(net, n_filter, 1, act, 'b_%d' % i)
return net
vae_net.py
from model_utils import *
act = lambda x: tl.act.leaky_twice_relu6(x, 0.1, 0.1)
n_hidden = 64
# block
def get_encoder(img, c, reuse):
hw = img.shape.as_list()[1:3]
with tf.variable_scope('encoder', reuse=reuse):
net = tl.layers.InputLayer(img)
# 输入类别信息
c_net = tl.layers.OneHotInputLayer(c, 10, axis=-1, dtype=tf.float32)
c_net = tl.layers.DenseLayer(c_net, hw[0]*hw[1], act, name='d1')
c_net = tl.layers.ReshapeLayer(c_net, (-1, hw[0], hw[1], 1))
net = tl.layers.ConcatLayer([net, c_net], -1)
b_id = 0
def get_unique_name():
nonlocal b_id
b_id += 1
return str(b_id)
net = ablock(net, 20, 1, act, 2, get_unique_name())
net = ablock(net, 40, 2, act, 3, get_unique_name())
net = ablock(net, 60, 2, act, 3, get_unique_name())
net = ablock(net, 80, 2, act, 4, get_unique_name())
net = tl.layers.GlobalMeanPool2d(net)
net = tl.layers.DenseLayer(net, n_hidden, act, name='out')
mean = tl.layers.DenseLayer(net, n_hidden, act, name='mean')
log_sigma = tl.layers.DenseLayer(net, n_hidden, act, name='log_sigma')
net = tl.layers.merge_networks([net, mean, log_sigma])
mean = mean.outputs
log_sigma = log_sigma.outputs
std = log_sigma * 0.5
noise = tf.random_normal(tf.shape(mean))
z = mean + noise * tf.exp(tf.minimum(std, 20))
return net, [z, mean, log_sigma]
def get_decoder(z, c, reuse):
z_len = z.shape.as_list()[-1]
with tf.variable_scope('decoder', reuse=reuse):
net = tl.layers.InputLayer(z)
# 便于鉴别器复用这里解码器,鉴别器没有c输入
if c is not None:
c_net = tl.layers.OneHotInputLayer(c, 10, axis=-1, dtype=tf.float32)
c_net = tl.layers.DenseLayer(c_net, z_len, act, name='d1')
net = tl.layers.ConcatLayer([net, c_net], -1)
net = tl.layers.DenseLayer(net, z_len, act, name='d2')
b_id = 0
def get_unique_name():
nonlocal b_id
b_id += 1
return str(b_id)
net = tl.layers.ReshapeLayer(net, (-1, 1, 1, n_hidden))
net = tl.layers.TileLayer(net, [1, 4, 4, 1])
net = ablock(net, 80, 1, act, 3, get_unique_name())
net = tl.layers.UpSampling2dLayer(net, (2, 2), True, 1)
net = ablock(net, 60, 1, act, 3, get_unique_name())
net = tl.layers.UpSampling2dLayer(net, (2, 2), True, 1)
net = ablock(net, 40, 1, act, 3, get_unique_name())
net = tl.layers.UpSampling2dLayer(net, (2, 2), True, 1)
net = ablock(net, 20, 1, act, 3, get_unique_name())
out_act = lambda x: tf.where(x<0, 0.1*x, tf.where(x>1, 0.1*x+1, x))
net = conv2d(net, 1, 3, 1, out_act, 'SAME', False, 'out')
return net, net.outputs
if __name__ == '__main__':
x = tf.placeholder(tf.float32, [None, 32, 32, 1])
classes_label = tf.placeholder(tf.int32, [None, ])
samples_placeholder = tf.placeholder(tf.float32, [None, 64])
encoder, encoder_output = get_encoder(x, classes_label, False)
print(encoder_output)
decoder, decoder_output = get_decoder(encoder_output[0], classes_label, False)
print(decoder_output)
discriminator_net.py
from model_utils import *
import vae_net
act = lambda x: tl.act.leaky_twice_relu6(x, 0.1, 0.1)
def get_discriminator(img, reuse):
with tf.variable_scope('discriminator', reuse=reuse):
net = tl.layers.InputLayer(img)
b_id = 0
def get_unique_name():
nonlocal b_id
b_id += 1
return str(b_id)
net = ablock(net, 20, 1, act, 2, get_unique_name())
net = ablock(net, 40, 2, act, 3, get_unique_name())
net = ablock(net, 60, 2, act, 3, get_unique_name())
net = ablock(net, 80, 2, act, 4, get_unique_name())
net = tl.layers.GlobalMeanPool2d(net)
net = tl.layers.DenseLayer(net, 64, None, name='out')
# 自编码器
net2, net2_output = vae_net.get_decoder(net.outputs, None, reuse)
net = tl.layers.merge_networks([net2, net])
return net, net.outputs
if __name__ == '__main__':
x = tf.placeholder(tf.float32, [None, 32, 32, 1])
discriminator, discriminator_output = get_discriminator(x, False)
print(discriminator_output)
下面三个为功能性主程序
train.py
原码中用手写字符数据集训练,下述程序中使用(64,64,1)数据集。可依据需求修改参数。
训练结束后会对模型训练好的参数以及生成的图片进行保存。
from time import time
t1 = time()
import tensorflow as tf
import tensorlayer as tl
import numpy as np
import vae_net
import discriminator_net
import classifier_net
# from skimage.transform import resize
# import my_py_lib.utils
import os
from progressbar import progressbar
print('加载 tf 耗时', time()-t1)
tl.logging.set_verbosity('INFO')
steps = 64
t1 = time()
def read_tfrecords(tfrecord_name):
#将tfrecords读入流中,乱序操作并循环读取
filename_queue = tf.train.string_input_producer([tfrecord_name])
reader = tf.TFRecordReader()
#返回文件名和文件
_, serialized_example = reader.read(filename_queue)
#取出文件中包含image和label的feature对象
features = tf.parse_single_example(serialized_example,
features={
'label': tf.FixedLenFeature([], tf.int64),
'img_raw' : tf.FixedLenFeature([], tf.string),
})
#将字符串解析成图像对应的像素数组
image = tf.decode_raw(features['img_raw'], tf.uint8)
#改变像素数组的大小,彩图是3通道的
image = tf.reshape(image, [64, 64, 1])
#将像素数组归一化
image = tf.cast(image,tf.float32)*(1./255)
#读取标签
label = tf.cast(features['label'], tf.int32)
#将标签制成one_hot
label = tf.one_hot(label,depth=classes_num,on_value=1)
#按批次大小乱序读取数据
x_batch, y_batch = tf.train.shuffle_batch([image,label],
batch_size=steps,
num_threads=1, capacity=30*steps,
min_after_dequeue=15*steps)
return x_batch,y_batch
#图片总共由4类,用于one_hot标签
classes_num = 4
#获取训练集数据
x_dataset = read_tfrecords('路径')
#获取测试集数据
y_dataset = read_tfrecords('路径')
# 加载和处理数据
#x_dataset, y_dataset = tl.files.load_mnist_dataset((-1, 28, 28, 1))[:2]
x_dataset = np.array([np.resize(i, (64, 64), 3, 'constant', 0, True, False, False) for i in x_dataset], np.float32)
x_dataset = tl.prepro.threading_data(x_dataset, tl.prepro.imresize, size=(64, 64)) / 255.
print('加载数据集耗时', time() - t1)
with tf.Session().as_default() as sess:
#必写
sess.run(tf.global_variables_initializer())
coord=tf.train.Coordinator()
threads= tf.train.start_queue_runners(coord=coord)
img = tf.placeholder(tf.float32, [None, 64, 64, 1])
classes_label = tf.placeholder(tf.int32, [None, ])
kt = tf.Variable(0, False, dtype=tf.float32, name='kt')
lr_placeholder = tf.placeholder(tf.float32, name='lr_placeholder')
gamma = 0.5
lamda = 0.5
# for test
samples_placeholder = tf.placeholder(tf.float32, [None, 64])
encoder, encoder_output = vae_net.get_encoder(img, classes_label, False)
decoder, decoder_output = vae_net.get_decoder(encoder_output[0], classes_label, False)
# for test
_, samples_decoder_output = vae_net.get_decoder(samples_placeholder, classes_label, True)
classifier_real, classifier_real_output = classifier_net.get_classifier(img, False)
classifier_fake, classifier_fake_output = classifier_net.get_classifier(decoder_output, True)
discriminator_real, discriminator_real_output = discriminator_net.get_discriminator(img, False)
discriminator_fake, discriminator_fake_output = discriminator_net.get_discriminator(decoder_output, True)
print('encoder params count', encoder.count_params())
print('decoder params count', decoder.count_params())
print('classifier params count', classifier_real.count_params())
print('discriminator params count', discriminator_real.count_params())
def get_kl_loss(mean, log_sigma):
# 限制 exp 的值,以免爆炸
log_sigma = tf.minimum(log_sigma, 20)
return tf.reduce_mean(-0.5 * tf.reduce_sum(1 + log_sigma - tf.square(mean) - tf.exp(log_sigma), 1))
kl_loss_op = get_kl_loss(*encoder_output[1:])
d_loss_real = tf.reduce_mean(tf.abs(discriminator_real_output - img))
d_loss_fake = tf.reduce_mean(tf.abs(discriminator_real_output - decoder_output))
d_loss_op = d_loss_real - kt * d_loss_fake
kt_update_op = tf.assign(kt, tf.clip_by_value(kt + lamda * (gamma * d_loss_real - d_loss_fake), 0., 1.))
m_global = d_loss_real + tf.abs(gamma * d_loss_real - d_loss_fake)
g_loss_img = tf.reduce_mean(tf.abs(img - decoder_output))
g_loss_disc = d_loss_fake
g_loss_class = tf.losses.sigmoid_cross_entropy(tf.one_hot(classes_label, 4, dtype=tf.int32), classifier_fake_output)
g_loss_op = g_loss_img + g_loss_disc + g_loss_class
c_loss_op = tf.losses.sigmoid_cross_entropy(tf.one_hot(classes_label, 4, dtype=tf.int32), classifier_real_output)
c_top_1 = tf.reduce_mean(tf.to_float(tf.nn.in_top_k(classifier_real_output, classes_label, 1)))
tf.summary.scalar('g/kl_loss_op', kl_loss_op)
tf.summary.histogram('g/z', encoder_output[0])
tf.summary.scalar('g/img', g_loss_img)
tf.summary.scalar('g/disc', g_loss_disc)
tf.summary.scalar('g/class', g_loss_class)
tf.summary.image('g/ori_img', tf.cast(tf.clip_by_value(img*255, 0, 255), tf.uint8), 3)
tf.summary.image('g/gen_img', tf.cast(tf.clip_by_value(decoder_output*255, 0, 255), tf.uint8), 3)
tf.summary.image('d/real_img', tf.cast(tf.clip_by_value(discriminator_real_output*255, 0, 255), tf.uint8), 3)
tf.summary.image('d/fake_img', tf.cast(tf.clip_by_value(discriminator_fake_output*255, 0, 255), tf.uint8), 3)
tf.summary.scalar('d/real', d_loss_real)
tf.summary.scalar('d/fake', d_loss_fake)
tf.summary.scalar('c/loss', c_loss_op)
tf.summary.scalar('c/in_top_1', c_top_1)
tf.summary.scalar('misc/kt', kt)
tf.summary.scalar('misc/m_global', m_global)
tf.summary.scalar('misc/lr', lr_placeholder)
os.makedirs('imgs', exist_ok=True)
os.makedirs('logs', exist_ok=True)
merge_summary_op = tf.summary.merge_all()
sum_file = tf.summary.FileWriter('logs', tf.get_default_graph())
g_optim = tf.train.AdamOptimizer(lr_placeholder).minimize(kl_loss_op + g_loss_op, var_list=encoder.all_params + decoder.all_params)
d_optim = tf.train.AdamOptimizer(lr_placeholder).minimize(d_loss_op + c_loss_op, var_list=discriminator_real.all_params + classifier_real.all_params)
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
config.allow_soft_placement = True
sess = tf.Session(config=config)
sess.run(tf.global_variables_initializer())
tl.files.load_and_assign_npz(sess, 'encoder.npz', encoder)
tl.files.load_and_assign_npz(sess, 'decoder.npz', decoder)
tl.files.load_and_assign_npz(sess, 'classifier.npz', classifier_real)
tl.files.load_and_assign_npz(sess, 'discriminator.npz', discriminator_real)
n_epoch = 200
batch_size = 30
n_batch = int(np.ceil(len(x_dataset)/batch_size))
lr = 0.001
for e in range(n_epoch):
for b in progressbar(range(n_batch)):
feed_dict = {
img: x_dataset[b*batch_size:(b+1)*batch_size], classes_label: y_dataset[b*batch_size:(b+1)*batch_size], lr_placeholder: lr}
if b == 5 and e == 0:
run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
run_metadata = tf.RunMetadata()
_, sum = sess.run([[g_optim, d_optim, kt_update_op], merge_summary_op], feed_dict, run_options, run_metadata)
sum_file.add_run_metadata(run_metadata, 'step%d' % (e*n_batch+b))
sum_file.add_summary(sum, e*n_batch+b)
else:
_, sum = sess.run([[g_optim, d_optim, kt_update_op], merge_summary_op], feed_dict)
sum_file.add_summary(sum, e*n_batch+b)
if b % 20 == 0:
lr *= 0.95
if b % 100 == 0:
tl.files.save_npz(encoder.all_params, 'encoder.npz', sess)
tl.files.save_npz(decoder.all_params, 'decoder.npz', sess)
tl.files.save_npz(classifier_real.all_params, 'classifier.npz', sess)
tl.files.save_npz(discriminator_real.all_params, 'discriminator.npz', sess)
samples_z = np.random.normal(size=[36, 64])
samples_c = np.random.randint(0, 10, [36,], np.int32)
samples_imgs = sess.run(samples_decoder_output, {
samples_placeholder: samples_z, classes_label: samples_c})
samples_imgs = np.asarray((samples_imgs - np.min(samples_imgs)) / (np.max(samples_imgs) - np.min(samples_imgs)) * 255, np.uint8)
tl.vis.save_images(samples_imgs, (6, 6), 'imgs/%d_%d.jpg' % (e, b))
coord.request_stop()
coord.join(threads)
test.py
加载保存的参数,并进行测试运行,对测试生成的图片进行保存。检验模型性能。
import tensorflow as tf
import tensorlayer as tl
import numpy as np
import vae_net
import discriminator_net
import os
from progressbar import progressbar
tl.logging.set_verbosity('INFO')
classes_label = tf.placeholder(tf.int32, [None, ])
z_placeholder = tf.placeholder(tf.float32, [None, 64])
decoder, decoder_output = vae_net.get_decoder(z_placeholder, classes_label, False)
discriminator_fake, discriminator_fake_output = discriminator_net.get_discriminator(decoder_output, False)
print('decoder params count', decoder.count_params())
print('discriminator params count', discriminator_fake.count_params())
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
config.allow_soft_placement = True
sess = tf.Session(config=config)
sess.run(tf.global_variables_initializer())
tl.files.load_and_assign_npz(sess, 'decoder.npz', decoder)
tl.files.load_and_assign_npz(sess, 'discriminator.npz', discriminator_fake)
n_samples = 64######32
os.makedirs('test_output', exist_ok=True)
def tr_imgs(imgs_float):
return np.asarray((imgs_float - np.min(imgs_float)) / (np.max(imgs_float) - np.min(imgs_float)) * 255, np.uint8)
for b in progressbar(range(n_samples)):
samples_z = np.random.normal(size=[3*6, 64])
samples_c = np.random.randint(0, 10, [3*6,], np.int32)
feed_dict = {
z_placeholder: samples_z, classes_label: samples_c}
samples_imgs, recon_imgs = sess.run([decoder_output, discriminator_fake_output], feed_dict)
samples_imgs = tr_imgs(samples_imgs)
recon_imgs = tr_imgs(recon_imgs)
output_imgs = np.concatenate([samples_imgs, recon_imgs], 0)
tl.vis.save_images(output_imgs, (6, 6), 'test_output/%d.jpg' % b)
classifier_net.py
对输入的数据进行分类处理,检验模型的分类能力。
from model_utils import *
act = lambda x: tl.act.leaky_twice_relu6(x, 0.1, 0.1)
def get_classifier(img, reuse):
with tf.variable_scope('classifier', reuse=reuse):
net = tl.layers.InputLayer(img)
b_id = 0
def get_unique_name():
nonlocal b_id
b_id += 1
return str(b_id)
net = ablock(net, 20, 1, act, 2, get_unique_name())
net = ablock(net, 40, 2, act, 3, get_unique_name())
net = ablock(net, 60, 2, act, 3, get_unique_name())
net = ablock(net, 80, 2, act, 4, get_unique_name())
net = tl.layers.GlobalMeanPool2d(net)
net = tl.layers.DenseLayer(net, 10, act, name='out')
return net, net.outputs
if __name__ == '__main__':
x = tf.placeholder(tf.float32, [None, 32, 32, 1])
classifier, classifier_output = get_classifier(x, False)
print(classifier_output)
欢迎有兴趣的一起交流,共同进步!
希望对大家有所帮助!