感知器算法

# -*- coding: utf-8 -*-
"""
Created on Thu Oct 15 13:58:06 2015

@author: Think
"""
#感知器算法
import mkdata as mk
import numpy as np
import matplotlib.pyplot as plt

N = 100 #生成测试点的数目

def check(item, y, w, b):
    ans = w[0]*item[0] + w[1]*item[1] + b
    ans *= y
    
    if ans > 0:
        return True
    else:
        return False
        
def perceptron(X,y):
    iterNums = 1000
    m,n = X.shape
    w = np.zeros(m)
    b = 0
    a = 0.01
    
    for i in range(iterNums):
        for j in range(n):
            if not check(X[:,j], y[0][j], w, b):
                w = w + a * y[0][j]*X[:,j]
                b += a*y[0][j]
    return (w, b)
    
if __name__ == '__main__':
    (X,y,w) = mk.mk_data(N) #是线性可分
    #(X,y,w) = mk.mk_data(N,True) # 不是线性可分
    plt.scatter(X[0,y[0]==1], X[1,y[0]==1], color='red')
    plt.scatter(X[0,y[0]==-1], X[1,y[0]==-1], color='green')
    
    w, b = perceptron(X, y)
    
    x = np.arange(-2,2,0.1)
    x2 = (-b-w[0]*x)/w[1]
    
    plt.plot(x,x2)
    plt.show()


mk_data函数的链接为:python生成测试数据点


如果是线性可分数据点,结果如下:

感知器算法_第1张图片



 如果数据点不是线性可分的,效果如下:

感知器算法_第2张图片

你可能感兴趣的:(机器学习)