TensorFlow 线性回归

线性回归

  • 线性回归
    • 概述
    • 代码实现
      • 导包
      • 获取训练数据
      • 构造预测函数
      • 构造损失函数
      • 调整预测函数
      • 开始训练

线性回归

概述

关于线性回归的具体介绍, 可以参考我的机器学习专栏, 在这里我们就不再赘述.

线性回归公式:
在这里插入图片描述

代码实现

导包

首先我们需要导入 tensorflow 包和 matplotlib 包.

import tensorflow as tf
from matplotlib import pyplot as plt

获取训练数据

# ------------------1. 获取训练数据------------------
x = [1.0, 2.0, 3.0, 4.0]
y = [25.0, 40.0, 60.0, 95.0]

# 模型权重
W = tf.Variable(tf.random_uniform([1]))  # 构造一个0~1的随机数
b = tf.Variable(tf.zeros([1]))  # 设b的初始值为0

构造预测函数

# ------------2.构造预测的线性回归函数 y = W * x + b------------
predict = W * x + b

构造损失函数

# ------------------3. 判断假函数的好坏------------------

# 损失函数
cost = tf.reduce_mean(tf.square(y - predict))

调整预测函数

# ------------------4. 调整假设函数------------------

# 梯度下降
optimizer = tf.train.GradientDescentOptimizer(0.05)  # 0.05 为学习率
train = optimizer.minimize(cost)

开始训练

# ------------------5. 开始训练------------------

# 加载session
with tf.Session() as sess:
    # 初始化所有变量
    init = tf.global_variables_initializer()
    sess.run(init)
    print("x=", x, "y=", y)

    # 开始训练
    for j in range(500):  # 跑500次
        sess.run(train)
        print("cost=", sess.run(cost), "W=", sess.run(W), "b=", sess.run(b))  # 调试输出

    print("运行完成")

    # 构造图形结构
    plt.plot(x, y, 'ro', label='train data')  # 画出原始数据点
    plt.plot(x, sess.run(predict), label='train result')  # 画出预测线

    plt.legend()
    plt.show()  # 可视化展示]

输出结果:
x= [1.0, 2.0, 3.0, 4.0] y= [25.0, 40.0, 60.0, 95.0]
cost= 130.54874 W= [16.869402] b= [5.2555995]
cost= 41.342194 W= [19.52845] b= [6.012689]
cost= 38.550335 W= [20.00394] b= [6.029308]
cost= 38.121033 W= [20.118658] b= [5.925392]
cost= 37.76654 W= [20.173317] b= [5.8031883]
cost= 37.424282 W= [20.217531] b= [5.67954]
cost= 37.092247 W= [20.259497] b= [5.5572033]
cost= 36.770065 W= [20.300573] b= [5.436609]
cost= 36.45749 W= [20.340992] b= [5.317805]
cost= 36.154144 W= [20.380796] b= [5.200776]
cost= 35.859837 W= [20.420006] b= [5.0855]
cost= 35.574295 W= [20.458626] b= [4.971948]
cost= 35.29721 W= [20.49667] b= [4.860097]
cost= 35.02838 W= [20.534143] b= [4.74992]
cost= 34.767563 W= [20.571054] b= [4.6413918]
cost= 34.51443 W= [20.607416] b= [4.534489]
cost= 34.26886 W= [20.64323] b= [4.429186]
cost= 34.030563 W= [20.67851] b= [4.32546]
cost= 33.799366 W= [20.713263] b= [4.223286]
cost= 33.575058 W= [20.747494] b= [4.122642]
cost= 33.3574 W= [20.781212] b= [4.0235043]
cost= 33.1462 W= [20.814426] b= [3.9258506]
cost= 32.941277 W= [20.847145] b= [3.829659]
cost= 32.74243 W= [20.879372] b= [3.734907]
cost= 32.549538 W= [20.911116] b= [3.6415734]
cost= 32.362335 W= [20.942387] b= [3.5496373]
cost= 32.180706 W= [20.973186] b= [3.459077]
cost= 32.00448 W= [21.003529] b= [3.3698726]
cost= 31.833466 W= [21.033413] b= [3.2820032]
cost= 31.667576 W= [21.062853] b= [3.1954496]
cost= 31.506586 W= [21.09185] b= [3.1101913]
cost= 31.350405 W= [21.120415] b= [3.0262096]
cost= 31.198845 W= [21.148552] b= [2.943485]
cost= 31.05177 W= [21.176268] b= [2.8619986]
cost= 30.909086 W= [21.203568] b= [2.7817318]
cost= 30.770641 W= [21.23046] b= [2.7026668]
cost= 30.63631 W= [21.256948] b= [2.6247852]
cost= 30.505955 W= [21.283041] b= [2.5480695]
cost= 30.379513 W= [21.308743] b= [2.4725022]
cost= 30.256819 W= [21.33406] b= [2.3980663]
cost= 30.137733 W= [21.359] b= [2.3247445]
cost= 30.022242 W= [21.383564] b= [2.25252]
cost= 29.910137 W= [21.40776] b= [2.181377]
cost= 29.801365 W= [21.431597] b= [2.111299]
cost= 29.695837 W= [21.455074] b= [2.04227]
cost= 29.593433 W= [21.4782] b= [1.9742745]
cost= 29.494081 W= [21.500982] b= [1.907297]
cost= 29.397697 W= [21.523422] b= [1.841322]
cost= 29.304142 W= [21.545525] b= [1.776334]
cost= 29.213396 W= [21.567297] b= [1.7123195]
cost= 29.125343 W= [21.588745] b= [1.6492634]
cost= 29.039907 W= [21.60987] b= [1.5871509]
cost= 28.957 W= [21.63068] b= [1.5259682]
cost= 28.876587 W= [21.651178] b= [1.4657012]
cost= 28.798529 W= [21.67137] b= [1.4063364]
cost= 28.722786 W= [21.69126] b= [1.3478605]
cost= 28.649323 W= [21.71085] b= [1.2902595]
cost= 28.578009 W= [21.730146] b= [1.2335209]
cost= 28.508812 W= [21.749157] b= [1.1776322]
cost= 28.441715 W= [21.767881] b= [1.1225798]
cost= 28.376596 W= [21.786325] b= [1.0683514]
cost= 28.313396 W= [21.804493] b= [1.0149348]
cost= 28.25209 W= [21.82239] b= [0.9623179]
cost= 28.192574 W= [21.840017] b= [0.9104886]
cost= 28.134865 W= [21.857382] b= [0.85943544]
cost= 28.078865 W= [21.874487] b= [0.8091464]
cost= 28.024508 W= [21.891336] b= [0.7596101]
cost= 27.971792 W= [21.90793] b= [0.71081495]
cost= 27.920624 W= [21.924278] b= [0.6627508]
cost= 27.870962 W= [21.940382] b= [0.6154062]
cost= 27.822811 W= [21.956244] b= [0.56877005]
cost= 27.776054 W= [21.971869] b= [0.52283216]
cost= 27.730717 W= [21.987259] b= [0.47758183]
cost= 27.686745 W= [22.002419] b= [0.4330089]
cost= 27.644035 W= [22.017353] b= [0.3891034]
cost= 27.60263 W= [22.032063] b= [0.34585467]
cost= 27.562408 W= [22.04655] b= [0.3032534]
cost= 27.52344 W= [22.060825] b= [0.2612905]
cost= 27.485584 W= [22.074883] b= [0.21995495]
cost= 27.44887 W= [22.088732] b= [0.17923878]
cost= 27.413267 W= [22.102373] b= [0.13913201]
cost= 27.378666 W= [22.11581] b= [0.09962548]
cost= 27.345165 W= [22.129047] b= [0.06071051]
cost= 27.312603 W= [22.142084] b= [0.02237757]
cost= 27.281033 W= [22.154926] b= [-0.01538121]
cost= 27.250397 W= [22.167576] b= [-0.05257479]
cost= 27.220654 W= [22.180037] b= [-0.08921143]
cost= 27.191793 W= [22.192312] b= [-0.12529947]
cost= 27.16383 W= [22.204403] b= [-0.16084749]
cost= 27.136692 W= [22.216312] b= [-0.19586335]
cost= 27.110332 W= [22.228043] b= [-0.23035532]
cost= 27.084743 W= [22.2396] b= [-0.2643305]
cost= 27.059952 W= [22.250982] b= [-0.29779723]
cost= 27.0359 W= [22.262196] b= [-0.33076295]
cost= 27.012539 W= [22.27324] b= [-0.3632355]
cost= 26.989891 W= [22.284119] b= [-0.39522174]
cost= 26.967892 W= [22.294836] b= [-0.42672914]
cost= 26.946571 W= [22.305391] b= [-0.45776534]
cost= 26.925846 W= [22.315788] b= [-0.4883368]
...

TensorFlow 线性回归_第1张图片

你可能感兴趣的:(Python,深度学习,#,Tensorflow,入门)