统计学习方法笔记2--感知机

# -*- coding:utf-8 -*-
import os
import optparse

# 命令行参数解析
#----------------------------------------------------------------
optparser = optparse.OptionParser()
optparser.add_option(
    "-i", "--input", default="",
    help="Input file location"
)

opts = optparser.parse_args()[0]

# Check parameters validity
assert os.path.isfile(opts.input)
#------------------------------------------------------------------

# 绘图
#------------------------------------------------------------------
import numpy as np
import matplotlib.pyplot as plt

def drawData(x, y, c, W, b):
    area = np.pi * (6)**2
    plt.scatter(x, y, s=area, c=c, alpha=0.5)

    x_1 = []
    y_1 = []
    x_1.append(-b/W[0])
    x_1.append(0)
    y_1.append(0)
    y_1.append(-b/W[1])
    plt.plot(x_1, y_1)
    plt.show()
#--------------------------------------------------------------------

# 处理输入数据
#--------------------------------------------------------------------
def prepareInput(inputFile):
    x = []
    y = []
    c = []
    X = []
    Y = []
    with open(inputFile, 'r') as f:
        for line in f.readlines():
            col = line.strip().split()
            x.append(int(col[0]))
            y.append(int(col[1]))
            X.append((int(col[0]), int(col[1])))
            Y.append(int(col[2]))
            if int(col[2]) > 0:
                c.append('r')
            else:
                c.append('g')
    return x, y, c, X, Y
#----------------------------------------------------------------------

# 训练函数
#----------------------------------------------------------------------
def train(x, y, W=np.array((0,0)), b=0, lr=1):
    X = np.array(x)
    Y = np.array(y)
    while 1:
        count = 0
        for i in range(len(y)):
            if Y[i]*(W.dot(X[i])+b) <= 0:
                W += lr*Y[i]*X[i]
                b += lr*Y[i]
                break
            count += 1
        if count == len(y):
            break
    return W, b

#----------------------------------------------------------------------

# 主函数
#----------------------------------------------------------------------
def main():
    x, y, c, X, Y = prepareInput(opts.input)
    W, b = train(X, Y)
    drawData(x, y, c, W, b)
    print "W: %s" % W
    print "b: %s" % b

#----------------------------------------------------------------------

if __name__ == "__main__":
    main()

你可能感兴趣的:(统计学习方法,python,感知机)