2. 使用keras-神经网络来做线性回归问题

代码:

  • 导入库:
import keras
import numpy as np
import matplotlib.pyplot as plt

from keras.models import Sequential # 按顺序构成的模型
from keras.layers import Dense # 全连接层
  • 生成随机数

2. 使用keras-神经网络来做线性回归问题_第1张图片

# 使用Numpy生成随机点
x_data = np.random.rand(100)

noise = np.random.normal(0, 0.01, x_data.shape) # 加上噪音
y_data = x_data*0.1 + 0.2 + noise  # 构建 y 的数据

# 显示数据点(散点图)
plt.scatter(x_data, y_data)
plt.show()

2. 使用keras-神经网络来做线性回归问题_第2张图片

  • 创建模型并训练
# 构建一个顺序模型
model = Sequential()

# 在模型中添加一个全连接层
model.add(Dense(units=1, input_dim=1)) # 输入1维数据,输出1维数据

# 优化器:sgd:Stochastic gradient descent,随机梯度下降
# 损失函数:mse:Mean Squared Error,均方误差
model.compile(optimizer='sgd', loss='mse')

# 训练5001个批次
for step in range(5001): # 迭代次数越多,结果会越精确
    # 每次训练一个批次
    cost = model.train_on_batch(x_data,y_data)
    # 每500个batch打印一次loss值
    if step%500 == 0:
        print('cost', cost)

# 打印权值和偏置值
W,b = model.layers[0].get_weights()
print('W: ',W,'  b: ',b)


# 用模型进行预测
y_pred = model.predict(x_data)

# 显示随机点
plt.scatter(x_data,y_data)
# 显示预测结果
plt.plot(x_data, y_pred,'r-', lw=3)
plt.show()

2. 使用keras-神经网络来做线性回归问题_第3张图片

总代码:

import keras
import numpy as np
import matplotlib.pyplot as plt

from keras.models import Sequential # 按顺序构成的模型
from keras.layers import Dense # 全连接层

# 使用Numpy生成随机点
x_data = np.random.rand(100)

noise = np.random.normal(0, 0.01, x_data.shape) # 加上噪音
y_data = x_data*0.1 + 0.2 + noise  # 构建 y 的数据

# 显示数据点(散点图)
plt.scatter(x_data, y_data)
plt.show()

# 构建一个顺序模型
model = Sequential()

# 在模型中添加一个全连接层
model.add(Dense(units=1, input_dim=1)) # 输入1维数据,输出1维数据

# 优化器:sgd:Stochastic gradient descent,随机梯度下降
# 损失函数:mse:Mean Squared Error,均方误差
model.compile(optimizer='sgd', loss='mse')

# 训练5001个批次
for step in range(5001): # 迭代次数越多,结果会越精确
    # 每次训练一个批次
    cost = model.train_on_batch(x_data,y_data)
    # 每500个batch打印一次loss值
    if step%500 == 0:
        print('cost', cost)

# 打印权值和偏置值
W,b = model.layers[0].get_weights()
print('W: ',W,'  b: ',b)


# 用模型进行预测
y_pred = model.predict(x_data)

# 显示随机点
plt.scatter(x_data,y_data)
# 显示预测结果
plt.plot(x_data, y_pred,'r-', lw=3)
plt.show()

参考:

视频: 覃秉丰老师的“Keras入门”:http://www.ai-xlab.com/course/32
博客参考:https://www.cnblogs.com/XUEYEYU/tag/keras%E5%AD%A6%E4%B9%A0/

你可能感兴趣的:(Keras)