机器学习笔记:线性回归

线性回归是利用数理统计中回归分析,来确定两种或两种以上变量间相互依赖的定量关系的一种统计分析方法。线性回归在假设特证满足线性关系,根据给定的训练数据训练一个模型,并用此模型进行预测。

有一组“工龄 - 工资”的数据表,我们假设它满足线性关系y = a + bx,其中x为工龄,y为工资。

工龄:0  1  2  3  4  5  6  7  8  9  10
工资:103100, 104900, 106800, 108700, 110400, 112300, 114200, 116100, 117800, 119700, 121600

定义损失函数J(a, b) ,求其偏导,得到梯度下降的公式。推导过程如下:


机器学习笔记:线性回归_第1张图片
线性回归

示例代码如下:

import matplotlib.pyplot as plt
import numpy as np

y = (103100, 104900, 106800, 108700, 110400, 112300, 114200, 116100, 117800, 119700, 121600)

def calc_diff_a(a, b):
    sum = 0
    for x in range(0, 11):
        sum = sum + 2 * a + 2 * b * x  - 2 * y[x]
    return sum

def calc_diff_b(a, b):
    sum = 0
    for x in range(0, 11):
        sum = sum + x * (2 * a + 2 * b * x  - 2 * y[x])
    return sum

def cost(a, b):
    sum = 0
    for x in range(0, 11):
        sum = sum + (a*a + b*b*x*x + 2*a*b*x - 2*a*y[x] - 2*b*x*y[x] + y[x]*y[x])
    return sum;

if __name__ == "__main__":
    num1 = 100000
    num2 = 1
    ratio = 0.0001
    itercnt = 0
    while itercnt < 50000:
        tmp1 = calc_diff_a(num1, num2)
        tmp2 = calc_diff_b(num1, num2)
        num1 = num1 - ratio * tmp1
        num2 = num2 - ratio * tmp2
        itercnt = itercnt + 1
        #print(tmp1, tmp2, cost(num1, num2))

    print(num1, num2)
    listx = np.linspace(0,10,11)
    listy = num1 + num2 * listx
    plt.figure()
    plt.plot(listx, y, '*')
    plt.plot(listx, listy)
    plt.show()

运行结果如下:
a = 103086.36363635205
b = 1848.181818183475


机器学习笔记:线性回归_第2张图片
运行结果

你可能感兴趣的:(机器学习笔记:线性回归)