tf5. 构建基础网络

import sys
print(sys.version)
'''
3.5.3 |Continuum Analytics, Inc.| (default, May 15 2017, 10:43:23) [MSC v.1900 64 bit (AMD64)]
'''
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

def add_layer(inputs, in_size, out_size, activation_function=None):
    #weights为一个in_size行, out_size列的随机变量矩阵
    Weights = tf.Variable(tf.random_normal([in_size, out_size]))
    #在机器学习中,biases的推荐值不为0,所以我们这里是在0向量的基础上又加了0.1。
    biases = tf.Variable(tf.zeros([1, out_size]) + 0.1)
    Wx_plus_b = tf.matmul(inputs, Weights) + biases
    #激励函数为None时,输出就是当前的预测值——Wx_plus_b,不为None时,就把Wx_plus_b传到activation_function()函数中得到输出
    if activation_function is None:
        outputs = Wx_plus_b
    else:
        outputs = activation_function(Wx_plus_b)
    return outputs

x_data = np.linspace(-1,1,300, dtype=np.float32)[:, np.newaxis]
noise = np.random.normal(0, 0.05, x_data.shape).astype(np.float32)
y_data = np.square(x_data) - 0.5 + noise
#这里的None代表无论输入有多少都可以,因为输入只有一个特征,所以这里是1
xs = tf.placeholder(tf.float32, [None, 1])
ys = tf.placeholder(tf.float32, [None, 1])
#通常神经层都包括输入层、隐藏层和输出层。这里的输入层只有一个属性, 所以我们就只有一个输入;隐藏层我们可以自己假设,这里我们假设隐藏层有10个神经元;
# 输出层和输入层的结构是一样的,所以我们的输出层也是只有一层。 所以,我们构建的是——输入层1个、隐藏层10个、输出层1个的神经网络。
l1 = add_layer(xs, 1, 10, activation_function=tf.nn.relu)
#接着,定义输出层。此时的输入就是隐藏层的输出——l1,输入有10层(隐藏层的输出层),输出有1层。
prediction = add_layer(l1, 10, 1, activation_function=None)
#对二者差的平方求和再取平均
loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys - prediction),reduction_indices=[1]))
#接下来,是很关键的一步,如何让机器学习提升它的准确率。tf.train.GradientDescentOptimizer()中的值通常都小于1,这里取的是0.1,代表以0.1的效率来最小化误差loss。
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)

fig = plt.figure()
ax = fig.add_subplot(1,1,1)
ax.scatter(x_data, y_data)
plt.ion()#持续画图,不暂停
plt.show()

for i in range(1000):
    # training
    sess.run(train_step, feed_dict={xs: x_data, ys: y_data})
    if i % 50 == 0:
        # to see the step improvement
        print(sess.run(loss, feed_dict={xs: x_data, ys: y_data}))
        try:
            ax.lines.remove(lines[0])
        except Exception:
            pass
        prediction_value = sess.run(prediction, feed_dict={xs: x_data})
        # plot the prediction
        lines = ax.plot(x_data, prediction_value, 'r-', lw=5)
        plt.pause(1)
tf5. 构建基础网络_第1张图片

你可能感兴趣的:(tf5. 构建基础网络)