本例程源码来自这里。
目前我已把我自己手工敲写加注释的代码放到自己的github账户上面,项目地址在这里:https://github.com/RootYuan/tensorflow_examples_practice/。
下面是正文分割线
tensorflow一直以来是基于静态计算图的,这其实跟程序的执行过程并不一致,没办法使用python语言取控制中间流程。PyTorch 的动态图一直是 TensorFlow用户求之不得的功能。tensorflow从v1.5版本加入动态图,目前为止已更新到v1.7版本,动态图得到了进一步完善。
今天主要运行一下动态图版本的线性回归。静态计算图一般是先搭建图结构,然后使用sess.run填入数据并并运行优化器。动态图虽然思路类似,但是不在需要sess和显示grap了,就像调用函数一样方便。
另外,使用动态库开始需要import tensorflow.contrib.eager as tfe并调用tfe.enable_eager_execution()
下面是源码,我主要为了感性认识,还有很多知识点没有弄清楚,后续遇到再继续。
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow.contrib.eager as tfe
# 开启eager API
tfe.enable_eager_execution()
# 训练数据
train_X = [3.3, 4.4, 5.5, 6.71, 6.93, 4.168, 9.779, 6.182, 7.59, 2.167,
7.042, 10.791, 5.313, 7.997, 5.654, 9.27, 3.1]
train_Y = [1.7, 2.76, 2.09, 3.19, 1.694, 1.573, 3.366, 2.596, 2.53, 1.221,
2.827, 3.465, 1.65, 2.904, 2.42, 2.94, 1.3]
n_samples = len(train_X)
# 参数
learning_rate = 0.01
display_step = 100
num_steps = 1000
# 权重和偏置
W = tfe.Variable(np.random.randn())
b = tfe.Variable(np.random.randn())
# 线性回归公式函数(Wx + b)
def linear_regression(inputs):
return inputs * W + b
# 均方误差函数,计算损失
def mean_square_fn(model_fn, inputs, labels):
return tf.reduce_sum(tf.pow(model_fn(inputs) - labels, 2)) / (2 * n_samples)
# 随机梯度下降法作为优化器
optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)
# 计算梯度
grad = tfe.implicit_gradients(mean_square_fn)
# 优化之前,初始化损失函数
print("Initial cost= {:.9f}".format(
mean_square_fn(linear_regression, train_X, train_Y)),
"W=", W.numpy(), "b=", b.numpy())
# 训练
for step in range(num_steps):
optimizer.apply_gradients(grad(linear_regression, train_X, train_Y))
if (step + 1) % display_step == 0 or step == 0:
print( "Epoch:", '%04d' % (step + 1), "cost=",
"{:.9f}".format( mean_square_fn( linear_regression, train_X, train_Y ) ),
"W=", W.numpy(), "b=", b.numpy() )
# 图表显示
plt.plot(train_X, train_Y, 'ro', label='Original data')
plt.plot(train_X, np.array(W * train_X + b), label='Fitted line')
plt.legend()
plt.show()
终端输出:
Initial cost= 2.973774910 W= 0.8843342 b= -1.2450998
Epoch: 0001 cost= 1.136796951 W= 0.7314757 b= -1.2640716
Epoch: 0100 cost= 0.289310157 W= 0.51290995 b= -1.053521
Epoch: 0200 cost= 0.243507951 W= 0.4830278 b= -0.84167004
Epoch: 0300 cost= 0.207583487 W= 0.4565633 b= -0.6540487
Epoch: 0400 cost= 0.179406464 W= 0.43312556 b= -0.48788548
Epoch: 0500 cost= 0.157306090 W= 0.41236836 b= -0.34072652
Epoch: 0600 cost= 0.139971867 W= 0.3939852 b= -0.21039806
Epoch: 0700 cost= 0.126376018 W= 0.37770453 b= -0.09497542
Epoch: 0800 cost= 0.115712211 W= 0.36328587 b= 0.0072462982
Epoch: 0900 cost= 0.107348159 W= 0.35051632 b= 0.097776845
Epoch: 1000 cost= 0.100787930 W= 0.3392072 b= 0.17795336