机器学习:KNN用java代码实现

KNN算法:使用欧式距离计算方法,从源对象集合中选取距离目标节点最近的K个节点,判断K个节点所属类别最多的节点,即为目标节点所属的类别。

此处只是简单的实现KNN算法的过程,其中有一些优化的地方不再修改,还请小伙伴自行优化。

KNN的model类:

package com.spring5.bigdata.knn;

/**
 * @author yinxf
 * @date 2020-05-16
 */
public class KnnNode {
    private float x; //X坐标
    private float y; //Y坐标
    private Float distance; //目标节点到此节点的距离
    private String type; //所属类别

    public KnnNode(float x, float y, String type) {
        this.x = x;
        this.y = y;
        this.type = type;
    }

    public float getDistance() {
        return distance;
    }

    public void setDistance(float distance) {
        this.distance = distance;
    }

    public float getX() {
        return x;
    }

    public void setX(float x) {
        this.x = x;
    }

    public float getY() {
        return y;
    }

    public void setY(float y) {
        this.y = y;
    }

    public String getType() {
        return type;
    }

    public void setType(String type) {
        this.type = type;
    }

    @Override
    public String toString() {
        return "Node{" +
                "x=" + x +
                ", y=" + y +
                ", distance=" + distance +
                ", type='" + type + '\'' +
                '}';
    }

}

KNN的测试类,其中包括初始化数据,使用欧式距离计算选取距离目标节点最近的K个节点,

package com.spring5.bigdata.knn;

import java.util.ArrayList;
import java.util.List;

/**
 * @author yinxf
 * @date 2020-05-16
 */
public class Knn {

    //类别
    private final static String RED = "RED";
    private final static String BLACK = "BLACK";


    public static void main(String[] args) {
        //初始化所有节点坐标
        List totalType = init();
        //验证此节点属于哪个类别
//        KnnNode knnNode = new KnnNode(4,5,""); //所属类别为:black
        KnnNode knnNode = new KnnNode(3,2,""); //所属类别为:red

        //计算所有节点到目标节点的距离,欧式距离
        totalType = getDistance(totalType,knnNode);

        //计算距离目标节点最近的K个节点
        int k = 3;
        List kList = getKList(totalType,k);

        //计算提供的节点属于那种类别
        String resultType = getNodeType(kList);
        System.out.println("目标节点所属类别为:" + resultType);

    }

    /**
     * 查找距离目标节点最近的K个节点
     * @param totalType
     * @param k
     * @return
     */
    private static List getKList(List totalType, int k) {
        List kList = new ArrayList<>(k);
        //选出距离目标节点最近的K个节点
        for (int i = 0 ; i < totalType.size() ; i++ ) {
            KnnNode type = totalType.get(i);
            if (i < k){
                kList.add(type);
            }else {
                boolean flag = false;
                //判断当前节点小于K个节点集合中的节点
                for (KnnNode knnNode1 : kList) {
                    if (type.getDistance() < knnNode1.getDistance()){
                        flag = true;
                        break;
                    }
                }
                //替换距离目标节点最远的节点
                if (flag) {
                    int index = 0 ;
                    for (int j = 0; j < k; j++) {
                        if (kList.get(j).getDistance() > type.getDistance()) {
                            index = j ;
                        }
                    }
                    kList.remove(index);
                    kList.add(type);

                    kList.forEach(list -> System.out.println(list.toString()));
                    System.out.println("=========================================");
                }
            }
        }
        return kList;
    }

    /**
     * 计算所有节点到目标节点的距离
     * @param totalType
     * @param knnNode
     * @return
     */
    private static List getDistance(List totalType, KnnNode knnNode){
        for (int i = 0 ; i < totalType.size() ; i++ ) {
            KnnNode type = totalType.get(i);
            float distance = distance(type, knnNode);
            type.setDistance(distance);
            System.out.println( i+"类别为:【"+ type.getType() + "】  距离为:【"+distance +"】" );
        }
        return totalType;
    }

    /**
     * 计算目标节点所属类别
     * @param kList
     * @return
     */
    private static String getNodeType(List kList) {
        //结算距离目标节点最近的K个节点中的,节点最多的类别是什么
        int redNum = 0;
        int blackNum = 0;
        for (KnnNode result : kList) {
            if (RED.equals(result.getType())){
                redNum++;
            }else if (BLACK.equals(result.getType())){
                blackNum++;
            }
        }
        return blackNum > redNum ? BLACK : RED;
    }


    /**
     * 欧式距离计算公式
     * @param source
     * @param target
     * @return
     */
    private static float distance(KnnNode source, KnnNode target) {
        float x = source.getX() - target.getX();
        float y = source.getY() - target.getY();
        float z = x * x + y * y;
        float distance = (float) Math.sqrt(z);
        return distance;
    }

    /**
     * 初始化节点
     * @return
     */
    private static List init() {
        List totalType = new ArrayList<>();
        totalType.add(new KnnNode(1,2,RED));
        totalType.add(new KnnNode(2,2,RED));
        totalType.add(new KnnNode(1,3,RED));
        totalType.add(new KnnNode(2,1,RED));
        totalType.add(new KnnNode(2,3,RED));
        totalType.add(new KnnNode(3,5,BLACK));
        totalType.add(new KnnNode(4,6,BLACK));
        totalType.add(new KnnNode(3,4,BLACK));
        totalType.add(new KnnNode(5,4,BLACK));
        totalType.add(new KnnNode(5,3,BLACK));
        return totalType;
    }
}

测试结果如下:

0类别为:【RED】  距离为:【2.0】
1类别为:【RED】  距离为:【1.0】
2类别为:【RED】  距离为:【2.236068】
3类别为:【RED】  距离为:【1.4142135】
4类别为:【RED】  距离为:【1.4142135】
5类别为:【BLACK】  距离为:【3.0】
6类别为:【BLACK】  距离为:【4.1231055】
7类别为:【BLACK】  距离为:【2.0】
8类别为:【BLACK】  距离为:【2.828427】
9类别为:【BLACK】  距离为:【2.236068】
Node{x=1.0, y=2.0, distance=2.0, type='RED'}
Node{x=2.0, y=2.0, distance=1.0, type='RED'}
Node{x=2.0, y=1.0, distance=1.4142135, type='RED'}
=========================================
Node{x=2.0, y=2.0, distance=1.0, type='RED'}
Node{x=2.0, y=1.0, distance=1.4142135, type='RED'}
Node{x=2.0, y=3.0, distance=1.4142135, type='RED'}
=========================================
目标节点所属类别为:RED

 

你可能感兴趣的:(Java开发,java,大数据,算法)