《统计学习方法》—— 感知机原始形式、感知机对偶形式的python3代码实现(三)

前言

在前两篇博客里面,我们分别介绍了感知机的原始形式和感知机的对偶形式。在这篇博客里面,我们将用python3对上述两种感知机算法进行实现。

注意:本文参考了@akirameiao的博客内容。数据放在本文最后,直接复制进文本,保存为.txt格式,各位大佬自取。

  • 导入第三方库。
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
  • 导入数据,将数据保存为 data。
# 载入数据
def load_data(file):
    # 指定数据类型
    data_types = {
     'data1': np.float32, 'data2': np.float32, 'data3': np.float32, 'label': np.int16}
    
    # 数据读取,注意,这里的sep值,一定要为三个空格'   '
    data = pd.read_csv(file, sep='   ', header=None, names=['data1', 'data2', 'label'], dtype=data_types)
    
    # w*x+b = (w, b)*(x, 1),所以我们将特征向量x增加一维,为(x, 1)
    data.insert(2, 'data3', 1)
    
    return data

# 这里的文件路径替换成各位大佬自己文件所在的绝对路径
data = load_data('../input/ganzhiji.txt')
  • 数据的可视化。
data_plot = data.groupby('label')
for name, group in data_plot:
	plt.scatter(data=group, x='data1', y='data2', label=name)
plt.legend()
plt.show()

《统计学习方法》—— 感知机原始形式、感知机对偶形式的python3代码实现(三)_第1张图片

8. 感知机原始形式

现将算法复述如下:

  • 输入:数据集 T = { ( x 1 , y 1 ) , ( x 2 , y 2 ) , . . . , ( x N , y N ) } T=\{(x_1, y_1), (x_2, y_2), ..., (x_N, y_N)\} T={ (x1,y1),(x2,y2),...,(xN,yN)};学习步长 η \eta η
  • 输出: w w w b b b;感知机模型 f ( x ) = s i g n ( w ⋅ x + b ) f(x)=sign(w\cdot x+b) f(x)=sign(wx+b)

(1) 给定初值 ( w 0 , b 0 ) = ( 0 , 0 ) (w_0, b_0)=(0, 0) (w0,b0)=(0,0)
(2) 遍历数据集 T T T,找到第一个误分类点 ( x i , y i ) (x_i, y_i) (xi,yi),满足 y i ( w ⋅ x i + b ) < 0 y_i(w\cdot x_i+b)<0 yi(wxi+b)<0
(3) 更新 w w w b b b w ← w + η y i x i w\leftarrow w+\eta y_ix_i ww+ηyixi b ← b + η y i b\leftarrow b + \eta y_i bb+ηyi
(4) 回到步骤 (2),如果找不到误分类点,则终止算法

根据上述算法,可以写出

# 训练感知机模型
def perception(data, w_b, eta=1, wrongPoints_num=[]):
    wrong_nums = 1
    while True:
        if not wrong_nums:
            break
        
        # 当前w和b下,计算yi(w*xi+b)=yi(w, b) * (xi, 1)的值
         # 首先计算(w, b) * (xi, 1)
        data['wrong_point'] = data[['data1', 'data2', 'data3']].dot(w_b)
         # 再依次乘以yi,并保存进data
        data['wrong_point'] = data['label'].mul(data['wrong_point'])
        
        # 所有的误分类点
        temp = data[data['wrong_point']<=0]
        
        # 计算yi(w*xi+b)<=0,也就是误分类点的数量
        wrong_nums = temp['wrong_point'].count()
        wrongPoints_num.append(wrong_nums)
        
        # 找出第一个误分类点,并更新w, b
        if wrong_nums:
            # 计算 eta*yi*xi
            change = eta * temp['label'].iloc[0] * temp.iloc[0, 0:3].values
            w_b = w_b + change
            #print('更新后的w和b为', w_b[:2], w_b[2])
            #print(w_b)
        
    return w_b[:2], w_b[2], wrongPoints_num

给定初始值,并运行程序

# 初始值
w_b = np.array([0, 0, 0])
eta = 0.1
wrongPoints_num = [] # 记录每次迭代的误分类点个数

# 运行程序
w, b, wrongPoints_num = perception(data, w_b, eta, wrongPoints_num)
print('最终权重w为', w)
print('最终偏置b为', b)

可以作图看下结果

# 可视化
data_plot = data.groupby('label')
for name, group in data_plot:
    plt.scatter(data=group, x='data1', y='data2', label=name)

# 直线
x = np.arange(4, 5.5, 0.1)
y = - w[0] / w[1] * x - b / w[1]
plt.plot(x, y)

plt.legend()
plt.show()

《统计学习方法》—— 感知机原始形式、感知机对偶形式的python3代码实现(三)_第2张图片
算法迭代过程中,误分类点数量的变化曲线

plt.plot(np.arange(len(wrongPoints_num)), wrongPoints_num)
plt.xlabel('num of recursions')
plt.ylabel('num of wrong points')
plt.show()

《统计学习方法》—— 感知机原始形式、感知机对偶形式的python3代码实现(三)_第3张图片

9. 感知机对偶形式

现将算法复述如下:

  • 输入:数据集 T = { ( x 1 , y 1 ) , ( x 2 , y 2 ) , . . . , ( x N , y N ) } T=\{(x_1, y_1), (x_2, y_2), ..., (x_N, y_N)\} T={ (x1,y1),(x2,y2),...,(xN,yN)};学习步长 η \eta η
  • 输出: ( n 1 , n 2 , . . . , n N ) (n_1, n_2, ..., n_N) (n1,n2,...,nN);感知机模型 f ( x ) = s i g n ( w ⋅ x + b ) f(x)=sign(w\cdot x+b) f(x)=sign(wx+b),其中, w = ∑ i = 1 N n i η y i x i w=\sum_{i=1}^Nn_i\eta y_ix_i w=i=1Nniηyixi b = ∑ i = 1 N n i η y i b=\sum_{i=1}^Nn_i\eta y_i b=i=1Nniηyi

(1) 给定初始值 ( n 1 , n 2 , . . . , n N ) = ( 0 , 0 , . . . , 0 ) (n_1, n_2, ..., n_N)=(0, 0, ..., 0) (n1,n2,...,nN)=(0,0,...,0)
(2) 遍历数据集 T T T,找出第一个误分类点 ( x i , y i ) (x_i, y_i) (xi,yi),满足
y i ( ∑ j = 1 N n j η y j x j ⋅ x i + ∑ j = 1 N n j η y j ) = y i ∑ j = 1 N n j η y j ( x j ⋅ x i + 1 ) < 0 \begin{array}{lll} &&y_i(\sum_{j=1}^Nn_j\eta y_jx_j\cdot x_i+\sum_{j=1}^Nn_j\eta y_j)\\ &=& y_i\sum_{j=1}^Nn_j\eta y_j(x_j\cdot x_i+1)\\ &<&0 \end{array} =<yi(j=1Nnjηyjxjxi+j=1Nnjηyj)yij=1Nnjηyj(xjxi+1)0
(3) 更新 n i n_i ni n i ← n i + 1 n_i\leftarrow n_i+1 nini+1
(4) 返回步骤(2),如果没有误分类点,则终止算法

由于在判断误分类点的时候,我们仅需要 x j ⋅ x i x_j\cdot x_i xjxi 的值,所以,我们可以提前计算内积,也就是提前算出Gram矩阵
G = [ x i ⋅ x j ] N × N \mathbf{G}=\left[x_i\cdot x_j \right]_{N\times N} G=[xixj]N×N

# 计算Gram矩阵
# 实际上,我们需要计算的是 [xi * xj + 1] 
# 预处理 Gram矩阵
G = data.loc[:, ['data1', 'data2', 'data3']].values.dot(data.loc[:, ['data1', 'data2', 'data3']].values.T)

# 再计算 向量[yj] 与 Gram矩阵的第i行[xi * xj + 1] 按照元素做乘法
G_hat = G * data['label'].values

下面,我们可以写出如下程序

def perception_dual(data, eta, G_hat, n, wrongPoints_num):
    wrong_num = 1
    while True:
        if not wrong_num:
            break
        
        # 遍历数据集,找到误分类点
        temp = eta * pd.DataFrame(G_hat * n).apply(sum, axis=1)
        data['wrong_points'] = data['label'].mul(temp)
        
        # 所有的误分类点
        wrong = data[data['wrong_points']<=0]
        
        # 误分类点个数
        wrong_num = wrong['wrong_points'].count()
        wrongPoints_num.append(wrong_num)
        
        # 找出第一个误分类点(xi, yi),更新 n_i
        if wrong_num:
            first_index = list(wrong.index)[0]
            n[first_index] += 1
            #print('更新第', first_index, '个数据点')
            #print('该数据点n_i=', n[first_index])
    
    return n, wrongPoints_num

给初值,运行程序

# 给初值
n = np.zeros(len(data))
eta = 1
wrongPoints_num = []

# 运行程序
n, wrongPoints_num = perception_dual(data, eta, G_hat, n, wrongPoints_num)
# 根据 n_i,计算w和b
def w_b(data, eta, n):
    w = eta * data.loc[:, ['data1', 'data2']].mul(data['label'], axis=0).mul(n, axis=0).apply(sum, axis=0)
    b = eta * sum(data['label'] * n)
    return w.values, b

w, b = w_b(data, eta, n)

作图,看看结果对不对

# 可视化
data_plot = data.groupby('label')
for name, group in data_plot:
    plt.scatter(data=group, x='data1', y='data2', label=name)

# 直线
x = np.arange(4.4, 5.5, 0.1)
y = - w[0] / w[1] * x - b / w[1]
plt.plot(x, y)

plt.legend()
plt.show()

《统计学习方法》—— 感知机原始形式、感知机对偶形式的python3代码实现(三)_第4张图片
再看一下每次迭代后的误分类点情况

plt.plot(np.arange(len(wrongPoints_num)), wrongPoints_num)
plt.xlabel('num of recursions')
plt.ylabel('num of wrong points')
plt.show()

《统计学习方法》—— 感知机原始形式、感知机对偶形式的python3代码实现(三)_第5张图片
至此,我们将感知机的原始形式、对偶形式的数学推导以及python3实现全部完成。

下一篇博客中,我们将继续介绍 k近邻方法。

数据:100个数据,直接复制保存为.txt文件

3.542485    1.977398    -1
3.018896    2.556416    -1
7.551510    -1.580030   1
2.114999    -0.004466   -1
8.127113    1.274372    1
7.108772    -0.986906   1
8.610639    2.046708    1
2.326297    0.265213    -1
3.634009    1.730537    -1
0.341367    -0.894998   -1
3.125951    0.293251    -1
2.123252    -0.783563   -1
0.887835    -2.797792   -1
7.139979    -2.329896   1
1.696414    -1.212496   -1
8.117032    0.623493    1
8.497162    -0.266649   1
4.658191    3.507396    -1
8.197181    1.545132    1
1.208047    0.213100    -1
1.928486    -0.321870   -1
2.175808    -0.014527   -1
7.886608    0.461755    1
3.223038    -0.552392   -1
3.628502    2.190585    -1
7.407860    -0.121961   1
7.286357    0.251077    1
2.301095    -0.533988   -1
-0.232542   -0.547690   -1
3.457096    -0.082216   -1
3.023938    -0.057392   -1
8.015003    0.885325    1
8.991748    0.923154    1
7.916831    -1.781735   1
7.616862    -0.217958   1
2.450939    0.744967    -1
7.270337    -2.507834   1
1.749721    -0.961902   -1
1.803111    -0.176349   -1
8.804461    3.044301    1
1.231257    -0.568573   -1
2.074915    1.410550    -1
-0.743036   -1.736103   -1
3.536555    3.964960    -1
8.410143    0.025606    1
7.382988    -0.478764   1
6.960661    -0.245353   1
8.234460    0.701868    1
8.168618    -0.903835   1
1.534187    -0.622492   -1
9.229518    2.066088    1
7.886242    0.191813    1
2.893743    -1.643468   -1
1.870457    -1.040420   -1
5.286862    -2.358286   1
6.080573    0.418886    1
2.544314    1.714165    -1
6.016004    -3.753712   1
0.926310    -0.564359   -1
0.870296    -0.109952   -1
2.369345    1.375695    -1
1.363782    -0.254082   -1
7.279460    -0.189572   1
1.896005    0.515080    -1
8.102154    -0.603875   1
2.529893    0.662657    -1
1.963874    -0.365233   -1
8.132048    0.785914    1
8.245938    0.372366    1
6.543888    0.433164    1
-0.236713   -5.766721   -1
8.112593    0.295839    1
9.803425    1.495167    1
1.497407    -0.552916   -1
1.336267    -1.632889   -1
9.205805    -0.586480   1
1.966279    -1.840439   -1
8.398012    1.584918    1
7.239953    -1.764292   1
7.556201    0.241185    1
9.015509    0.345019    1
8.266085    -0.230977   1
8.545620    2.788799    1
9.295969    1.346332    1
2.404234    0.570278    -1
2.037772    0.021919    -1
1.727631    -0.453143   -1
1.979395    -0.050773   -1
8.092288    -1.372433   1
1.667645    0.239204    -1
9.854303    1.365116    1
7.921057    -1.327587   1
8.500757    1.492372    1
1.339746    -0.291183   -1
3.107511    0.758367    -1
2.609525    0.902979    -1
3.263585    1.367898    -1
2.912122    -0.202359   -1
1.731786    0.589096    -1
2.387003    1.573131    -1

你可能感兴趣的:(机器学习,机器学习)