【tensorflow】加性高斯噪声AutoEncoder

来源于TensorFlow实战, 黄文坚, 唐源

#coding=utf8
import numpy as np
import sklearn.preprocessing as prep
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

def xavier_init(fan_in, fan_out, constant = 1): # 泽威尔初始化
    low = -constant * np.sqrt(6.0 / (fan_in + fan_out))
    high = constant * np.sqrt(6.0 / (fan_in + fan_out))
    return tf.random_uniform((fan_in, fan_out),minval = low, maxval = high,dtype = tf.float32) # 均值为0,方差为2/(n_in + n_out)的均匀分布

class AdditiveGaussianNoiseAutoencoder(object): # 加性高斯噪声自编码器
    def __init__(self, n_input, n_hidden, transfer_function = tf.nn.softplus, optimizer = tf.train.AdamOptimizer(),
                 scale = 0.1):
        '''
        n_input: 输入变量数
        n_hidden: 隐含层节点数
        transfer_function: 隐含层激活函数
        optimizer: 优化器
        scale: 高斯噪声系数
        '''
        self.n_input = n_input
        self.n_hidden = n_hidden
        self.transfer = transfer_function
        self.scale = tf.placeholder(tf.float32)
        self.training_scale = scale
        network_weights = self._initialize_weights()
        self.weights = network_weights

        # model
        self.x = tf.placeholder(tf.float32, [None, self.n_input])
        self.hidden = self.transfer(tf.add(tf.matmul(self.x + scale * tf.random_normal((n_input,)),
                self.weights['w1']),
                self.weights['b1']))
        # (x + s * n) * w1 + b1 ,噪声n是 tf.random_normal((n_input,)),如果写出 tf.random_normal((n_input))报错
        # ValueError: Shape must be rank 1 but is rank 0 for 'random_normal/RandomStandardNormal' (op: 'RandomStandardNormal') with input shapes: [].
        self.reconstruction = tf.add(tf.matmul(self.hidden, self.weights['w2']), self.weights['b2'])

        # cost
        self.cost = 0.5 * tf.reduce_sum(tf.pow(tf.subtract(self.reconstruction, self.x), 2.0))
        self.optimizer = optimizer.minimize(self.cost)

        init = tf.global_variables_initializer()
        self.sess = tf.Session()
        self.sess.run(init)

    def _initialize_weights(self):
        all_weights = dict()
        all_weights['w1'] = tf.Variable(xavier_init(self.n_input, self.n_hidden))
        all_weights['b1'] = tf.Variable(tf.zeros([self.n_hidden], dtype = tf.float32))
        all_weights['w2'] = tf.Variable(tf.zeros([self.n_hidden, self.n_input], dtype = tf.float32))
        all_weights['b2'] = tf.Variable(tf.zeros([self.n_input], dtype = tf.float32)) # 重构成和原来相同size
        return all_weights

    def partial_fit(self, X):
        cost, opt = self.sess.run((self.cost, self.optimizer), feed_dict = {self.x: X,
                                                                            self.scale: self.training_scale
                                                                            })
        return cost

    def calc_total_cost(self, X):
        return self.sess.run(self.cost, feed_dict = {self.x: X,
                                                     self.scale: self.training_scale
                                                     })

    def transform(self, X):
        return self.sess.run(self.hidden, feed_dict = {self.x: X,
                                                       self.scale: self.training_scale
                                                       })

    def generate(self, hidden = None):
        if hidden is None:
            hidden = np.random.normal(size = self.weights["b1"])
        return self.sess.run(self.reconstruction, feed_dict = {self.hidden: hidden})

    def reconstruct(self, X):
        return self.sess.run(self.reconstruction, feed_dict = {self.x: X,
                                                               self.scale: self.training_scale
                                                               })

    def getWeights(self):
        return self.sess.run(self.weights['w1'])

    def getBiases(self):
        return self.sess.run(self.weights['b1'])
        
        
        
        
mnist = input_data.read_data_sets('MNIST_data', one_hot = True)

def standard_scale(X_train, X_test):
    preprocessor = prep.StandardScaler().fit(X_train)
    X_train = preprocessor.transform(X_train)
    X_test = preprocessor.transform(X_test)
    return X_train, X_test

def get_random_block_from_data(data, batch_size):
    start_index = np.random.randint(0, len(data) - batch_size)
    return data[start_index:(start_index + batch_size)]

X_train, X_test = standard_scale(mnist.train.images, mnist.test.images)

n_samples = int(mnist.train.num_examples)
training_epochs = 100
batch_size = 128
display_step = 1

autoencoder = AdditiveGaussianNoiseAutoencoder(n_input = 784,
                                               n_hidden = 200,
                                               transfer_function = tf.nn.softplus,
                                               optimizer = tf.train.AdamOptimizer(learning_rate = 0.001),
                                               scale = 0.01)

for epoch in range(training_epochs):
    all_cost = 0.
    total_batch = int(n_samples / batch_size)
    # Loop over all batches
    for i in range(total_batch):
        batch_xs = get_random_block_from_data(X_train, batch_size)

        # Fit training using batch data
        cost = autoencoder.partial_fit(batch_xs)
        # Compute average loss
        all_cost += cost 
    avg_cost = all_cost / (total_batch * batch_size) # 每一个样本的损失
    # Display logs per epoch step
    if epoch % display_step == 0:
        print("Epoch{0}, training, {1}, test, {2}".format(epoch,str(avg_cost),str(float(autoencoder.calc_total_cost(X_test))/float(X_test.shape[0]))))

(tensorflow_1.1.0))python 4_2_AutoEncoder.py

输出:

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
2019-03-16 15:40:47.590879: W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use SSE4.1 instructions, but these are available on your machine and could speed up CPU computations.
2019-03-16 15:40:47.590903: W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use SSE4.2 instructions, but these are available on your machine and could speed up CPU computations.
2019-03-16 15:40:47.590930: W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use AVX instructions, but these are available on your machine and could speed up CPU computations.
2019-03-16 15:40:47.590936: W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use AVX2 instructions, but these are available on your machine and could speed up CPU computations.
2019-03-16 15:40:47.590941: W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use FMA instructions, but these are available on your machine and could speed up CPU computations.
2019-03-16 15:40:47.652067: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:901] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2019-03-16 15:40:47.652350: I tensorflow/core/common_runtime/gpu/gpu_device.cc:887] Found device 0 with properties: 
name: GeForce GTX 1050
major: 6 minor: 1 memoryClockRate (GHz) 1.493
pciBusID 0000:01:00.0
Total memory: 3.95GiB
Free memory: 3.47GiB
2019-03-16 15:40:47.652383: I tensorflow/core/common_runtime/gpu/gpu_device.cc:908] DMA: 0 
2019-03-16 15:40:47.652390: I tensorflow/core/common_runtime/gpu/gpu_device.cc:918] 0:   Y 
2019-03-16 15:40:47.652415: I tensorflow/core/common_runtime/gpu/gpu_device.cc:977] Creating TensorFlow device (/gpu:0) -> (device: 0, name: GeForce GTX 1050, pci bus id: 0000:01:00.0)
Epoch0, training, 149.90008065639398, test, 110.149675
Epoch1, training, 100.7819356140279, test, 93.15680625
Epoch2, training, 82.02972562829932, test, 85.61115625
Epoch3, training, 77.79704608783855, test, 81.1944375
Epoch4, training, 73.00198553658865, test, 81.16870625
Epoch5, training, 76.58162193031578, test, 73.14924375
Epoch6, training, 65.52065071248231, test, 73.27254375
Epoch7, training, 68.13188927212518, test, 75.10165
Epoch8, training, 64.33849497886249, test, 70.59715625
Epoch9, training, 63.2652562859453, test, 69.42678125
Epoch10, training, 60.24016694200067, test, 68.8033125
Epoch11, training, 68.00162230004797, test, 70.22744375
Epoch12, training, 65.3913254993501, test, 65.38971875
Epoch13, training, 61.898716897675484, test, 66.6174625
Epoch14, training, 62.68255256884026, test, 66.12139375
Epoch15, training, 65.9540237969174, test, 63.759025
Epoch16, training, 63.017467569915844, test, 63.11196875
Epoch17, training, 61.400563651309426, test, 65.19886875
Epoch18, training, 61.48787060539761, test, 68.0450875
Epoch19, training, 58.67962571791002, test, 62.070525
Epoch20, training, 66.4236932678934, test, 65.19708125
Epoch21, training, 62.840021417969034, test, 65.21670625
Epoch22, training, 59.42216017307379, test, 67.3710125
Epoch23, training, 63.043263428694715, test, 64.9743625
Epoch24, training, 65.20769200180516, test, 66.04265
Epoch25, training, 66.18748156229655, test, 65.6527875
Epoch26, training, 60.343725729099795, test, 66.19673125
Epoch27, training, 63.28298964867225, test, 66.14604375
Epoch28, training, 55.71372350501569, test, 65.42456875
Epoch29, training, 59.14011889904529, test, 64.57674375
Epoch30, training, 55.97405335842035, test, 62.45036875
Epoch31, training, 59.48967899864926, test, 62.13161875
Epoch32, training, 57.57828269193778, test, 64.37306875
Epoch33, training, 58.93725188017447, test, 63.98418125
Epoch34, training, 59.13745266574246, test, 65.03913125
Epoch35, training, 55.790950535060645, test, 62.1180375
Epoch36, training, 58.87303747163786, test, 63.575225
Epoch37, training, 56.73477933123395, test, 61.70260625
Epoch38, training, 58.09486912458371, test, 64.1324625
Epoch39, training, 63.67588301972076, test, 65.4092875
Epoch40, training, 58.61002230699801, test, 67.6374
Epoch41, training, 62.15032804262388, test, 66.48238125
Epoch42, training, 63.876240692494356, test, 65.57199375
Epoch43, training, 58.875221510311384, test, 64.860025
Epoch44, training, 59.08009947485579, test, 62.80025
Epoch45, training, 58.45370927557245, test, 62.36695
Epoch46, training, 61.14604921496554, test, 63.17583125
Epoch47, training, 61.806070812376504, test, 63.87290625
Epoch48, training, 58.6005185089467, test, 63.44245625
Epoch49, training, 57.38292686978142, test, 63.6083875
Epoch50, training, 55.77550971646965, test, 65.32834375
Epoch51, training, 61.42551477249964, test, 62.81951875
Epoch52, training, 57.57441598051911, test, 61.8763125
Epoch53, training, 54.00586842712545, test, 62.040425
Epoch54, training, 58.01396513985587, test, 63.741275
Epoch55, training, 56.86145811369925, test, 61.1164875
Epoch56, training, 57.98246621529817, test, 62.05344375
Epoch57, training, 57.46451657079594, test, 61.52176875
Epoch58, training, 55.86559759375655, test, 59.68309375
Epoch59, training, 57.6433522062146, test, 58.66555625
Epoch60, training, 55.73156979700902, test, 61.52915625
Epoch61, training, 54.937945292546196, test, 60.2031
Epoch62, training, 62.242481160552906, test, 59.5436125
Epoch63, training, 55.624111331148306, test, 58.70915625
Epoch64, training, 60.1892326550606, test, 58.02835625
Epoch65, training, 57.382898015020054, test, 61.10570625
Epoch66, training, 57.050446552552266, test, 62.05825625
Epoch67, training, 53.6121877790331, test, 62.70536875
Epoch68, training, 58.36627696110652, test, 65.23339375
Epoch69, training, 56.821734439520846, test, 62.20115625
Epoch70, training, 59.093814867637654, test, 63.5755875
Epoch71, training, 56.52735379803708, test, 61.76349375
Epoch72, training, 56.219313361428, test, 59.89129375
Epoch73, training, 56.50808762059067, test, 63.8745
Epoch74, training, 58.00832246900438, test, 62.2881625
Epoch75, training, 60.135561322832444, test, 62.969525
Epoch76, training, 58.97952186477768, test, 62.967175
Epoch77, training, 59.207802481306736, test, 64.7482625
Epoch78, training, 55.44021341461679, test, 60.74296875
Epoch79, training, 55.22020384426161, test, 60.10155
Epoch80, training, 55.75481741133825, test, 62.73705
Epoch81, training, 52.94930788853785, test, 60.31095
Epoch82, training, 54.90460995138386, test, 63.21538125
Epoch83, training, 54.265996515056194, test, 62.39475625
Epoch84, training, 54.349244849109425, test, 60.32793125
Epoch85, training, 54.625338925626174, test, 62.575625
Epoch86, training, 55.478113894696, test, 61.2159
Epoch87, training, 57.70553330552606, test, 62.3741375
Epoch88, training, 57.59218071446274, test, 62.98685625
Epoch89, training, 59.31323751258405, test, 63.296175
Epoch90, training, 53.57663123924416, test, 61.96986875
Epoch91, training, 54.07729722403146, test, 62.744075
Epoch92, training, 60.70762411975638, test, 60.740275
Epoch93, training, 50.84226532249184, test, 61.1522625
Epoch94, training, 51.81098407576412, test, 60.18791875
Epoch95, training, 53.99923409504213, test, 62.51310625
Epoch96, training, 55.52134593438991, test, 60.09274375
Epoch97, training, 53.98442616940656, test, 61.54753125
Epoch98, training, 55.46080613747621, test, 62.35149375
Epoch99, training, 54.96678209638262, test, 62.52013125

你可能感兴趣的:(深度学习)