【Python可视化】使用梯度下降法实现简单线性回归 动态展示

【Python可视化】实现简单线性回归

梯度下降法的简单线性回归,原理见西瓜书第3章,这里只放Python实现代码。

import numpy as np
import matplotlib.pyplot as plt


def gradient(gt_y, pred_y, x):
    N = len(x)
    diff = (1 / N) * (pred_y - gt_y)
    dw = np.dot(diff, x)  # 这里的dw是dloss/dw
    db = np.dot(diff, np.array(np.ones(shape=(N, 1), dtype=np.float)))
    return dw, db


def train(w, b, x, gt_y, lr, max_iter):
    N = len(x)

    plt.figure()
    plt.ion()

    for num in range(max_iter):
        pred_y = np.dot(w, x) + b
        delta = pred_y - gt_y
        w = w - lr * gradient(gt_y, pred_y, x)[0]
        b = b - lr * gradient(gt_y, pred_y, x)[1]
        loss = (1 / N) * np.dot(delta, delta.T)
        # print(loss)
        # print('w: {}, b: {}, loss:{}'.format(w, b, loss))

        plt.clf()
        plt.scatter(x, gt_y)
        plt.plot(x, pred_y, 'r-', lw=3)
        plt.xlim(0, 8)
        plt.ylim(0, 15)
        plt.title('Iteration:%d' % num)
        plt.text(4, 4, 'loss=%.4f' % loss)
        plt.text(4, 3, 'Y=%.4fx+%.4f' % (w, b))
        plt.pause(0.01)
        plt.show()

        if num % 20 == 0:
            print('Iteration:%d \tY=%.4fx+%.4f \tLoss=%.4f' % (num, w, b, loss))


if __name__ == '__main__':
    x = [0.8, 2.4, 2.7, 3.7, 5.2]  # 测试数据1
    gt_y = [4.8, 5.7, 6.99, 8.54, 12.4]  # gt_y = 1.5*x+3.6
    train(0, 0, x, gt_y, 0.01, 2000)

下面是使用测试数据的运行结果,学习率和迭代次数可调,也可以改成当loss低于某个阈值后中断循环。

【Python可视化】使用梯度下降法实现简单线性回归 动态展示_第1张图片

你可能感兴趣的:(【Python可视化】使用梯度下降法实现简单线性回归 动态展示)