李航《统计学习方法》第二章-感知机的python实现

重点:

  1. 感知机是一种二类分类的线性分类模型,属于判别模型。感知机对应于特征空间中的分离超平面 w*x+b=0
  2. 损失函数:误分类点到分离超平面的总距离。
  3. 学习算法:随机梯度下降法。有原始和对偶两种形式。
  4. 当训练数据线性可分时,感知机学习算法存在无穷多解,其解由不同初值和迭代顺序而可能不同。

李航《统计学习方法》第二章-感知机的python实现_第1张图片


实现代码:

import numpy as np  
import matplotlib.pyplot as plt  
p_x = np.array([[3, 3], [4, 3], [1, 1]])  
y = np.array([1, 1, -1])   
plt.figure()  
for i in range(len(p_x)):  
    if y[i] == 1:  
        plt.plot(p_x[i][0], p_x[i][1], 'ro')  
    else:  
        plt.plot(p_x[i][0], p_x[i][1], 'bo')  
        
# 初始权重w0,偏置b0,学习率delta=1
w = np.array([1, 0])  
b = 0  
delta = 1  
  
for i in range(1000):  
    choice = -1  
    #选取一个错误分类的点,计算其梯度下降
    for j in range(len(p_x)):  
        if y[j] != np.sign(np.dot(w, p_x[0]) + b):  
            choice = j  
            break  
    if choice == -1:  
        break  
    # 学习权重和偏置
    w = w + delta * y[choice]*p_x[choice]  
    b = b + delta * y[choice]  
  
line_x = [0, 20]  
line_y = [0, 0]  
  
for i in range(len(line_x)):  
    line_y[i] = (-w[0] * line_x[i]-b)/w[1] 
    
   
plt.plot(line_x, line_y)  
plt.savefig("picture.png")  

运行结果:

李航《统计学习方法》第二章-感知机的python实现_第2张图片

注意:作为数据驱动的学习算法,数据点太少,可能学习不到最后的分类超平面。


你可能感兴趣的:(机器学习)