2018-12-12-VAE

class VAE(object):

def __init__(self, n_hidden=500, dim_z=20, n_epochs=20, batch_size=128, learn_rate=1e-3,

                model_path=base_dir +"/similarity_k_line/feature_extract/vae_feature_map_model"):

""" parameters """

        self.model_path = model_path

# network architecture

        self.n_hidden = n_hidden

self.dim_img =7*89  # number of pixels for a feature image

        self.dim_z = dim_z

# train

        self.n_epochs = n_epochs

self.batch_size = batch_size

self.learn_rate = learn_rate

# start a subThread for map producer

        self.data_generator()

# build graph

        self.build_graph()

def data_generator(self):

pool =list()

self.warehouse = Warehouse(pool=pool)

p = Build_maps(pool=self.warehouse)

p.start()

def build_graph(self):

# input placeholders

# In denoising-autoencoder, x_hat == x + noise, otherwise x_hat == x

        self.x_hat = tf.placeholder(tf.float32, shape=[None, self.dim_img], name='input_img')

self.x = tf.placeholder(tf.float32, shape=[None, self.dim_img], name='target_img')

self.global_steps = tf.Variable(0, trainable=False, name="global_steps")

# dropout

        self.keep_prob = tf.placeholder(tf.float32, name='keep_prob')

# network architecture

        y, self.z, self.loss, self.neg_marginal_likelihood, self.KL_divergence = \

vae.autoencoder(self.x_hat, self.x, self.dim_img, self.dim_z, self.n_hidden, self.keep_prob)

with tf.name_scope('loss'):

tf.summary.scalar('total_loss', self.loss)

tf.summary.scalar('KL divergence', self.KL_divergence)

tf.summary.scalar('likelihood loss', self.neg_marginal_likelihood)

# 构建Graph的变量列表

# self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, tf.get_variable_scope().name)

    def train(self):

# train_op

        self.train_op = tf.train.AdamOptimizer(self.learn_rate).minimize(self.loss, global_step=self.global_steps)

saver = tf.train.Saver(max_to_keep=4)

with tf.Session()as sess:

writer = tf.summary.FileWriter("logs/", sess.graph)

merge_summary = tf.summary.merge_all()

ckpt = tf.train.get_checkpoint_state(self.model_path)

if ckptand ckpt.model_checkpoint_path:

# saver = tf.train.import_meta_graph(meta_graph_or_file=ckpt.model_checkpoint_path + ".meta")

                saver.restore(sess, tf.train.latest_checkpoint(checkpoint_dir=self.model_path))

print('finish loading model!')

else:

print("no checkpoint found...")

sess.run(tf.global_variables_initializer())

# start train thread

            while True:

if self.warehouse.get_length() <10000:

continue

                for epochin range(self.n_epochs):

batch_xs_input = np.array(self.warehouse.get(num_retrived=500)).reshape([500, -1])

batch_xs_target = batch_xs_input

# train

                    _, tot_loss, loss_likelihood, loss_divergence, train_summary, global_steps, z = sess.run(

(self.train_op, self.loss, self.neg_marginal_likelihood, self.KL_divergence, merge_summary,

                        self.global_steps, self.z), feed_dict={self.x_hat: batch_xs_input, self.x: batch_xs_target,

                                                        self.keep_prob:0.9})

writer.add_summary(train_summary, global_steps)

# print cost every epoch

                    print("epoch %d: L_tot %03.2f L_likelihood %03.2f L_divergence %03.2f" %

(epoch, tot_loss, loss_likelihood, loss_divergence))

if global_steps %10 ==0:

saver.save(sess=sess, save_path=self.model_path +"/vae_model", global_step=global_steps)

def run_encoder(self):

feature_map_input =None

        saver = tf.train.Saver(max_to_keep=4)

with tf.Session()as sess:

# model restore

            ckpt = tf.train.get_checkpoint_state(self.model_path)

if ckptand ckpt.model_checkpoint_path:

# saver = tf.train.import_meta_graph(meta_graph_or_file=ckpt.model_checkpoint_path + ".meta")

                saver.restore(sess, tf.train.latest_checkpoint(checkpoint_dir=self.model_path))

print('finish loading model!')

while True:

if self.warehouse.get_length() >100:

feature_map_input = np.array(self.warehouse.get(num_retrived=1)).reshape([1, -1])

# encoder

                z = sess.run(fetches=self.z, feed_dict={self.x_hat: feature_map_input, self.x: feature_map_input,

                                                        self.keep_prob:0.9})

print(z)

你可能感兴趣的:(2018-12-12-VAE)