TensorFlow神经网络拟合抛物线

最近学习Python和TensorFlow,闲来无聊,自己写一个小程序玩玩吧!做什么呢?本人想了想,那就做一个深度学习拟合抛物线的列子吧!

    抛物线的数学表达式是y=ax^2+bx+c,我们就来模拟:y=0.3x^2-0.2X+0.5

1)我们需要构造样本,样本就从上面表达式来,加入一些噪音。x样本的分布通过Numpy库来生成。y的噪音也通过Numpy.random.normal来生成

2)输入样本数据的维度是1(其实就是一个普通的标量),输出数据的维度也是一个标量(维度为1),我们构建神经网络的时候一定要非常清楚自己输入和输出。这里是一个回归问题,不是一个分类问题,和图像等分类是不一样的。它输出的目标值一个确定的值,而不是分类里的我们设置分类维度输出(Onehot模式)。3)

3)神经网络经过1-2分钟训练后,对抛物线拟合效果不错,如下图:

TensorFlow神经网络拟合抛物线_第1张图片

import  tensorflow as tf
import  numpy as np
import  matplotlib.pyplot as plt
import  os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

#产生样本数据和图形
numdots=300
inputdata=[]
for i in range(numdots):
    x=np.random.normal(0.8,10)
    y=0.3*x*x-0.2*x+0.5+np.random.normal(0,6)
    inputdata.append([x,y])
x_data=[v[0] for v in inputdata]
y_data=[v[1] for v in inputdata]

plt.scatter(x_data,y_data)
plt.show()

#构建神经网络模式
x_data=np.array([x_data]).reshape(-1,1)
y_data=np.array([y_data]).reshape(-1,1)

x_h=tf.placeholder(dtype=tf.float64,shape=[None,1],name="xh")
y_h=tf.placeholder(dtype=tf.float64,shape=[None,1],name="yh")

w=tf.Variable(np.random.normal(0,0.3,size=[1,20]),dtype=tf.float64)
b=tf.Variable(np.random.normal(0.0,0.5,size=[20]),dtype=tf.float64)

y0=tf.nn.relu(tf.matmul(x_h,w)+b)

w1=tf.Variable(np.random.normal(0,0.5,size=[20,1]),dtype=tf.float64)
b1=tf.Variable(np.random.normal(0,0.8,size=[1]),dtype=tf.float64)

y=(tf.matmul(y0,w1)+b1)
los=tf.reduce_mean((tf.square(y-y_h)))
tran=tf.train.GradientDescentOptimizer(0.001).minimize(los)
#开始训练神经网络
sess=tf.Session()
sess.run(tf.global_variables_initializer())
#for i in range(100000):
for i in range(50000):
    sess.run(tran,feed_dict={x_h:x_data,y_h:y_data})
    if(i%100==0 ):
        los2=sess.run(los,feed_dict={x_h:x_data,y_h:y_data})
        print("epoch=", i, ",los=", los2)
        if (los2 <= 1):
            break
#用神经网络计算X上的所有点的Y值,绘制图形,看效果
testx=np.linspace(-30,30,60,dtype=np.float32).reshape(-1,1)
testy=sess.run(y,feed_dict={x_h:testx})
#plt.scatter(testx,testy)
plt.scatter(x_data,y_data)
plt.plot(testx,testy)
plt.show()



你可能感兴趣的:(AI,TensorFlow)