Perceptron Learning Algorithm(python实现)

原文链接: http://www.cnblogs.com/chenximcm/p/6285159.html

一、概论
对于给定的n维(两种类型)数据(训练集),找出一个n-1维的面,能够“尽可能”地按照数据类型分开。通过这个面,我们可以通过这个面对测试数据进行预测。
例如对于二维数据,要找一条直线,把这些数据按照不同类型分开。我们要通过PLA算法,找到这条直线,然后通过判断预测数据与这条直线的位置关系,划分测试数据类型。如下图:
Perceptron Learning Algorithm(python实现)_第1张图片
二、PLA的原理
先初始化一条直线,然后通过多次迭代,修改这条直线,通过多次迭代,这条直线会收敛于接近最佳分类直线。
修改直线的标准是,任意找出一个点(训练数据中的某个点),判断这个点按照这条直线的划分类型是否跟该点实际类型是否相同。如果相同则开始下次迭代;如果判断错误,则更新直线的参数。
三、W的更新步骤
这里写图片描述
期中W为直线的参数矩阵。y为该点的实际类型,x为该点的参数矩阵。
假设有一下测试数据:
Perceptron Learning Algorithm(python实现)_第2张图片
第1、2个位向量参数,第三个为截距值。
这几个测试数据集的类型表现为:
这里写图片描述
求出以下的测试集的类型:
Perceptron Learning Algorithm(python实现)_第3张图片
假设W的初始化值为:这里写图片描述
第一次选择E点来更新W的值:
Perceptron Learning Algorithm(python实现)_第4张图片
其中sign的符号函数,sign(x)当x的值大于0是sign(x)=+1,否则为-1。(这里+1,-1分别表示两种标签类型)
如上面公式求出来的结果是+1类型,而真实值为这里写图片描述预测值跟真实值不一样,所以需要更新W的值:
Perceptron Learning Algorithm(python实现)_第5张图片

四、python实现
1、初始化W的值和迭代次数:

ITERATION = 70;
W = [1, 1, 1];

2、读取训练、测试数据,生成训练、测试(二维)列表:

def createData():
    lines_set = open('../data/PLA/Dataset_PLA.txt').readlines();
    linesTrain = lines_set[1:7];    #测试数据
    linesTest = lines_set[9:13];     #训练数据

    trainDataList = processData(linesTrain);    #生成训练集(二维列表)
    testDataList = processData(linesTest);      #生成测试集(二维列表)
    return trainDataList, testDataList;

def processData(lines):     #按行处理从txt中读到的训练集(测试集)数据
    dataList = [];
    for line in lines:           #逐行读取txt文档里的训练集
        dataLine = line.strip().split();            #按空格切割一行训练数据(字符串)
        dataLine = [int(data) for data in dataLine];            #字符串转int
        dataList.append(dataLine);           #添加到训练数据列表
    return dataList;

3、两个矩阵相乘的结果求符号函数值:

def sign(W, dataList):      #符号函数
    sum = 0;
    for i in range(len(W)):
        sum += W[i] * dataList[i];
    if sum > 0: return 1;
    else: return -1;

如果各项相乘的和比0大则返回+1,否则返回-1;
4、检测测试的类型是否跟真实标签类型一样

def renewW(W, trainData):   #更新W
    signResult = sign(W, trainData);
    if signResult == trainData[-1]: return W;
    for k in range(len(W)):
        W[k] = W[k] + trainData[-1]*trainData[k];
    return W;

如果相等,则不更新W的值,否则按公式 W[k] = W[k] + trainData[-1]*trainData[k];更新W的值,返回W的新值。
5、通过多次迭代,训练W的值

def trainW(W, trainDatas):  #训练W
    newW = [];
    for num in range(ITERATION):
        index = num % len(trainDatas);
        newW = renewW(W, trainDatas[index]);
    return newW;

经过多次迭代后,W的值会收敛于某个值。
6、使用训练后的W对测试集进行分类(预测)

def predictTestData(W, trainDatas, testDatas):  #预测测试数据集
    W = trainW(W, trainDatas);
    print W;
    for i in range(len(testDatas)):
        result = sign(W, testDatas[i]);
        print result;

五、完整代码

ITERATION = 70;
W = [1, 1, 1];

def createData():
    lines_set = open('../data/PLA/Dataset_PLA.txt').readlines();
    linesTrain = lines_set[1:7];    #测试数据
    linesTest = lines_set[9:13];     #训练数据

    trainDataList = processData(linesTrain);    #生成训练集(二维列表)
    testDataList = processData(linesTest);      #生成测试集(二维列表)
    return trainDataList, testDataList;

def processData(lines):     #按行处理从txt中读到的训练集(测试集)数据
    dataList = [];
    for line in lines:           #逐行读取txt文档里的训练集
        dataLine = line.strip().split();            #按空格切割一行训练数据(字符串)
        dataLine = [int(data) for data in dataLine];            #字符串转int
        dataList.append(dataLine);           #添加到训练数据列表
    return dataList;

def sign(W, dataList):      #符号函数
    sum = 0;
    for i in range(len(W)):
        sum += W[i] * dataList[i];
    if sum > 0: return 1;
    else: return -1;

def renewW(W, trainData):   #更新W
    signResult = sign(W, trainData);
    if signResult == trainData[-1]: return W;
    for k in range(len(W)):
        W[k] = W[k] + trainData[-1]*trainData[k];
    return W;

def trainW(W, trainDatas):  #训练W
    newW = [];
    for num in range(ITERATION):
        index = num % len(trainDatas);
        newW = renewW(W, trainDatas[index]);
    return newW;

def predictTestData(W, trainDatas, testDatas):  #预测测试数据集
    W = trainW(W, trainDatas);
    print W;
    for i in range(len(testDatas)):
        result = sign(W, testDatas[i]);
        print result;

trainDatas, testDatas = createData();

predictTestData(W, trainDatas, testDatas);

六、数据集
Perceptron Learning Algorithm(python实现)_第6张图片
第一列为向量的第一个参数,第二列为第二个参数,第三列为截距值,(训练集)第四列为真实标签类型。

转载于:https://www.cnblogs.com/chenximcm/p/6285159.html

你可能感兴趣的:(Perceptron Learning Algorithm(python实现))