【机器学习系列】感知机学习代码

例题见《统计学习方法》P29例2.1

感知机学习代码:


#coding:gbk

train_set=[[3,3,1],
           [4,3,1],
           [1,1,-1]]

w=[0,0]
b=0
learn_ratio=1

def predict(t,w,b):
    return t[2]*(w[0]*t[0]+w[1]*t[1]+b)>0

def learn(w,b):
    have_wrong_predict_point=True
    iter_num=0
    print("%d\t\t%s\t%d\t%d*x1+%d*x2+(%d)" % (iter_num,w,b,w[0],w[1],b))
    while(have_wrong_predict_point):
        iter_num+=1
        have_wrong_predict_point=False
        for i in range(len(train_set)):
            t=train_set[i]
            if not predict(t,w,b):
                have_wrong_predict_point=True
                w[0]=w[0]+learn_ratio*t[0]*t[2]
                w[1]=w[1]+learn_ratio*t[1]*t[2]
                b=b+learn_ratio*t[2]
                print("%d\tx%d\t%s\t%d\t%d*x1+%d*x2+(%d)" % (iter_num,i+1,w,b,w[0],w[1],b))
                break
    print("%d\tx%d\t%s\t%d\t%d*x1+%d*x2+(%d)" % (iter_num,i+1,w,b,w[0],w[1],b))

learn(w,b)


你可能感兴趣的:(机器学习,感知机)