前言:Machine Learning用来记录机器学习常用几种算法的学习过程,理论知识部分基本摘抄至李航大神的统计学习方法。每个算法会有自己的源码实现。优缺点总结,适用场景(后续补充)。
统计学习三要素:方法=模型+策略+算法
1)模型定义
感知机是用来分类的一种模型,假设x属于n维实数空间的向量,,感知机模型定义为:
感知机是一种线性分类模型,属于判别模型。感知机的假设空间是定义在特征空间中的所有线性分类器。
感知机的几何解释是:对应于特征空间的超平面,该超平面将特征空间划分为两个部分;位于两部分的点分别被分为正,负两类。该超平面被称为分离超平面。
2)策略(损失函数定义)
损失函数如果使用模型误判的个数会有问题,虽然该方法可以度量模型的损失,但对w,b不可导。不以优化。一般采取模型误判点到分离超平面的距离作为loss。符合误判点距离超平面越远(错的越离谱),误差愈多。
3)算法
优化模型任务就转化成了,求解w,b。使其成为以下损失函数极小化问题的解:
1、求梯度。分别对w,b求偏导
2、参数更新
#-*- coding= utf-8 -*-
import sklearn
from sklearn import datasets
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.datasets import make_classification
X,Y = make_classification(n_samples=569,n_features=30,n_classes=2)
Y = np.where(Y>0,1,-1)
print(X.shape)
print(Y.shape)
x_train,x_test,y_train,y_test = train_test_split(X,Y,test_size=.3)
#def model sign(w·x+b)
#weight and bias initeralizer
mu = 0
sigma = 1
np.random.seed(0)
w = np.random.normal(mu,sigma,30)
b = np.random.randn(1)
#define model
def model(x,w,b):
'''
define model: y = sign(w·x+b)
:param x:
:param w:
:param b:
:return:
'''
# print(np.dot(w,x)+b)
return np.where((np.dot(w,np.transpose(x))+b)>=0,1,-1)
#定义算法,选择最远模型的方法
def loss_fn(w,b,x,y_true):
temp = np.maximum(-y_true*(np.dot(w,np.transpose(x))+b),0)
return temp
def compute_gradient(loss,x,y):
w_grad = 0
b_grad = 0
for i in range(len(loss)):
if loss[i] > 0:
w_grad += -x[i] * y[i]
b_grad += -y[i]
return w_grad,b_grad
def compute_acc(y_true,y_pred):
total_samples = len(y_pred)
nums_true_sample = 0
for i in range(total_samples):
if y_true[i] == y_pred[i]:
nums_true_sample += 1
acc = nums_true_sample/float(total_samples)
return acc
epochs = 100000
learning_rate = 1e-5
step = 0
for i in range(epochs):
batchs = sklearn.utils.gen_batches(len(y_train),32)
for batch in batchs:
step += 1
batch_x = x_train[batch]
batch_y = y_train[batch]
#定义前向传播过程
#y = sign(w·x+b)
#指示函数sign使用np.where实现
pred = model(batch_x,w,b)
acc = compute_acc(batch_y,pred)
# 计算loss
loss = loss_fn(w, b, batch_x, batch_y)
if step % 500 == 0:
print('loss:%.5f acc:%.5f'%(sum(loss),acc))
#反向传播,计算梯度
w_grad,b_grad = compute_gradient(loss,batch_x,batch_y)
#更新参数
w -= learning_rate * w_grad
b -= learning_rate * b_grad
#test acc on test_data each epoch
test_pred = np.where((np.dot(w,np.transpose(x_test))+b)>0,1,-1)
test_acc = compute_acc(y_test,test_pred)
if step % 1000 == 0:
print('test acc:%.5f'%test_acc)
训练测试结果:
下面截取了一段训练日志,可以看到loss总体趋势一直在减小,模型对训练集的acc稳定在90%以上,验证机的准确率提升到80%后开始震荡。
'''
loss:105.24404 acc:0.43750
loss:150.76018 acc:0.15625
loss:101.36801 acc:0.37500
loss:135.94005 acc:0.28125
loss:115.67770 acc:0.37500
loss:113.76541 acc:0.31250
loss:114.51207 acc:0.31250
loss:84.38081 acc:0.46875
loss:81.72001 acc:0.50000
loss:83.79107 acc:0.50000
loss:97.56048 acc:0.31250
loss:50.71439 acc:0.53125
loss:11.85965 acc:0.64286
loss:62.48295 acc:0.62500
loss:83.81445 acc:0.40625
loss:65.34037 acc:0.50000
loss:72.84614 acc:0.40625
loss:74.43541 acc:0.50000
loss:61.79947 acc:0.46875
loss:70.49474 acc:0.43750
loss:41.63456 acc:0.56250
loss:50.60463 acc:0.59375
loss:51.39166 acc:0.53125
loss:45.76451 acc:0.50000
loss:24.18704 acc:0.68750
loss:1.87922 acc:0.85714
test acc:0.53801
loss:36.63442 acc:0.71875
loss:48.50881 acc:0.56250
loss:45.31712 acc:0.59375
loss:34.31767 acc:0.46875
loss:49.56287 acc:0.59375
loss:34.27624 acc:0.62500
loss:41.15760 acc:0.56250
loss:19.99830 acc:0.65625
loss:30.16264 acc:0.59375
loss:32.66717 acc:0.59375
loss:23.36999 acc:0.68750
loss:13.30596 acc:0.68750
loss:0.00000 acc:1.00000
loss:19.43344 acc:0.68750
loss:29.43970 acc:0.62500
loss:30.82463 acc:0.59375
loss:18.29285 acc:0.68750
loss:32.54916 acc:0.68750
loss:21.77884 acc:0.68750
loss:26.32651 acc:0.62500
loss:7.72756 acc:0.81250
loss:19.46919 acc:0.65625
loss:22.92023 acc:0.68750
loss:13.68733 acc:0.78125
loss:9.21291 acc:0.78125
loss:0.00000 acc:1.00000
test acc:0.68421
loss:10.81774 acc:0.84375
loss:18.69560 acc:0.71875
loss:21.14270 acc:0.68750
loss:11.71013 acc:0.71875
loss:23.68367 acc:0.71875
loss:15.59680 acc:0.75000
loss:16.38365 acc:0.65625
loss:3.48871 acc:0.84375
loss:13.86623 acc:0.78125
loss:17.23298 acc:0.71875
loss:9.11982 acc:0.87500
loss:7.13866 acc:0.78125
loss:0.00000 acc:1.00000
loss:7.07178 acc:0.87500
loss:12.17179 acc:0.81250
loss:15.75085 acc:0.75000
loss:7.26880 acc:0.71875
loss:18.08974 acc:0.81250
loss:11.61002 acc:0.75000
loss:9.54142 acc:0.68750
loss:1.72861 acc:0.87500
loss:10.87609 acc:0.81250
loss:13.35418 acc:0.75000
loss:7.08608 acc:0.84375
loss:5.89114 acc:0.81250
loss:0.00000 acc:1.00000
test acc:0.77193
loss:4.25427 acc:0.87500
loss:8.29292 acc:0.84375
loss:12.63444 acc:0.81250
loss:4.98424 acc:0.78125
loss:15.21485 acc:0.81250
loss:8.43432 acc:0.75000
loss:5.25917 acc:0.78125
loss:0.94625 acc:0.87500
loss:8.57712 acc:0.84375
loss:10.67448 acc:0.84375
loss:5.88110 acc:0.84375
loss:5.32498 acc:0.87500
loss:0.00000 acc:1.00000
loss:2.22088 acc:0.87500
loss:5.87675 acc:0.87500
loss:10.45557 acc:0.78125
loss:3.74712 acc:0.84375
loss:13.22224 acc:0.81250
loss:6.36244 acc:0.75000
loss:3.52659 acc:0.84375
loss:0.98263 acc:0.93750
loss:6.95762 acc:0.84375
loss:9.31484 acc:0.90625
loss:4.58888 acc:0.84375
loss:5.10567 acc:0.87500
loss:0.00000 acc:1.00000
test acc:0.78363
loss:0.90419 acc:0.90625
loss:4.25921 acc:0.87500
loss:8.49134 acc:0.78125
loss:3.26698 acc:0.84375
loss:11.79509 acc:0.84375
loss:4.97345 acc:0.84375
loss:2.84477 acc:0.87500
loss:1.21036 acc:0.93750
loss:5.64792 acc:0.84375
loss:8.42069 acc:0.93750
loss:3.71578 acc:0.90625
loss:4.87594 acc:0.90625
loss:0.00000 acc:1.00000
loss:0.45847 acc:0.96875
loss:3.19834 acc:0.90625
loss:6.94889 acc:0.78125
loss:3.04756 acc:0.78125
loss:10.93073 acc:0.84375
loss:4.25836 acc:0.84375
loss:2.17207 acc:0.84375
loss:1.42136 acc:0.93750
loss:4.70966 acc:0.87500
loss:8.03839 acc:0.93750
loss:3.00213 acc:0.90625
loss:4.75577 acc:0.90625
loss:0.00000 acc:1.00000
test acc:0.78363
loss:0.31959 acc:0.96875
loss:2.40828 acc:0.90625
loss:5.82317 acc:0.84375
loss:2.83232 acc:0.81250
loss:10.00275 acc:0.84375
loss:3.57126 acc:0.84375
loss:1.65234 acc:0.87500
loss:1.61977 acc:0.93750
loss:4.03384 acc:0.87500
loss:7.69399 acc:0.90625
loss:2.63821 acc:0.90625
loss:4.54189 acc:0.90625
loss:0.00000 acc:1.00000
loss:0.18972 acc:0.96875
loss:1.75994 acc:0.90625
loss:5.05572 acc:0.84375
loss:2.65805 acc:0.81250
loss:9.21468 acc:0.87500
loss:2.95136 acc:0.84375
loss:1.49775 acc:0.87500
loss:1.76816 acc:0.93750
loss:3.52355 acc:0.90625
loss:7.41432 acc:0.93750
loss:2.25899 acc:0.90625
loss:4.25336 acc:0.90625
loss:0.00000 acc:1.00000
test acc:0.80702
loss:0.02820 acc:0.96875
loss:1.23499 acc:0.93750
loss:4.29800 acc:0.84375
loss:2.56567 acc:0.81250
loss:8.50027 acc:0.87500
loss:2.34109 acc:0.84375
loss:1.34891 acc:0.93750
loss:1.92763 acc:0.93750
loss:3.11241 acc:0.90625
loss:7.10171 acc:0.87500
loss:1.94405 acc:0.90625
loss:3.95033 acc:0.90625
loss:0.00000 acc:1.00000
loss:0.00000 acc:1.00000
loss:0.87173 acc:0.93750
loss:3.59864 acc:0.84375
loss:2.36613 acc:0.81250
loss:7.94177 acc:0.87500
loss:1.97824 acc:0.93750
loss:1.26354 acc:0.93750
loss:2.04078 acc:0.90625
loss:2.77734 acc:0.90625
loss:6.74133 acc:0.87500
loss:1.77319 acc:0.96875
loss:3.68123 acc:0.90625
loss:0.00000 acc:1.00000
test acc:0.81287
loss:0.00000 acc:1.00000
loss:0.58720 acc:0.93750
loss:3.08344 acc:0.84375
loss:2.10701 acc:0.78125
loss:7.54687 acc:0.87500
loss:1.84997 acc:0.93750
loss:1.15680 acc:0.93750
loss:2.20248 acc:0.93750
loss:2.51515 acc:0.90625
loss:6.39995 acc:0.93750
loss:1.78001 acc:0.93750
loss:3.47088 acc:0.90625
loss:0.00000 acc:1.00000
loss:0.00000 acc:1.00000
loss:0.36069 acc:0.93750
loss:2.72082 acc:0.84375
loss:1.84009 acc:0.81250
loss:7.11066 acc:0.90625
loss:1.72742 acc:0.93750
loss:1.02820 acc:0.90625
loss:2.29181 acc:0.93750
loss:2.41081 acc:0.90625
loss:6.09377 acc:0.93750
loss:1.80054 acc:0.93750
loss:3.26580 acc:0.90625
loss:0.00000 acc:1.00000
test acc:0.81871
loss:0.00000 acc:1.00000
loss:0.18003 acc:0.93750
loss:2.44613 acc:0.87500
loss:1.61047 acc:0.87500
loss:6.67611 acc:0.90625
loss:1.59183 acc:0.93750
loss:0.89677 acc:0.90625
loss:2.37084 acc:0.93750
loss:2.25829 acc:0.90625
loss:5.79726 acc:0.90625
loss:1.80508 acc:0.93750
loss:3.00829 acc:0.90625
loss:0.00000 acc:1.00000
loss:0.00000 acc:1.00000
loss:0.05135 acc:0.93750
loss:2.24199 acc:0.84375
loss:1.45565 acc:0.84375
loss:6.24971 acc:0.87500
loss:1.44302 acc:0.93750
loss:0.76834 acc:0.90625
loss:2.44379 acc:0.93750
loss:2.09498 acc:0.90625
loss:5.51264 acc:0.93750
loss:1.80650 acc:0.93750
loss:2.74188 acc:0.93750
loss:0.00000 acc:1.00000
test acc:0.81871
loss:0.00000 acc:1.00000
loss:0.00006 acc:0.96875
loss:2.06176 acc:0.84375
loss:1.31890 acc:0.87500
loss:5.81376 acc:0.90625
loss:1.29607 acc:0.93750
loss:0.66643 acc:0.90625
loss:2.45500 acc:0.93750
loss:1.95926 acc:0.93750
loss:5.30081 acc:0.90625
loss:1.78542 acc:0.93750
loss:2.53433 acc:0.90625
loss:0.00000 acc:1.00000
loss:0.00000 acc:1.00000
loss:0.00000 acc:1.00000
loss:1.92465 acc:0.87500
loss:1.18924 acc:0.84375
loss:5.40256 acc:0.84375
loss:1.15270 acc:0.93750
loss:0.56949 acc:0.93750
loss:2.44485 acc:0.93750
loss:1.81624 acc:0.93750
loss:5.10627 acc:0.87500
loss:1.76959 acc:0.93750
loss:2.38349 acc:0.93750
loss:0.00000 acc:1.00000
test acc:0.82456
'''