cvae-gan tensorflow实现

论文依据:Cvae-gan: fine-grained image generation through asymmetric training。
代码来源:github
关于论文讲解分析,csdn已经有不少例子,在此不做详细解释。
该模型呢,可以应用于,如图像修复、超分辨率和数据增强,以训练更好的人脸识别模型等领域。
该程序,按实验目的分为3部分,即训练网络,测试网络以及分类网络。又有三个基础支持网络,即基本模型,VAE网络和判别网络。分工明确,环环相扣。
首先是基本模型
model_utils.py

import tensorflow as tf
import tensorlayer as tl
import numpy as np


def _channel_shuffle(x, n_group):
    n, h, w, c = x.shape.as_list()
    x_reshaped = tf.reshape(x, [-1, h, w, n_group, c // n_group])
    x_transposed = tf.transpose(x_reshaped, [0, 1, 2, 4, 3])
    output = tf.reshape(x_transposed, [-1, h, w, c])
    return output


def _group_norm_and_channel_shuffle(x, is_train, G=32, epsilon=1e-12, use_shuffle=False, name='_group_norm'):
    with tf.variable_scope(name):
        N, H, W, C = x.get_shape().as_list()
        if N == None:
            N = -1
        G = min(G, C)
        x = tf.reshape(x, [N, G, H, W, C // G])
        mean, var = tf.nn.moments(x, [2, 3, 4], keep_dims=True)
        x = (x - mean) / tf.sqrt(var + epsilon)
        # shuffle channel
        if use_shuffle:
            x = tf.transpose(x, [0, 4, 2, 3, 1])
        # per channel gamma and beta
        gamma = tf.get_variable('gamma', [C], initializer=tf.constant_initializer(1.0), trainable=is_train)
        beta = tf.get_variable('beta', [C], initializer=tf.constant_initializer(0.0), trainable=is_train)
        gamma = tf.reshape(gamma, [1, 1, 1, C])
        beta = tf.reshape(beta, [1, 1, 1, C])

        output = tf.reshape(x, [N, H, W, C]) * gamma + beta
    return output


def _switch_norm(x, name='_switch_norm') :
    with tf.variable_scope(name) :
        ch = x.shape[-1]
        eps = 1e-5

        batch_mean, batch_var = tf.nn.moments(x, [0, 1, 2], keep_dims=True)
        ins_mean, ins_var = tf.nn.moments(x, [1, 2], keep_dims=True)
        layer_mean, layer_var = tf.nn.moments(x, [1, 2, 3], keep_dims=True)

        gamma = tf.get_variable("gamma", [ch], initializer=tf.constant_initializer(1.0))
        beta = tf.get_variable("beta", [ch], initializer=tf.constant_initializer(0.0))

        mean_weight = tf.nn.softmax(tf.get_variable("mean_weight", [3], initializer=tf.constant_initializer(1.0)))
        var_wegiht = tf.nn.softmax(tf.get_variable("var_weight", [3], initializer=tf.constant_initializer(1.0)))

        mean = mean_weight[0] * batch_mean + mean_weight[1] * ins_mean + mean_weight[2] * layer_mean
        var = var_wegiht[0] * batch_var + var_wegiht[1] * ins_var + var_wegiht[2] * layer_var

        x = (x - mean) / (tf.sqrt(var + eps))
        x = x * gamma + beta

        return x


def _add_coord(x):
    batch_size = tf.shape(x)[0]
    height, width = x.shape.as_list()[1:3]

    # 加1是为了使坐标值为[0,1],不加1则是[0,1)
    y_coord = tf.range(0, height, dtype=tf.float32)
    y_coord = tf.reshape(y_coord, [1, -1, 1, 1])  # b,h,w,c
    y_coord = tf.tile(y_coord, [batch_size, 1, width, 1]) / (height-1)

    x_coord = tf.range(0, width, dtype=tf.float32)
    x_coord = tf.reshape(x_coord, [1, 1, -1, 1])  # b,h,w,c
    x_coord = tf.tile(x_coord, [batch_size, height, 1, 1]) / (width-1)

    o = tf.concat([x, y_coord, x_coord], 3)
    return o


def coord_layer(net):
    return tl.layers.LambdaLayer(net, _add_coord, name='coord_layer')


def switchnorm_layer(net, act, name):
    net = tl.layers.LambdaLayer(net, _switch_norm, name=name)
    if act is not None:
        net = tl.layers.LambdaLayer(net, act, name=name)
    return net


def groupnorm_layer(net, is_train, G, use_shuffle, act, name):
    net = tl.layers.LambdaLayer(net, _group_norm_and_channel_shuffle, {
     'is_train':is_train, 'G':G, 'use_shuffle':use_shuffle, 'name':name}, name=name)
    if act is not None:
        net = tl.layers.LambdaLayer(net, act, name=name)
    return net


def upsampling_layer(net, shortpoint):
    hw = shortpoint.outputs.shape.as_list()[1:3]
    net_upsamping = tl.layers.UpSampling2dLayer(net, hw, is_scale=False)
    net = tl.layers.ConcatLayer([net_upsamping, shortpoint], -1)
    return net


def upsampling_layer2(net, shortpoint, name):
    with tf.variable_scope(name):
        hw = shortpoint.outputs.shape.as_list()[1:3]
        dim1 = net.outputs.shape.as_list()[3]
        dim2 = shortpoint.outputs.shape.as_list()[3]

        net = conv2d(net, dim1//2, 1, 1, None, 'SAME', True, True, False, 'up1')
        shortpoint = conv2d(shortpoint, dim2//2, 1, 1, None, 'SAME', True, True, False, 'up2')

        net = tl.layers.UpSampling2dLayer(net, hw, is_scale=False)
        net = tl.layers.ConcatLayer([net, shortpoint], -1)
    return net


def upsampling_layer3(net, shortpoint):
    hw = shortpoint.outputs.shape.as_list()[1:3]
    shortpoint = tl.layers.LambdaLayer(shortpoint, lambda x: tf.split(x, 2, -1)[0])
    net = tl.layers.UpSampling2dLayer(net, hw, is_scale=False)
    net = tl.layers.ConcatLayer([net, shortpoint], -1)
    return net


def conv2d(net, n_filter, filter_size, strides, act, padding, use_norm, name):
    filter_size = np.broadcast_to(filter_size, [2])
    strides = np.broadcast_to(strides, [2])
    with tf.variable_scope(name):
        if use_norm:
            net = tl.layers.Conv2d(net, n_filter, filter_size, strides, None, padding, b_init=None, name='c2d')
            # net = groupnorm_layer(net, is_train, n_group, use_shuffle, act, 'gn')
            net = switchnorm_layer(net, act, 'sn')
        else:
            net = tl.layers.Conv2d(net, n_filter, filter_size, strides, act, padding, name='c2d')
    return net


def groupconv2d(net, n_filter, filter_size, strides, n_group, act, padding, use_norm, use_shuffle, name):
    filter_size = np.broadcast_to(filter_size, [2])
    strides = np.broadcast_to(strides, [2])
    with tf.variable_scope(name):
        if use_norm:
            net = tl.layers.GroupConv2d(net, n_filter, filter_size, strides, n_group, None, padding, b_init=None, name='gc2d')
            net = switchnorm_layer(net, act, 'sn')
            if use_shuffle:
                net = tl.layers.LambdaLayer(net, lambda x: _channel_shuffle(x, n_group))
        else:
            net = tl.layers.GroupConv2d(net, n_filter, filter_size, strides, n_group, act, padding, name='gc2d')

    return net


def deconv2d(net, n_filter, filter_size, strides, act, padding, use_norm, name):
    filter_size = np.broadcast_to(filter_size, [2])
    strides = np.broadcast_to(strides, [2])
    with tf.variable_scope(name):
        if use_norm:
            net = tl.layers.DeConv2d(net, n_filter, filter_size, strides=strides, padding=padding, b_init=None, name='dc2d')
            net = switchnorm_layer(net, act, 'sn')
        else:
            net = tl.layers.DeConv2d(net, n_filter, filter_size, strides=strides, act=act, padding=padding, name='dc2d')
    return net


def depthwiseconv2d(net, depth_multiplier, filter_size, strides, act, padding, use_norm, name, dilation_rate=1):
    filter_size = np.broadcast_to(filter_size, [2])
    strides = np.broadcast_to(strides, [2])
    dilation_rate = np.broadcast_to(dilation_rate, [2])
    with tf.variable_scope(name):
        if use_norm:
            net = tl.layers.DepthwiseConv2d(net, filter_size, strides, None, padding, dilation_rate=dilation_rate, depth_multiplier=depth_multiplier, b_init=None, name='dwc2d')
            net = switchnorm_layer(net, act, 'sn')
        else:
            net = tl.layers.DepthwiseConv2d(net, filter_size, strides, act, padding, dilation_rate=dilation_rate, depth_multiplier=depth_multiplier, name='dwc2d')

    return net


def resblock_1(net, n_filter, strides, act, name):
    strides = np.broadcast_to(strides, [2])
    with tf.variable_scope(name):
        if np.max(strides) > 1 or net.outputs.shape.as_list()[-1] != n_filter:
            shortcut = conv2d(net, n_filter, (3, 3), strides, None, 'SAME', True, 'shortcut')
        else:
            shortcut = net
        net = coord_layer(net)
        net = conv2d(net, n_filter, 1, 1, None, 'SAME', False, 'c1')
        net = groupconv2d(net, n_filter, 3, 1, 20, act, 'SAME', True, True, 'gc1')
        net = depthwiseconv2d(net, 1, 3, strides, None, 'SAME', True, 'dwc2')
        net = groupconv2d(net, n_filter, 1, 1, 20, act, 'SAME', True, True, 'gc2')
        net = tl.layers.ElementwiseLayer([shortcut, net], tf.add)
    return net


def resblock_2(net, n_filter, strides, act, name):
    strides = np.broadcast_to(strides, [2])
    with tf.variable_scope(name):
        if np.max(strides) > 1 or net.outputs.shape.as_list()[-1] != n_filter:
            shortcut = conv2d(net, n_filter, (3, 3), strides, None, 'SAME', True, 'shortcut')
        else:
            shortcut = net
        net = conv2d(net, n_filter, 3, strides, act, 'SAME', True, 'c1')
        net = conv2d(net, n_filter, 3, 1, act, 'SAME', True, 'c1')
        net = tl.layers.ElementwiseLayer([shortcut, net], tf.add)
    return net


def ablock(net, n_filter, strides, act, n_block, name):
    with tf.variable_scope('ab_' + name):
        net = resblock_1(net, n_filter, strides, act, 'rb_0')
        for i in range(1, n_block):
            net = resblock_1(net, n_filter, 1, act, 'rb_%d' % i)
    return net


def group_block(net, n_filter, strides, act, block_type, n_block, name):
    with tf.variable_scope('gb_' + name):
        net = block_type(net, n_filter=n_filter, strides=strides, act=act, name='b_0')
        for i in range(1, n_block):
            net = block_type(net, n_filter, 1, act, 'b_%d' % i)
    return net

vae_net.py

from model_utils import *

act = lambda x: tl.act.leaky_twice_relu6(x, 0.1, 0.1)
n_hidden = 64
# block

def get_encoder(img, c, reuse):
    hw = img.shape.as_list()[1:3]
    with tf.variable_scope('encoder', reuse=reuse):
        net = tl.layers.InputLayer(img)
        # 输入类别信息
        c_net = tl.layers.OneHotInputLayer(c, 10, axis=-1, dtype=tf.float32)
        c_net = tl.layers.DenseLayer(c_net, hw[0]*hw[1], act, name='d1')
        c_net = tl.layers.ReshapeLayer(c_net, (-1, hw[0], hw[1], 1))
        net = tl.layers.ConcatLayer([net, c_net], -1)

        b_id = 0
        def get_unique_name():
            nonlocal b_id
            b_id += 1
            return str(b_id)

        net = ablock(net, 20, 1, act, 2, get_unique_name())
        net = ablock(net, 40, 2, act, 3, get_unique_name())
        net = ablock(net, 60, 2, act, 3, get_unique_name())
        net = ablock(net, 80, 2, act, 4, get_unique_name())

        net = tl.layers.GlobalMeanPool2d(net)
        net = tl.layers.DenseLayer(net, n_hidden, act, name='out')

        mean = tl.layers.DenseLayer(net, n_hidden, act, name='mean')
        log_sigma = tl.layers.DenseLayer(net, n_hidden, act, name='log_sigma')

        net = tl.layers.merge_networks([net, mean, log_sigma])

        mean = mean.outputs
        log_sigma = log_sigma.outputs
        std = log_sigma * 0.5
        noise = tf.random_normal(tf.shape(mean))

        z = mean + noise * tf.exp(tf.minimum(std, 20))

    return net, [z, mean, log_sigma]


def get_decoder(z, c, reuse):
    z_len = z.shape.as_list()[-1]
    with tf.variable_scope('decoder', reuse=reuse):
        net = tl.layers.InputLayer(z)
        # 便于鉴别器复用这里解码器,鉴别器没有c输入
        if c is not None:
            c_net = tl.layers.OneHotInputLayer(c, 10, axis=-1, dtype=tf.float32)
            c_net = tl.layers.DenseLayer(c_net, z_len, act, name='d1')
            net = tl.layers.ConcatLayer([net, c_net], -1)
            net = tl.layers.DenseLayer(net, z_len, act, name='d2')

        b_id = 0
        def get_unique_name():
            nonlocal b_id
            b_id += 1
            return str(b_id)

        net = tl.layers.ReshapeLayer(net, (-1, 1, 1, n_hidden))
        net = tl.layers.TileLayer(net, [1, 4, 4, 1])

        net = ablock(net, 80, 1, act, 3, get_unique_name())
        net = tl.layers.UpSampling2dLayer(net, (2, 2), True, 1)
        net = ablock(net, 60, 1, act, 3, get_unique_name())
        net = tl.layers.UpSampling2dLayer(net, (2, 2), True, 1)
        net = ablock(net, 40, 1, act, 3, get_unique_name())
        net = tl.layers.UpSampling2dLayer(net, (2, 2), True, 1)
        net = ablock(net, 20, 1, act, 3, get_unique_name())

        out_act = lambda x: tf.where(x<0, 0.1*x, tf.where(x>1, 0.1*x+1, x))
        net = conv2d(net, 1, 3, 1, out_act, 'SAME', False, 'out')

    return net, net.outputs


if __name__ == '__main__':
    x = tf.placeholder(tf.float32, [None, 32, 32, 1])
    classes_label = tf.placeholder(tf.int32, [None, ])
    samples_placeholder = tf.placeholder(tf.float32, [None, 64])

    encoder, encoder_output = get_encoder(x, classes_label, False)
    print(encoder_output)
    decoder, decoder_output = get_decoder(encoder_output[0], classes_label, False)
    print(decoder_output)

discriminator_net.py

from model_utils import *
import vae_net

act = lambda x: tl.act.leaky_twice_relu6(x, 0.1, 0.1)

def get_discriminator(img, reuse):
    with tf.variable_scope('discriminator', reuse=reuse):
        net = tl.layers.InputLayer(img)

        b_id = 0

        def get_unique_name():
            nonlocal b_id
            b_id += 1
            return str(b_id)

        net = ablock(net, 20, 1, act, 2, get_unique_name())
        net = ablock(net, 40, 2, act, 3, get_unique_name())
        net = ablock(net, 60, 2, act, 3, get_unique_name())
        net = ablock(net, 80, 2, act, 4, get_unique_name())

        net = tl.layers.GlobalMeanPool2d(net)
        net = tl.layers.DenseLayer(net, 64, None, name='out')

        # 自编码器
        net2, net2_output = vae_net.get_decoder(net.outputs, None, reuse)

        net = tl.layers.merge_networks([net2, net])

    return net, net.outputs


if __name__ == '__main__':
    x = tf.placeholder(tf.float32, [None, 32, 32, 1])
    discriminator, discriminator_output = get_discriminator(x, False)
    print(discriminator_output)

下面三个为功能性主程序
train.py
原码中用手写字符数据集训练,下述程序中使用(64,64,1)数据集。可依据需求修改参数。
训练结束后会对模型训练好的参数以及生成的图片进行保存。

from time import time
t1 = time()
import tensorflow as tf
import tensorlayer as tl
import numpy as np
import vae_net
import discriminator_net
import classifier_net
# from skimage.transform import resize
# import my_py_lib.utils
import os
from progressbar import progressbar
print('加载 tf 耗时', time()-t1)

tl.logging.set_verbosity('INFO')

steps = 64
t1 = time()
def read_tfrecords(tfrecord_name):
    #将tfrecords读入流中,乱序操作并循环读取
    filename_queue = tf.train.string_input_producer([tfrecord_name]) 
    reader = tf.TFRecordReader()
    #返回文件名和文件
    _, serialized_example = reader.read(filename_queue)
    #取出文件中包含image和label的feature对象
    features = tf.parse_single_example(serialized_example,
                                       features={
     
                                           'label': tf.FixedLenFeature([], tf.int64),
                                           'img_raw' : tf.FixedLenFeature([], tf.string),
                                       })
    #将字符串解析成图像对应的像素数组
    image = tf.decode_raw(features['img_raw'], tf.uint8)
    #改变像素数组的大小,彩图是3通道的
    image = tf.reshape(image, [64, 64, 1])
    #将像素数组归一化
    image = tf.cast(image,tf.float32)*(1./255)
    #读取标签
    label = tf.cast(features['label'], tf.int32)
    #将标签制成one_hot
    label = tf.one_hot(label,depth=classes_num,on_value=1)
    #按批次大小乱序读取数据
    x_batch, y_batch = tf.train.shuffle_batch([image,label], 
                                              batch_size=steps, 
                                              num_threads=1, capacity=30*steps,
                                              min_after_dequeue=15*steps)
    return x_batch,y_batch

#图片总共由4类,用于one_hot标签
classes_num = 4

#获取训练集数据
x_dataset = read_tfrecords('路径')
#获取测试集数据
y_dataset = read_tfrecords('路径')
# 加载和处理数据
#x_dataset, y_dataset = tl.files.load_mnist_dataset((-1, 28, 28, 1))[:2]
x_dataset = np.array([np.resize(i, (64, 64), 3, 'constant', 0, True, False, False) for i in x_dataset], np.float32)
x_dataset = tl.prepro.threading_data(x_dataset, tl.prepro.imresize, size=(64, 64)) / 255.
print('加载数据集耗时', time() - t1)

with tf.Session().as_default() as sess: 
    #必写
    sess.run(tf.global_variables_initializer())
    coord=tf.train.Coordinator()
    threads= tf.train.start_queue_runners(coord=coord)
    
    
    img = tf.placeholder(tf.float32, [None, 64, 64, 1])
    classes_label = tf.placeholder(tf.int32, [None, ])
    kt = tf.Variable(0, False, dtype=tf.float32, name='kt')
    lr_placeholder = tf.placeholder(tf.float32, name='lr_placeholder')
    gamma = 0.5
    lamda = 0.5
    # for test
    samples_placeholder = tf.placeholder(tf.float32, [None, 64])
    
    encoder, encoder_output = vae_net.get_encoder(img, classes_label, False)
    decoder, decoder_output = vae_net.get_decoder(encoder_output[0], classes_label, False)
    # for test
    _, samples_decoder_output = vae_net.get_decoder(samples_placeholder, classes_label, True)
    
    classifier_real, classifier_real_output = classifier_net.get_classifier(img, False)
    classifier_fake, classifier_fake_output = classifier_net.get_classifier(decoder_output, True)
    discriminator_real, discriminator_real_output = discriminator_net.get_discriminator(img, False)
    discriminator_fake, discriminator_fake_output = discriminator_net.get_discriminator(decoder_output, True)
    
    print('encoder params count', encoder.count_params())
    print('decoder params count', decoder.count_params())
    print('classifier params count', classifier_real.count_params())
    print('discriminator params count', discriminator_real.count_params())
    
    def get_kl_loss(mean, log_sigma):
        # 限制 exp 的值,以免爆炸
        log_sigma = tf.minimum(log_sigma, 20)
        return tf.reduce_mean(-0.5 * tf.reduce_sum(1 + log_sigma - tf.square(mean) - tf.exp(log_sigma), 1))
    
    kl_loss_op = get_kl_loss(*encoder_output[1:])
    
    d_loss_real = tf.reduce_mean(tf.abs(discriminator_real_output - img))
    d_loss_fake = tf.reduce_mean(tf.abs(discriminator_real_output - decoder_output))
    
    d_loss_op = d_loss_real - kt * d_loss_fake
    
    kt_update_op = tf.assign(kt, tf.clip_by_value(kt + lamda * (gamma * d_loss_real - d_loss_fake), 0., 1.))
    m_global = d_loss_real + tf.abs(gamma * d_loss_real - d_loss_fake)
    
    g_loss_img = tf.reduce_mean(tf.abs(img - decoder_output))
    g_loss_disc = d_loss_fake
    g_loss_class = tf.losses.sigmoid_cross_entropy(tf.one_hot(classes_label, 4, dtype=tf.int32), classifier_fake_output)
    
    g_loss_op = g_loss_img + g_loss_disc + g_loss_class
    
    c_loss_op = tf.losses.sigmoid_cross_entropy(tf.one_hot(classes_label, 4, dtype=tf.int32), classifier_real_output)
    
    c_top_1 = tf.reduce_mean(tf.to_float(tf.nn.in_top_k(classifier_real_output, classes_label, 1)))
    
    
    tf.summary.scalar('g/kl_loss_op', kl_loss_op)
    tf.summary.histogram('g/z', encoder_output[0])
    tf.summary.scalar('g/img', g_loss_img)
    tf.summary.scalar('g/disc', g_loss_disc)
    tf.summary.scalar('g/class', g_loss_class)
    tf.summary.image('g/ori_img', tf.cast(tf.clip_by_value(img*255, 0, 255), tf.uint8), 3)
    tf.summary.image('g/gen_img', tf.cast(tf.clip_by_value(decoder_output*255, 0, 255), tf.uint8), 3)
    tf.summary.image('d/real_img', tf.cast(tf.clip_by_value(discriminator_real_output*255, 0, 255), tf.uint8), 3)
    tf.summary.image('d/fake_img', tf.cast(tf.clip_by_value(discriminator_fake_output*255, 0, 255), tf.uint8), 3)
    tf.summary.scalar('d/real', d_loss_real)
    tf.summary.scalar('d/fake', d_loss_fake)
    tf.summary.scalar('c/loss', c_loss_op)
    tf.summary.scalar('c/in_top_1', c_top_1)
    tf.summary.scalar('misc/kt', kt)
    tf.summary.scalar('misc/m_global', m_global)
    tf.summary.scalar('misc/lr', lr_placeholder)
    
    
    os.makedirs('imgs', exist_ok=True)
    os.makedirs('logs', exist_ok=True)
    merge_summary_op = tf.summary.merge_all()
    sum_file = tf.summary.FileWriter('logs', tf.get_default_graph())
    
    g_optim = tf.train.AdamOptimizer(lr_placeholder).minimize(kl_loss_op + g_loss_op, var_list=encoder.all_params + decoder.all_params)
    d_optim = tf.train.AdamOptimizer(lr_placeholder).minimize(d_loss_op + c_loss_op, var_list=discriminator_real.all_params + classifier_real.all_params)
    
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())
    
    tl.files.load_and_assign_npz(sess, 'encoder.npz', encoder)
    tl.files.load_and_assign_npz(sess, 'decoder.npz', decoder)
    tl.files.load_and_assign_npz(sess, 'classifier.npz', classifier_real)
    tl.files.load_and_assign_npz(sess, 'discriminator.npz', discriminator_real)
    
    
    n_epoch = 200
    batch_size = 30
    n_batch = int(np.ceil(len(x_dataset)/batch_size))
    lr = 0.001
    
    for e in range(n_epoch):
        for b in progressbar(range(n_batch)):
            feed_dict = {
     img: x_dataset[b*batch_size:(b+1)*batch_size], classes_label: y_dataset[b*batch_size:(b+1)*batch_size], lr_placeholder: lr}
    
            if b == 5 and e == 0:
                run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
                run_metadata = tf.RunMetadata()
                _, sum = sess.run([[g_optim, d_optim, kt_update_op], merge_summary_op], feed_dict, run_options, run_metadata)
                sum_file.add_run_metadata(run_metadata, 'step%d' % (e*n_batch+b))
                sum_file.add_summary(sum, e*n_batch+b)
            else:
                _, sum = sess.run([[g_optim, d_optim, kt_update_op], merge_summary_op], feed_dict)
                sum_file.add_summary(sum, e*n_batch+b)
    
            if b % 20 == 0:
                lr *= 0.95
    
            if b % 100 == 0:
                tl.files.save_npz(encoder.all_params, 'encoder.npz', sess)
                tl.files.save_npz(decoder.all_params, 'decoder.npz', sess)
                tl.files.save_npz(classifier_real.all_params, 'classifier.npz', sess)
                tl.files.save_npz(discriminator_real.all_params, 'discriminator.npz', sess)
    
                samples_z = np.random.normal(size=[36, 64])
                samples_c = np.random.randint(0, 10, [36,], np.int32)
                samples_imgs = sess.run(samples_decoder_output, {
     samples_placeholder: samples_z, classes_label: samples_c})
                samples_imgs = np.asarray((samples_imgs - np.min(samples_imgs)) / (np.max(samples_imgs) - np.min(samples_imgs)) * 255, np.uint8)
                tl.vis.save_images(samples_imgs, (6, 6), 'imgs/%d_%d.jpg' % (e, b))
    coord.request_stop()
    coord.join(threads)

test.py
加载保存的参数,并进行测试运行,对测试生成的图片进行保存。检验模型性能。

import tensorflow as tf
import tensorlayer as tl
import numpy as np
import vae_net
import discriminator_net
import os
from progressbar import progressbar

tl.logging.set_verbosity('INFO')


classes_label = tf.placeholder(tf.int32, [None, ])
z_placeholder = tf.placeholder(tf.float32, [None, 64])

decoder, decoder_output = vae_net.get_decoder(z_placeholder, classes_label, False)
discriminator_fake, discriminator_fake_output = discriminator_net.get_discriminator(decoder_output, False)

print('decoder params count', decoder.count_params())
print('discriminator params count', discriminator_fake.count_params())

config = tf.ConfigProto()
config.gpu_options.allow_growth = True
config.allow_soft_placement = True
sess = tf.Session(config=config)
sess.run(tf.global_variables_initializer())

tl.files.load_and_assign_npz(sess, 'decoder.npz', decoder)
tl.files.load_and_assign_npz(sess, 'discriminator.npz', discriminator_fake)

n_samples = 64######32

os.makedirs('test_output', exist_ok=True)

def tr_imgs(imgs_float):
    return np.asarray((imgs_float - np.min(imgs_float)) / (np.max(imgs_float) - np.min(imgs_float)) * 255, np.uint8)

for b in progressbar(range(n_samples)):
    samples_z = np.random.normal(size=[3*6, 64])
    samples_c = np.random.randint(0, 10, [3*6,], np.int32)
    feed_dict = {
     z_placeholder: samples_z, classes_label: samples_c}
    samples_imgs, recon_imgs = sess.run([decoder_output, discriminator_fake_output], feed_dict)
    samples_imgs = tr_imgs(samples_imgs)
    recon_imgs = tr_imgs(recon_imgs)
    output_imgs = np.concatenate([samples_imgs, recon_imgs], 0)
    tl.vis.save_images(output_imgs, (6, 6), 'test_output/%d.jpg' % b)

classifier_net.py
对输入的数据进行分类处理,检验模型的分类能力。

from model_utils import *

act = lambda x: tl.act.leaky_twice_relu6(x, 0.1, 0.1)

def get_classifier(img, reuse):
    with tf.variable_scope('classifier', reuse=reuse):
        net = tl.layers.InputLayer(img)

        b_id = 0

        def get_unique_name():
            nonlocal b_id
            b_id += 1
            return str(b_id)

        net = ablock(net, 20, 1, act, 2, get_unique_name())
        net = ablock(net, 40, 2, act, 3, get_unique_name())
        net = ablock(net, 60, 2, act, 3, get_unique_name())
        net = ablock(net, 80, 2, act, 4, get_unique_name())

        net = tl.layers.GlobalMeanPool2d(net)
        net = tl.layers.DenseLayer(net, 10, act, name='out')

    return net, net.outputs


if __name__ == '__main__':
    x = tf.placeholder(tf.float32, [None, 32, 32, 1])
    classifier, classifier_output = get_classifier(x, False)
    print(classifier_output)

欢迎有兴趣的一起交流,共同进步!
希望对大家有所帮助!

你可能感兴趣的:(vae,GAN,生成模型,深度学习,tensorflow,图像识别,神经网络)