李航《统计学习方法》例2.12感知机学习算法的对偶形式——实现代码

例2.2 数据同例2.1, 正样本点是x1 = (3,3)T, x2 = (4,3)T,负样本点是X3=(1,1)T,试用感知机学习算法对偶形式求感知机模型。
 

 

李航《统计学习方法》例2.12感知机学习算法的对偶形式——实现代码_第1张图片

 

暂时:

# 2020.01.11  By yangbocsu
import numpy as np
import matplotlib.pyplot as plt


#2.2 (对偶形式)
# python
x = np.array([[1,1],[4,3],[3,3]])
y = np.array([-1,1,1])#标签
x_transpose = x.T

g = np.dot(x, x_transpose)#Gram矩阵

alfa = np.array([0, 0, 0])
b =np.array([0])

learning_rate = 1


#是否还存在误分类点
def judge_class(y, g, b):
    misclassification = False
    cnt = 0
    mis_index = 0
    for i in range(len(y)):
        sum1 = 0
        for j in range(len(y)):
            sum1 += (alfa[j]*y[j]*g[i][j] + b)
        if y[i]*sum1 <= 0:
            cnt += 1
            mis_index = i
            misclassification = True
            #break
    return misclassification, mis_index


# 更新系数alfa, b
def update(y, alfa, learning_rate, b, i):
    alfa[i] = alfa[i] + learning_rate
    b = b + learning_rate*y[i]
    return alfa, b



#更新迭代
def optimization(y, alfa, b, learning_rate):
    misclassification, mis_index = judge_class(y, g, b)
    #while misclassification:
    for i in range(7):
        print ("误分类的第x{}点{}:".format(mis_index+1, x[mis_index]))
        alfa, b = update(y, alfa, learning_rate, b, mis_index)
        print ("采用第x{}误分类点 {} 更新后的权重为:alfa是 {} , b是 {} ".format(mis_index+1, x[mis_index], alfa, b))
        print("\n")
        misclassification, mis_index = judge_class(y, g, b)
    return alfa, b

#a1,b1=optimization(y, alfa, b, learning_rate)
#print("a1={},b1={}".format(a1,b1))

alfa, b = optimization(y, alfa, b, learning_rate)
print("alfa={},b={}".format(alfa,b))


alfa_y = np.multiply(list(alfa),y)#对应元素相乘
print("alfa_y={}".format(alfa_y))
w = np.dot(alfa_y,x)
b = np.dot(alfa, y)
print("w是{},b是{}".format(w, b))





#绘制图————————————————————————————————
#w1*x1+w2*x2+b=0
#matplotlib画图中中文显示会有问题,需要这两行设置默认字体
plt.rcParams['font.sans-serif']=['SimHei']
plt.rcParams['axes.unicode_minus'] = False
 
 
#画的坐标轴并设置轴标签x,y 
plt.xlabel('X')
plt.ylabel('Y')
plt.xlim(xmax=5,xmin=0)
plt.ylim(ymax=5,ymin=0)
 
#散点图数据
x1=np.array([3,4])
y1=np.array([3,3])
 
x2 = np.arange(0,5,0.1)
y2 = -((w[0])*x2 +b)/(w[1]+0.000001)
 
x3 = np.array([1])
y3 = np.array([1])
 
plt.scatter(x1, y1, marker = 'o', alpha=0.4,color="red", label='正类')
plt.scatter(x3, y3, marker = 'x', alpha=0.4,color="red", label='负类')
plt.plot(x2,y2,label="拟合直线",c='orange')
 
plt.legend() #label='正类' 图中显示
plt.title('感知机原始形')#图的标题
plt.show()

 

李航《统计学习方法》例2.12感知机学习算法的对偶形式——实现代码_第2张图片

 

李航《统计学习方法》例2.12感知机学习算法的对偶形式——实现代码_第3张图片

你可能感兴趣的:(统计学习方法)