统计学--感知机

参考李航的统计学习


统计学--感知机_第1张图片
Snip20170809_40.png

统计学--感知机_第2张图片
Snip20170904_2.png

感知机学习算法

统计学--感知机_第3张图片
感知机学习的原始形式.png
Python实现感知机代码
import numpy as np

w = np.array([0,0])
b = 0

#更新w和b
def update(item):
    global w,b
    w = w + item[1]*np.array(item[0])
    b = b + item[1]

def check(date_set):
    global w,b
    flag = False

    #遍历训练集
    for item in date_set:
        x=np.array(item[0])
        jieguo = item[1] * (np.dot(w,x)+b)

        #判断是否误分类
        if jieguo <= 0:
            flag = True
            update(item)

    print(w,b)
    return flag

if __name__ == "__main__":
    #训练集
    date_set = [[(3,3),1],[(4,3),1],[(1,1),-1]]

    #不断测试模型,当模型不存在误分类,停止训练
    for i in range(500):
        if not check(date_set):break

统计学--感知机_第4张图片
感知机学习算法的对偶形式.png
Python代码实现对偶形式
import numpy as np 
#训练集
date_set = np.array([[[3,3],1],[[4,3],1],[[1,1],-1]])
#设置alpha和b的初始值
a = np.zeros((len(date_set),1),np.float)
b = 0.0
x = np.empty((len(date_set),2),np.float)
for i in range(len(date_set)):
    x[i] = date_set[i,0]
y = np.array(date_set[:,1])
Gm = None

#求Gram矩阵
def gram():
    gm = np.empty((len(date_set),len(date_set)),np.int)
    for i in range(len(date_set)):
        for j in range(len(date_set)):
            gm[i][j] = np.dot(date_set[i][0],date_set[j][0])
    return gm

#更新alpha和b的值
def update(i):
    global a,b
    a[i] += 1
    b += y[i]
#测试模型
def check():
    global a,b
    flag = False
    for i in range(len(date_set)):
        jieguo = 0
        for j in range(len(date_set)):
            jieguo += a[j]*y[j]*Gm[i,j]
        jieguo = (jieguo + b)*y[i]
        print(jieguo)

        if jieguo <= 0:
            flag = True
            update(i)
    if not flag:
        w=0.0
        for i in range(len(date_set)):
            w += a[i]*y[i]*x[i]
        print(w,b)
        return False
    return True

if __name__ == "__main__":
    Gm = gram()
    for i in range(1000):
        if not check():break

你可能感兴趣的:(统计学--感知机)