代码:
# -*- coding:utf-8 -*-
def gradient_descent(xi, yi):
theta0, theta1 = 0, 0 # 初始化
m = len(xi)
alpha = 0.01 # 学习率
max_step = 20000 # 学习次数
count = 0
epsilon = 0.1 # 误差临界值
while True:
num1, num2 = 0, 0
for i in range(m):
num1 += theta0 + theta1 * xi[i] - yi[i]
num2 += (theta0 + theta1 * xi[i] - yi[i]) * xi[i]
# update theta
theta0 = theta0 - alpha * num1 / m
theta1 = theta1 - alpha * num2 / m
# print('theta0', theta0)
# print('theta1', theta1)
error = 0
for i in range(m):
error += (theta0 + theta1 * xi[i] - yi[i]) ** 2
if error <= epsilon:
break
count += 1
if count > max_step:
break
return theta1, theta0, error
if __name__ == '__main__':
xi = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
yi = [10, 11.5, 12, 13, 14.5, 15.5, 16.8, 17.3, 18, 18.7]
a, b, error = gradient_descent(xi, yi)
print("y = %10.5fx + %10.5f" % (a, b))
print("error: ", error)
Reference:
https://blog.csdn.net/troysps/article/details/80247320
https://blog.csdn.net/fenghuibian/article/details/52670806