机器学习-感知机模型(pocket算法)实现

我们上一篇考虑的感知机模型有一个假定:数据是线性可分的。实际上,由于噪音或者其他因素的存在,不能都是线性可分的。所以,需要考虑非线性可分的情况下,实现方法。

这里使用Pocket算法:

Pocket算法的思想非常简单,在搜索W的时候,不断记录最好的准确率和W。这样即使数据不是线性可分的,也可以得到比较好的测试结果(只要不断的提高迭代的次数)

数据:

https://www.csie.ntu.edu.tw/~htlin/course/ml15fall/hw1/hw1_18_train.dat

https://www.csie.ntu.edu.tw/~htlin/course/ml15fall/hw1/hw1_18_test.dat

计算准确率函数:

# 计算错误率
def checkErrorRate(testMatData, testLabelData, W):
    accuracyCount = 0
    for i in range(len(testMatData)):
        vect = testMatData[i, :]
        extraBiasVect = append(1, vect)
        resultY = vdot(W, extraBiasVect)
        if (resultY <= 0):
            labelY = -1
        else:
            labelY = 1
        if (labelY == testLabelData[i]):
            accuracyCount += 1
    return accuracyCount / len(testLabelData)

Pocket算法

# 数据非线性可分的情况下,pocketPerceptron实现
def pocketPerceptronLearn(trainMatData, trainLabelData, testMatData, testLabelData):
    # 设定最大迭代次数
    maxIteration = 100000
    # 初始向量
    W = [0, 0, 0, 0, 0]
    # labely
    iterationFinish = False
    # 当前迭代次数
    times = 0
    bestW = W
    bestAccuracyRate = 0
    for interationCount in range(maxIteration):
        for dataIndex in range(len(trainMatData)):
            # 计算向量内积
            vect = trainMatData[dataIndex, :]
            extraBiasVect = append(1, vect)
            resultY = vdot(W, extraBiasVect)
            if (resultY <= 0):
                labelY = -1
            else:
                labelY = 1

            if (labelY != trainLabelData[dataIndex]):
                W = W + trainLabelData[dataIndex] * extraBiasVect
                times += 1
                rate = checkErrorRate(testMatData, testLabelData, W)
                if (rate > bestAccuracyRate):
                    bestAccuracyRate = rate
                    bestW = W

            else:
                if (dataIndex == (len(trainMatData) - 1)):
                    iterationFinish = True
        if (iterationFinish == True):
            break
            # 验证测试
        if (times >= 50):
            print(bestW)
            print(bestAccuracyRate)
    return bestW, bestAccuracyRate

刚开始用Python,好多矩阵/数组等数学操作比较啰嗦,效率也不好。在使用中不断恶补吧。

你可能感兴趣的:(机器学习,机器学习,Pocket算法,感知机)