Machine Learning学习---感知机算法实现

前言:Machine Learning用来记录机器学习常用几种算法的学习过程,理论知识部分基本摘抄至李航大神的统计学习方法。每个算法会有自己的源码实现。优缺点总结,适用场景(后续补充)。

一、感知机算法原理

统计学习三要素:方法=模型+策略+算法

1)模型定义

感知机是用来分类的一种模型,假设x属于n维实数空间的向量,感知机模型定义为:

感知机是一种线性分类模型,属于判别模型。感知机的假设空间是定义在特征空间中的所有线性分类器

感知机的几何解释是:对应于特征空间的超平面,该超平面将特征空间划分为两个部分;位于两部分的点分别被分为正,负两类。该超平面被称为分离超平面。

2)策略(损失函数定义)

损失函数如果使用模型误判的个数会有问题,虽然该方法可以度量模型的损失,但对w,b不可导。不以优化。一般采取模型误判点到分离超平面的距离作为loss。符合误判点距离超平面越远(错的越离谱),误差愈多。

3)算法

优化模型任务就转化成了,求解w,b。使其成为以下损失函数极小化问题的解:

1、求梯度。分别对w,b求偏导

2、参数更新

 二、感知机算法实现

#-*- coding= utf-8 -*-
import sklearn
from sklearn import datasets
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.datasets import make_classification
X,Y = make_classification(n_samples=569,n_features=30,n_classes=2)
Y = np.where(Y>0,1,-1)
print(X.shape)
print(Y.shape)
x_train,x_test,y_train,y_test = train_test_split(X,Y,test_size=.3)

#def model sign(w·x+b)
#weight and bias initeralizer
mu = 0
sigma = 1
np.random.seed(0)
w = np.random.normal(mu,sigma,30)
b = np.random.randn(1)

#define model
def model(x,w,b):
    '''
    define model:  y = sign(w·x+b)
    :param x:
    :param w:
    :param b:
    :return:
    '''
    # print(np.dot(w,x)+b)
    return np.where((np.dot(w,np.transpose(x))+b)>=0,1,-1)

#定义算法,选择最远模型的方法
def loss_fn(w,b,x,y_true):
    temp = np.maximum(-y_true*(np.dot(w,np.transpose(x))+b),0)
    return temp

def compute_gradient(loss,x,y):
    w_grad = 0
    b_grad = 0
    for i in range(len(loss)):
        if loss[i] > 0:
            w_grad += -x[i] * y[i]
            b_grad += -y[i]
    return w_grad,b_grad

def compute_acc(y_true,y_pred):
    total_samples = len(y_pred)
    nums_true_sample = 0
    for i in range(total_samples):
        if y_true[i] == y_pred[i]:
            nums_true_sample += 1
    acc = nums_true_sample/float(total_samples)
    return acc

epochs = 100000
learning_rate = 1e-5
step = 0
for i in range(epochs):
    batchs = sklearn.utils.gen_batches(len(y_train),32)
    for batch in batchs:
        step += 1
        batch_x = x_train[batch]
        batch_y = y_train[batch]

        #定义前向传播过程
        #y = sign(w·x+b)
        #指示函数sign使用np.where实现
        pred = model(batch_x,w,b)
        acc = compute_acc(batch_y,pred)
        
        # 计算loss
        loss = loss_fn(w, b, batch_x, batch_y)

        if step % 500 == 0:
            print('loss:%.5f  acc:%.5f'%(sum(loss),acc))

        #反向传播,计算梯度
        w_grad,b_grad = compute_gradient(loss,batch_x,batch_y)
        
        #更新参数
        w -= learning_rate * w_grad
        b -= learning_rate * b_grad

    #test acc on test_data each epoch
    test_pred = np.where((np.dot(w,np.transpose(x_test))+b)>0,1,-1)
    test_acc = compute_acc(y_test,test_pred)
    if step % 1000 == 0:
        print('test acc:%.5f'%test_acc)

训练测试结果:

下面截取了一段训练日志,可以看到loss总体趋势一直在减小,模型对训练集的acc稳定在90%以上,验证机的准确率提升到80%后开始震荡。

'''
loss:105.24404  acc:0.43750
loss:150.76018  acc:0.15625
loss:101.36801  acc:0.37500
loss:135.94005  acc:0.28125
loss:115.67770  acc:0.37500
loss:113.76541  acc:0.31250
loss:114.51207  acc:0.31250
loss:84.38081  acc:0.46875
loss:81.72001  acc:0.50000
loss:83.79107  acc:0.50000
loss:97.56048  acc:0.31250
loss:50.71439  acc:0.53125
loss:11.85965  acc:0.64286
loss:62.48295  acc:0.62500
loss:83.81445  acc:0.40625
loss:65.34037  acc:0.50000
loss:72.84614  acc:0.40625
loss:74.43541  acc:0.50000
loss:61.79947  acc:0.46875
loss:70.49474  acc:0.43750
loss:41.63456  acc:0.56250
loss:50.60463  acc:0.59375
loss:51.39166  acc:0.53125
loss:45.76451  acc:0.50000
loss:24.18704  acc:0.68750
loss:1.87922  acc:0.85714
test acc:0.53801
loss:36.63442  acc:0.71875
loss:48.50881  acc:0.56250
loss:45.31712  acc:0.59375
loss:34.31767  acc:0.46875
loss:49.56287  acc:0.59375
loss:34.27624  acc:0.62500
loss:41.15760  acc:0.56250
loss:19.99830  acc:0.65625
loss:30.16264  acc:0.59375
loss:32.66717  acc:0.59375
loss:23.36999  acc:0.68750
loss:13.30596  acc:0.68750
loss:0.00000  acc:1.00000
loss:19.43344  acc:0.68750
loss:29.43970  acc:0.62500
loss:30.82463  acc:0.59375
loss:18.29285  acc:0.68750
loss:32.54916  acc:0.68750
loss:21.77884  acc:0.68750
loss:26.32651  acc:0.62500
loss:7.72756  acc:0.81250
loss:19.46919  acc:0.65625
loss:22.92023  acc:0.68750
loss:13.68733  acc:0.78125
loss:9.21291  acc:0.78125
loss:0.00000  acc:1.00000
test acc:0.68421
loss:10.81774  acc:0.84375
loss:18.69560  acc:0.71875
loss:21.14270  acc:0.68750
loss:11.71013  acc:0.71875
loss:23.68367  acc:0.71875
loss:15.59680  acc:0.75000
loss:16.38365  acc:0.65625
loss:3.48871  acc:0.84375
loss:13.86623  acc:0.78125
loss:17.23298  acc:0.71875
loss:9.11982  acc:0.87500
loss:7.13866  acc:0.78125
loss:0.00000  acc:1.00000
loss:7.07178  acc:0.87500
loss:12.17179  acc:0.81250
loss:15.75085  acc:0.75000
loss:7.26880  acc:0.71875
loss:18.08974  acc:0.81250
loss:11.61002  acc:0.75000
loss:9.54142  acc:0.68750
loss:1.72861  acc:0.87500
loss:10.87609  acc:0.81250
loss:13.35418  acc:0.75000
loss:7.08608  acc:0.84375
loss:5.89114  acc:0.81250
loss:0.00000  acc:1.00000
test acc:0.77193
loss:4.25427  acc:0.87500
loss:8.29292  acc:0.84375
loss:12.63444  acc:0.81250
loss:4.98424  acc:0.78125
loss:15.21485  acc:0.81250
loss:8.43432  acc:0.75000
loss:5.25917  acc:0.78125
loss:0.94625  acc:0.87500
loss:8.57712  acc:0.84375
loss:10.67448  acc:0.84375
loss:5.88110  acc:0.84375
loss:5.32498  acc:0.87500
loss:0.00000  acc:1.00000
loss:2.22088  acc:0.87500
loss:5.87675  acc:0.87500
loss:10.45557  acc:0.78125
loss:3.74712  acc:0.84375
loss:13.22224  acc:0.81250
loss:6.36244  acc:0.75000
loss:3.52659  acc:0.84375
loss:0.98263  acc:0.93750
loss:6.95762  acc:0.84375
loss:9.31484  acc:0.90625
loss:4.58888  acc:0.84375
loss:5.10567  acc:0.87500
loss:0.00000  acc:1.00000
test acc:0.78363
loss:0.90419  acc:0.90625
loss:4.25921  acc:0.87500
loss:8.49134  acc:0.78125
loss:3.26698  acc:0.84375
loss:11.79509  acc:0.84375
loss:4.97345  acc:0.84375
loss:2.84477  acc:0.87500
loss:1.21036  acc:0.93750
loss:5.64792  acc:0.84375
loss:8.42069  acc:0.93750
loss:3.71578  acc:0.90625
loss:4.87594  acc:0.90625
loss:0.00000  acc:1.00000
loss:0.45847  acc:0.96875
loss:3.19834  acc:0.90625
loss:6.94889  acc:0.78125
loss:3.04756  acc:0.78125
loss:10.93073  acc:0.84375
loss:4.25836  acc:0.84375
loss:2.17207  acc:0.84375
loss:1.42136  acc:0.93750
loss:4.70966  acc:0.87500
loss:8.03839  acc:0.93750
loss:3.00213  acc:0.90625
loss:4.75577  acc:0.90625
loss:0.00000  acc:1.00000
test acc:0.78363
loss:0.31959  acc:0.96875
loss:2.40828  acc:0.90625
loss:5.82317  acc:0.84375
loss:2.83232  acc:0.81250
loss:10.00275  acc:0.84375
loss:3.57126  acc:0.84375
loss:1.65234  acc:0.87500
loss:1.61977  acc:0.93750
loss:4.03384  acc:0.87500
loss:7.69399  acc:0.90625
loss:2.63821  acc:0.90625
loss:4.54189  acc:0.90625
loss:0.00000  acc:1.00000
loss:0.18972  acc:0.96875
loss:1.75994  acc:0.90625
loss:5.05572  acc:0.84375
loss:2.65805  acc:0.81250
loss:9.21468  acc:0.87500
loss:2.95136  acc:0.84375
loss:1.49775  acc:0.87500
loss:1.76816  acc:0.93750
loss:3.52355  acc:0.90625
loss:7.41432  acc:0.93750
loss:2.25899  acc:0.90625
loss:4.25336  acc:0.90625
loss:0.00000  acc:1.00000
test acc:0.80702
loss:0.02820  acc:0.96875
loss:1.23499  acc:0.93750
loss:4.29800  acc:0.84375
loss:2.56567  acc:0.81250
loss:8.50027  acc:0.87500
loss:2.34109  acc:0.84375
loss:1.34891  acc:0.93750
loss:1.92763  acc:0.93750
loss:3.11241  acc:0.90625
loss:7.10171  acc:0.87500
loss:1.94405  acc:0.90625
loss:3.95033  acc:0.90625
loss:0.00000  acc:1.00000
loss:0.00000  acc:1.00000
loss:0.87173  acc:0.93750
loss:3.59864  acc:0.84375
loss:2.36613  acc:0.81250
loss:7.94177  acc:0.87500
loss:1.97824  acc:0.93750
loss:1.26354  acc:0.93750
loss:2.04078  acc:0.90625
loss:2.77734  acc:0.90625
loss:6.74133  acc:0.87500
loss:1.77319  acc:0.96875
loss:3.68123  acc:0.90625
loss:0.00000  acc:1.00000
test acc:0.81287
loss:0.00000  acc:1.00000
loss:0.58720  acc:0.93750
loss:3.08344  acc:0.84375
loss:2.10701  acc:0.78125
loss:7.54687  acc:0.87500
loss:1.84997  acc:0.93750
loss:1.15680  acc:0.93750
loss:2.20248  acc:0.93750
loss:2.51515  acc:0.90625
loss:6.39995  acc:0.93750
loss:1.78001  acc:0.93750
loss:3.47088  acc:0.90625
loss:0.00000  acc:1.00000
loss:0.00000  acc:1.00000
loss:0.36069  acc:0.93750
loss:2.72082  acc:0.84375
loss:1.84009  acc:0.81250
loss:7.11066  acc:0.90625
loss:1.72742  acc:0.93750
loss:1.02820  acc:0.90625
loss:2.29181  acc:0.93750
loss:2.41081  acc:0.90625
loss:6.09377  acc:0.93750
loss:1.80054  acc:0.93750
loss:3.26580  acc:0.90625
loss:0.00000  acc:1.00000
test acc:0.81871
loss:0.00000  acc:1.00000
loss:0.18003  acc:0.93750
loss:2.44613  acc:0.87500
loss:1.61047  acc:0.87500
loss:6.67611  acc:0.90625
loss:1.59183  acc:0.93750
loss:0.89677  acc:0.90625
loss:2.37084  acc:0.93750
loss:2.25829  acc:0.90625
loss:5.79726  acc:0.90625
loss:1.80508  acc:0.93750
loss:3.00829  acc:0.90625
loss:0.00000  acc:1.00000
loss:0.00000  acc:1.00000
loss:0.05135  acc:0.93750
loss:2.24199  acc:0.84375
loss:1.45565  acc:0.84375
loss:6.24971  acc:0.87500
loss:1.44302  acc:0.93750
loss:0.76834  acc:0.90625
loss:2.44379  acc:0.93750
loss:2.09498  acc:0.90625
loss:5.51264  acc:0.93750
loss:1.80650  acc:0.93750
loss:2.74188  acc:0.93750
loss:0.00000  acc:1.00000
test acc:0.81871
loss:0.00000  acc:1.00000
loss:0.00006  acc:0.96875
loss:2.06176  acc:0.84375
loss:1.31890  acc:0.87500
loss:5.81376  acc:0.90625
loss:1.29607  acc:0.93750
loss:0.66643  acc:0.90625
loss:2.45500  acc:0.93750
loss:1.95926  acc:0.93750
loss:5.30081  acc:0.90625
loss:1.78542  acc:0.93750
loss:2.53433  acc:0.90625
loss:0.00000  acc:1.00000
loss:0.00000  acc:1.00000
loss:0.00000  acc:1.00000
loss:1.92465  acc:0.87500
loss:1.18924  acc:0.84375
loss:5.40256  acc:0.84375
loss:1.15270  acc:0.93750
loss:0.56949  acc:0.93750
loss:2.44485  acc:0.93750
loss:1.81624  acc:0.93750
loss:5.10627  acc:0.87500
loss:1.76959  acc:0.93750
loss:2.38349  acc:0.93750
loss:0.00000  acc:1.00000
test acc:0.82456
'''

 

你可能感兴趣的:(python,机器学习)