自己动手建一个二分类的网络模型,所有代码都是自己一个字母一个字母敲进去的。训练数据正样本是半径为 4 、圆心在原点的园内。我建了一个两层的神经网络,两个输入节点(x0, x1)代表平面坐标,一个输出 y 代表正负样本标识。代码如下:
import tensorflow as tf
import numpy as np
TRAIN_SIZE = 100000
TEST_SIZE = 1000
LARYER1_SIZE = 16
CIRCLE_R = 4.0
training_epochs = 10000
learning_rate = 2.0
#生成训练数据
train_x = (np.random.rand(TRAIN_SIZE, 2) - 0.5) * 10
train_y = np.empty([TRAIN_SIZE, 1], dtype = float)
for i in range(TRAIN_SIZE):
x0 = train_x[i, 0]
x1 = train_x[i, 1]
if x0*x0 + x1*x1 < CIRCLE_R * CIRCLE_R:
train_y[i] = 1.0
else:
train_y[i] = 0.0
#生成测试数据
test_x = (np.random.rand(TEST_SIZE, 2) - 0.5) * 10
test_y = np.empty([TEST_SIZE, 1], dtype = float)
for i in range(TEST_SIZE):
x0 = test_x[i, 0]
x1 = test_x[i, 1]
if x0*x0 + x1*x1 < CIRCLE_R * CIRCLE_R:
test_y[i] = 1.0
else:
test_y[i] = 0.0
#定义模型
x = tf.placeholder(tf.float32, [None, 2])
y = tf.placeholder(tf.float32, [None, 1])
w1 = tf.Variable(tf.random_normal([2, LARYER1_SIZE]))
b1 = tf.Variable(tf.zeros([1, LARYER1_SIZE]))
w2 = tf.Variable(tf.random_normal([LARYER1_SIZE, 1]))
b2 = tf.Variable(tf.zeros([1]))
layer1 = tf.sigmoid(tf.matmul(x, w1) + b1)
pred = tf.sigmoid(tf.matmul(layer1, w2) + b2)
cost = tf.reduce_mean(tf.abs(y-pred))
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for i in range(training_epochs):
_, c = sess.run([optimizer, cost], feed_dict = {x: train_x, y: train_y})
print ("Epoch:", "%04d" % (i+1), "cost = ", "{:9f}".format(c))
print("finished!")
correct_prediction = tf.equal(tf.greater(pred, 0.5), tf.greater(y, 0.5))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print("Accuracy_train: ", accuracy.eval({x: train_x, y: train_y}))
print("Accuracy_test : ", accuracy.eval({x: test_x, y: test_y}))
运行结果如下:
... ...
Epoch: 9997 cost = 0.023868
Epoch: 9998 cost = 0.023867
Epoch: 9999 cost = 0.023866
Epoch: 10000 cost = 0.023865
finished!
Accuracy_train: 0.99499995
Accuracy_test : 0.99600005
第一次实战,趟了很多坑。我觉得自己用 C 语言写代码,都会比用 Tensorflow 效率高。不过作为最流行的深度学习框架,还是必须熟练掌握的。这几年时间都浪费在非代码工作上了,心疼呀!