机器学习入门-用Java实现简单感知机

一、通俗理解机器学习

1、机器学习是人工智能的一种,如图所示,它是人工智能的一个子方向。

机器学习入门-用Java实现简单感知机_第1张图片

2、机器学习有点像人类的学习过程。

1. 人类学习通过经验(事件),归纳出规律。
2. 机器学习通过数据,训练出模型。

机器学习入门-用Java实现简单感知机_第2张图片

3、机器学习不是基于编程形成的直接结果,不是代码直接写出一个模型 y = -0.3x + 6,而是通过归纳得出来的模型,例如,通过初始化 y = ax + b 中的 a、b,不断迭代,获得针对样本数据最优的 a、b 值,即得到对应的、归纳出来的最优的模型。

4、机器学习中,会用一些真实的数据对算法构建的模型进行评估,评估模型的性能,如果这个模型能达到要求,就用来测试其他的数据,如果达不到要求就要调整算法来重新建立模型,再次进行评估,如此循环往复,最终获得满意的模型来处理其他的数据。

二、简单理解感知机

感知机(perceptron)是二分类的线性分类模型,输入为实例的特征向量,输出为实例的类别(取+1和-1)。感知机对应于输入空间中将实例划分为两类的分离超平面。感知机旨在求出该超平面,为求得超平面导入了基于误分类的损失函数,利用梯度下降法 对损失函数进行最优化(最优化)。感知机的学习算法具有简单而易于实现的优点,分为原始形式和对偶形式。感知机预测是用学习得到的感知机模型对新的实例进行预测的,因此属于判别模型。感知机由Rosenblatt于1957年提出的,是神经网络和支持向量机的基础。

  拿二维平面举例,

机器学习入门-用Java实现简单感知机_第3张图片

看这张图,很明显,直线没有将红蓝点完全分开在两个区域,我们可以将其称为错误的直线,感知机要做的,就是根据各点坐标,将错误的直线纠正为正确的,这样得到的直线就是训练的结果。

说了这么多,要怎么实现呢?  看下面的流程图

机器学习入门-用Java实现简单感知机_第4张图片

三、用Java实现感知机

1、关于w与b的修改

我们会定义一个变量η(0≤η≤1)作为步长,在统计学是中成为学习速率。步长越大,梯度下降的速度越快,更能接近极小点。如果步长过大,有可能导致跨过极小点,导致函数发散;如果步长过小,有可能会耗很长时间才能达到极小点。默认为1

对于wi   wi+=η*y*xi

对于b      b+=η*y

2、代码实现

package machineLearning;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

public class test1 {

    public static int eta = 1;//步长,默认为1
    public static double w[] = {1.0,2.0,3.0,4.0};
    public static int b = 7;

    public static List arrayList;
    public static void main(String[] args) {
        Point p1 = new Point(new double[]{0, 0, 0, 1}, -1);
        Point p2 = new Point(new double[]{1, 0, 0, 0}, 1);
        Point p3 = new Point(new double[]{2, 1, 0, 0}, 1);
        Point p4 = new Point(new double[]{2, 1, 0, 1}, -1);
        arrayList = new ArrayList<>();
        arrayList.add(p1);
        arrayList.add(p2);
        arrayList.add(p3);
        arrayList.add(p4);
        boolean classify = classify();
    }

    /*
     * 判断所有点的位置关系,进行分类
     */
    public static boolean classify() {
        boolean flag = false;
        while (!flag) {
            for (int i = 0; i < arrayList.size(); i++) {
                if (Anwser(arrayList.get(i)) <= 0) {
                    Update(arrayList.get(i));
                    break;
                }
                if (i + 1 == arrayList.size()) {
                    flag = true;
                }
            }
        }
        return true;
    }

    /*
     * 点乘返回sum
     */
    private static double Dot(double[] w, double[] x) {
        double sum = 0;
        for (int i = 0; i < x.length; i++) {
            sum += w[i] * x[i];
        }
        return sum;
    }

    /*
     * 返回函数计算的值
     */
    private static double Anwser(Point point) {
        System.out.println("w:"+Arrays.toString(w));
        System.out.println("b:"+b);
        return point.y * (Dot(w, point.x) + b);
    }

    public static void Update(Point point) {
        for (int i = 0; i < w.length; i++) {
            w[i] += eta * point.y * point.x[i];
        }
        b += eta * point.y;
        return;
    }

}
class Point{
    double x[];
    int y;

    public Point(double[] x, int y) {
        this.x = x;
        this.y = y;
    }
}

3、测试结果

w:[1.0, 2.0, 3.0, 4.0]
b:7
w:[1.0, 2.0, 3.0, 3.0]
b:6
w:[1.0, 2.0, 3.0, 2.0]
b:5
w:[1.0, 2.0, 3.0, 1.0]
b:4
w:[1.0, 2.0, 3.0, 0.0]
b:3
w:[1.0, 2.0, 3.0, -1.0]
b:2
w:[1.0, 2.0, 3.0, -2.0]
b:1
w:[1.0, 2.0, 3.0, -2.0]
b:1
w:[1.0, 2.0, 3.0, -2.0]
b:1
w:[1.0, 2.0, 3.0, -2.0]
b:1
w:[-1.0, 1.0, 3.0, -3.0]
b:0
w:[-1.0, 1.0, 3.0, -3.0]
b:0
w:[0.0, 1.0, 3.0, -3.0]
b:1
w:[0.0, 1.0, 3.0, -3.0]
b:1
w:[0.0, 1.0, 3.0, -3.0]
b:1
w:[0.0, 1.0, 3.0, -3.0]
b:1

四、一些感想

用Java实现这个感知机后,发现机器学习并没有想象中那么难,目前理解还是比较浅显,后续还需继续学习机器学习相关内容。

你可能感兴趣的:(机器学习,机器学习,java,人工智能)