算法逻辑请参考吴恩达机器学习相关视频
梯度下降法
import numpy as np
import matplotlib.pyplot as plt
def compute_cost(theta0, theta1, X, y):
total_cost = 0
M = X.shape[0]
for i in range(M):
x_i = X[i, 0]
y_i = y[i, 0]
total_cost += (theta0 + theta1 * x_i - y_i) ** 2
return total_cost / (2 * M)
def gradient_descent(X, y, init_theta0, init_theta1, alpha, num_iterations):
cost_list = []
theta0 = init_theta0
theta1 = init_theta1
for i in range(num_iterations):
cost_list.append(compute_cost(theta0, theta1, X, y))
theta0, theta1 = step_gradient_descent(theta0, theta1, alpha, X)
return theta0, theta1, cost_list
def step_gradient_descent(theta0, theta1, alpha, X):
grad_theta0 = 0
grad_theta1 = 0
M = X.shape[0]
for i in range(M):
x_i = X[i, 0]
y_i = y[i, 0]
grad_theta0 += (theta0 + theta1 * x_i - y_i)
grad_theta1 += (theta0 + theta1 * x_i - y_i) * x_i
grad_theta0 = grad_theta0 / M
grad_theta1 = grad_theta1 / M
updated_theta0 = theta0 - alpha * grad_theta0
updated_theta1 = theta1 - alpha * grad_theta1
return updated_theta0, updated_theta1
X = 2 * np.random.rand(100, 1)
y = 4 + 4 * X + np.random.randn(100, 1)
theta0, theta1, cost_list = gradient_descent(X, y, init_theta0=0, init_theta1=0, alpha=0.005, num_iterations=1000)
print("theta0 is :", theta0)
print("theta1 is :", theta1)
cost = compute_cost(theta0, theta1, X, y)
print("cost_list:", cost_list)
print("cost is:", cost)
plt.plot(cost_list)
plt.show()
plt.subplots(figsize=(12, 8))
plt.plot(X, y, "b.")
reg_model = theta0 + theta1 * X
plt.plot(X, reg_model, c='r')
plt.show()
theta0 is : 4.74244227311996
theta1 is : 3.1105226654110716
cost is: 0.08444660160107187
正规方程
import numpy as np
import matplotlib.pyplot as plt
X = 2 * np.random.rand(100, 1)
X_prev = X
y = 4 + 4 * X + np.random.randn(100, 1)
X0 = np.ones((100, 1))
X = np.c_[X0, np.matrix(X)]
Y = np.matrix(y)
A = np.dot(X.T, X)
theta = np.linalg.inv(A) @ X.T @ Y
print(theta)
theta0 = np.float_(theta[0][0])
theta1 = np.float_(theta[1][0])
plt.subplots(figsize=(12, 8))
plt.plot(X_prev, y, "b.")
reg_model = theta0 + theta1 * X_prev
plt.plot(X_prev, reg_model, c='r')
plt.show()