# 2020.01.11 By yangbocsu
# 2020.01.11 By yangbocsu
import numpy as np
import matplotlib.pyplot as plt
import time
#train data
x = np.array([[3, 3], [1, 1],[4, 3]])
y = np.array([1, -1, 1])#标签
#参数初始化
w = np.array([0,0])
b =np.array([0])
learning_rate = 1
#判断是否误分类 函数
def judge_class(x,y,w2,b):
misclassification = False
cnt = 0
mis_index = 0
for i in range(len(x)):
if y[i]*(np.dot(w2,x[i])+b) <= 0:
cnt += 1
mis_index = i
break #发现有误分类点 记录其下标就直接退出即可;很重要
if cnt > 0 :
misclassification = True
return misclassification, mis_index
# 参数w, b 的更新
def update(x, y, w1, b, i):
w1 = w1 + learning_rate*y[i]*x[i]
b = b + learning_rate*y[i]
return w1, b
#更新迭代
t1 = time.time()
misclassification, mis_index = judge_class(x,y,w,b)
#while misclassification: # 可以不指定跟新次数,用misclassification来判断
for j in range(7):#书上跟新到第7次 P41
#print(x)
w, b = update(x, y, w, b, mis_index)
print("第{}次循环,w={},b={}".format(j,w,b))
misclassification, mis_index = judge_class(x,y,w,b)
t2 = time.time()
print("w={}".format(w))
print("b={}".format(b))
print("感知机学习算法的原始形式共用:{}ms".format((t2-t1)*1000))
#绘制图
#w1*x1+w2*x2+b=0
#matplotlib画图中中文显示会有问题,需要这两行设置默认字体
plt.rcParams['font.sans-serif']=['SimHei']
plt.rcParams['axes.unicode_minus'] = False
#画的坐标轴并设置轴标签x,y
plt.xlabel('X')
plt.ylabel('Y')
plt.xlim(xmax=12,xmin=0)
plt.ylim(ymax=12,ymin=0)
#散点图数据
x1=np.array([3,4])
y1=np.array([3,3])
x2 = np.arange(0,5,0.1)
y2 = -((w[0])*x2 +b)/(w[1]+0)
x3 = np.array([1])
y3 = np.array([1])
plt.scatter(x1, y1, marker = 'o', alpha=0.4,color="red", label='正类')
plt.scatter(x3, y3, marker = 'x', alpha=0.4,color="red", label='负类')
plt.plot(x2,y2,label="拟合直线",c='orange')
plt.legend() #label='正类' 图中显示
plt.title('感知机原始形')#图的标题
plt.show()