使用tensorflow2.0.0a0 api实现线性回归算法

如标题,tf2相较tf1.x在api上有比较大的变动,1.x的很多api都在2.0中移除。
本文使用tf2.0的api实现一个简单的线性回归算法。

import tensorflow as tf

print(tf.__version__)

import numpy as np
import matplotlib.pyplot as plt

np.random.seed(199346745)

# 产生测试数据
w_known = [1.4, 0.4, -0.4, .3, -1.9]
DIM = len(w_known)
N = 1000
BATCH = 300

x = np.random.random((N, DIM))
# default DIM=5
y_ = sum(w_known[i]*x[:,i] for i in range(len(w_known)))
err = 0.01*np.random.normal(size=N)
y = (y_ + err).reshape((N, 1))

# 查看测试数据分布
plt.hist(err, 30)
plt.show()

plt.hist(y.reshape(1000), 30)
plt.show()

# 得益于tf2.0的动态图特性,可以在函数中直接循环训练
def lr(x, y, BATCH=None, niter=1000):
    if BATCH is None:
        BATCH = N
    losses = []
    w = tf.Variable(tf.random.normal(shape=(DIM, 1), mean=0))
    for i in range(niter):
        randidx = np.random.choice(N, size=BATCH)
        x2, y2 = (tf.constant(x[randidx], dtype='float32'),
                  tf.constant(y[randidx], dtype='float32'))
        loss = lambda: tf.losses.MeanSquaredError()(tf.matmul(x2, w), y2)
        opt = tf.keras.optimizers.SGD(1e-1)
        opt.minimize(loss, var_list=[w])
        losses.append(loss().numpy())
    return w, losses

w, losses = lr(x, y, 300)

# 查看loss的变化
plt.plot(losses)
plt.show()

# 查看w和设定的w的距离
tf.losses.MeanSquaredError()(w, tf.reshape(tf.constant(w_known),(5,1)))


你可能感兴趣的:(使用tensorflow2.0.0a0 api实现线性回归算法)