可能是最简单的感知机算法

前言

最近学习统计机器学习,主要是参考李航的《统计学习》,看完感知机后准备用python实现一下书上的例子。自己先在网上搜索相关代码,可是看了半天,大家的实现都太复杂了,代码动辄上百行,而且将功能拆分成各个及其短小的函数,显得十分不紧凑,而且不容易理解掌握,因此,想自己实现一个更加简单的版本,方便初学者理解和掌握,所以有了以下的代码。

代码

# -*- coding:utf8 -*-
import os
import sys

import numpy as np

#input 
x1 = np.array([3,3])
x2 = np.array([4,3])
x3 = np.array([1,1])
y = np.array([1,1,-1])
x = np.array([x1,x2,x3])
gram = np.dot(x, x.T)
print "x:\n",x
print "gram:\n",gram

#params
a = np.array([0,0,0])
b = 0;

count = 0;
condition = 0;
samples_number = len(x)
while(count < (samples_number-1)):
  for i in range(samples_number):
    condition = np.dot(a*y, gram[i])
    condition = (condition + b) * y[i]
    if condition <= 0 :
      a[i] += 1  #update parmas
      b += y[i]
      count = 0
    else:
      count += 1 

w =np.dot( a*y,x)
print "a:", a
print "w:", w
print "b:", b

总结

  1. 关键的部分只是while循环内部,初学者只要看懂这个循环就可以很快掌握感知机的写法,理解之后再写出更强壮的代码应该是比较容易了。
  2. 自己通过这个程序学习了numpy中array的用法。

你可能感兴趣的:(机器学习)