tensorflow2.0 入门学习——tf.keras实现线性回归

tensorflow2.0 入门学习——tf.keras实现线性回归

一、 简介:
tensorflow 安装(cpu):pip install tensorflow-cpu==2.3.0 -i https://pypi.douban.com/simple/

二、 实操:
数据集例子:
“教师授课年份(Education)”与“收入(Income)”之间的线性关系
Education Income
1 10 26.658839
2 10.401338 27.306435
3 10.842809 22.13241
4 11.244147 21.169841
5 11.645449 15.192634
6 12.086957 26.398951
7 12.048829 17.435307
……
28 21.167191 77.355021
29 21.598662 72.11879
30 22 80.260571
*完整数据集附于文章末尾

准备数据集:
将其拷贝到一个文本文件中并保存,之后再改名为 .csv
在这里插入图片描述
代码:

  1. 数据集的读取
#pandas 是基于NumPy的一种工具,该工具是为解决数据分析任务而创建的。
import pandas as pd 
#读取数据集文件
data = pd.read_csv("./Incomel.csv")     
  1. plt绘制散点图查看线性关系(可选)
#用来画图的库
import matplotlib.pyplot as plt
#指定x,y轴绘制散点图更加直观
plt.scatter(x=data.Education,y=data.Income)
plt.show()       

tensorflow2.0 入门学习——tf.keras实现线性回归_第1张图片

  1. 引入tensorflow构建模型
    在tf.keras中有两种类型的模型,tf.keras.Sequential 和 函数式模型(Model)
    最常见的模型是层的堆叠:tf.keras.Sequential,层的线性叠加。
    线性回归问题用Sequential,复杂网络用Model
import tensorflow as tf
model = tf.keras.Sequential()#构建Dense()层,构建y=wx+b关系
	#Dense(输出数据维度,输入数据维度)
layer_1 = tf.keras.layers.Dense(1,input_shape=(1,))    
model.add(layer_1 ) #为模型添加Dense()层
info = model.summary()  #可查看当前模型信息
print(info)

tensorflow2.0 入门学习——tf.keras实现线性回归_第2张图片
Model: “sequential” ---- 顺序模型
Layer (type) — 层的类型 Output Shape — 模型输出维度(None, 1)代表1维,batch=None
Param — 参数个数 因为是y=wx+b w、b两个参数

  1. 设置损失、优化函数开始模型训练
# 为指定模型训练用的优化算法和损失函数
	# 优化算法:adam,损失函数:mean square error
model.compile(optimizer="adam",loss="mse")
# 开始训练,指定数据集的input_x、output_y、训练次数epochs=3000
model.fit(x=data.Education,y=data.Income,epochs=3000)
  

自动打印:
tensorflow2.0 入门学习——tf.keras实现线性回归_第3张图片

  1. 模型预测:
# 用训练好的模型做一次测试(这里用数据集的x做演示)
model.predict(x=data.Education)
# 也可以输入自定义的数据20,不过要构造出pd.Series()对象
model.predict(x=pd.Series([20]))

tensorflow2.0 入门学习——tf.keras实现线性回归_第4张图片
在这里插入图片描述

完整数据集:
Education Income
1 10 26.658839
2 10.401338 27.306435
3 10.842809 22.13241
4 11.244147 21.169841
5 11.645449 15.192634
6 12.086957 26.398951
7 12.048829 17.435307
8 12.889632 25.507885
9 13.29097 36.884595
10 13.732441 39.666109
11 14.133779 34.396281
12 14.635117 41.497994
13 14.978589 44.981575
14 15.377926 47.039595
15 15.779264 48.252578
16 16.220736 57.034251
17 16.622074 51.490919
18 17.023411 51.336621
19 17.464883 57.681998
20 17.866221 68.553714
21 18.267559 64.310925
22 18.70903 68.959009
23 19.110368 74.614639
24 19.511706 71.867195
25 19.913043 76.098135
26 20.354515 75.775216
27 20.755853 72.486055
28 21.167191 77.355021
29 21.598662 72.11879
30 22 80.260571

你可能感兴趣的:(tensorflow,深度学习,python,机器学习,算法)