李航《统计学习方法》例2.1 感知机学习算法的原始形式——实现代码

# 2020.01.11  By yangbocsu

例2.1 如图2.2所示的训练数据集,其正实例点是x1= (3,3)T, x2= (4,3)T,负实例点是:x3=(1,1)",试用感知机学习算法的原始形式求感知机模型f(x)=sign(wx+b)。这里,w= (w(1), w(2))T", x= (x(1), x(2))T。

李航《统计学习方法》例2.1 感知机学习算法的原始形式——实现代码_第1张图片

 

 

 

参考代码:

# 2020.01.11  By yangbocsu
import numpy as np
import matplotlib.pyplot as plt
import time

#train data
x  = np.array([[3, 3], [1, 1],[4, 3]])
y = np.array([1, -1, 1])#标签

#参数初始化
w = np.array([0,0])
b =np.array([0])
learning_rate = 1



#判断是否误分类 函数  
def judge_class(x,y,w2,b):
    misclassification = False
    cnt = 0
    mis_index = 0
    for i in range(len(x)):
        if y[i]*(np.dot(w2,x[i])+b) <= 0:
            cnt += 1
            mis_index = i
            break  #发现有误分类点 记录其下标就直接退出即可;很重要
    if cnt > 0 :
        misclassification = True
    return misclassification, mis_index

# 参数w, b 的更新
def update(x, y, w1, b, i):
    w1 = w1 + learning_rate*y[i]*x[i]
    b = b + learning_rate*y[i]
    return w1, b
    
#更新迭代
t1 = time.time()
misclassification, mis_index = judge_class(x,y,w,b)
#while misclassification:  # 可以不指定跟新次数,用misclassification来判断
for j in range(7):#书上跟新到第7次   P41
    #print(x)
    w, b = update(x, y, w, b, mis_index)
    print("第{}次循环,w={},b={}".format(j,w,b))
    misclassification, mis_index = judge_class(x,y,w,b)
t2 = time.time()   
print("w={}".format(w))
print("b={}".format(b))
print("感知机学习算法的原始形式共用:{}ms".format((t2-t1)*1000))

#绘制图
#w1*x1+w2*x2+b=0
#matplotlib画图中中文显示会有问题,需要这两行设置默认字体
plt.rcParams['font.sans-serif']=['SimHei']
plt.rcParams['axes.unicode_minus'] = False
 
 
#画的坐标轴并设置轴标签x,y 
plt.xlabel('X')
plt.ylabel('Y')
plt.xlim(xmax=12,xmin=0)
plt.ylim(ymax=12,ymin=0)
 
#散点图数据
x1=np.array([3,4])
y1=np.array([3,3])

x2 = np.arange(0,5,0.1)
y2 = -((w[0])*x2 +b)/(w[1]+0)

x3 = np.array([1])
y3 = np.array([1])
 
plt.scatter(x1, y1, marker = 'o', alpha=0.4,color="red", label='正类')
plt.scatter(x3, y3, marker = 'x', alpha=0.4,color="red", label='负类')
plt.plot(x2,y2,label="拟合直线",c='orange')

plt.legend() #label='正类' 图中显示
plt.title('感知机原始形')#图的标题
plt.show()


 

李航《统计学习方法》例2.1 感知机学习算法的原始形式——实现代码_第2张图片

 

李航《统计学习方法》例2.1 感知机学习算法的原始形式——实现代码_第3张图片

你可能感兴趣的:(统计学习方法)