KNN算法:使用欧式距离计算方法,从源对象集合中选取距离目标节点最近的K个节点,判断K个节点所属类别最多的节点,即为目标节点所属的类别。
此处只是简单的实现KNN算法的过程,其中有一些优化的地方不再修改,还请小伙伴自行优化。
KNN的model类:
package com.spring5.bigdata.knn;
/**
* @author yinxf
* @date 2020-05-16
*/
public class KnnNode {
private float x; //X坐标
private float y; //Y坐标
private Float distance; //目标节点到此节点的距离
private String type; //所属类别
public KnnNode(float x, float y, String type) {
this.x = x;
this.y = y;
this.type = type;
}
public float getDistance() {
return distance;
}
public void setDistance(float distance) {
this.distance = distance;
}
public float getX() {
return x;
}
public void setX(float x) {
this.x = x;
}
public float getY() {
return y;
}
public void setY(float y) {
this.y = y;
}
public String getType() {
return type;
}
public void setType(String type) {
this.type = type;
}
@Override
public String toString() {
return "Node{" +
"x=" + x +
", y=" + y +
", distance=" + distance +
", type='" + type + '\'' +
'}';
}
}
KNN的测试类,其中包括初始化数据,使用欧式距离计算选取距离目标节点最近的K个节点,
package com.spring5.bigdata.knn;
import java.util.ArrayList;
import java.util.List;
/**
* @author yinxf
* @date 2020-05-16
*/
public class Knn {
//类别
private final static String RED = "RED";
private final static String BLACK = "BLACK";
public static void main(String[] args) {
//初始化所有节点坐标
List totalType = init();
//验证此节点属于哪个类别
// KnnNode knnNode = new KnnNode(4,5,""); //所属类别为:black
KnnNode knnNode = new KnnNode(3,2,""); //所属类别为:red
//计算所有节点到目标节点的距离,欧式距离
totalType = getDistance(totalType,knnNode);
//计算距离目标节点最近的K个节点
int k = 3;
List kList = getKList(totalType,k);
//计算提供的节点属于那种类别
String resultType = getNodeType(kList);
System.out.println("目标节点所属类别为:" + resultType);
}
/**
* 查找距离目标节点最近的K个节点
* @param totalType
* @param k
* @return
*/
private static List getKList(List totalType, int k) {
List kList = new ArrayList<>(k);
//选出距离目标节点最近的K个节点
for (int i = 0 ; i < totalType.size() ; i++ ) {
KnnNode type = totalType.get(i);
if (i < k){
kList.add(type);
}else {
boolean flag = false;
//判断当前节点小于K个节点集合中的节点
for (KnnNode knnNode1 : kList) {
if (type.getDistance() < knnNode1.getDistance()){
flag = true;
break;
}
}
//替换距离目标节点最远的节点
if (flag) {
int index = 0 ;
for (int j = 0; j < k; j++) {
if (kList.get(j).getDistance() > type.getDistance()) {
index = j ;
}
}
kList.remove(index);
kList.add(type);
kList.forEach(list -> System.out.println(list.toString()));
System.out.println("=========================================");
}
}
}
return kList;
}
/**
* 计算所有节点到目标节点的距离
* @param totalType
* @param knnNode
* @return
*/
private static List getDistance(List totalType, KnnNode knnNode){
for (int i = 0 ; i < totalType.size() ; i++ ) {
KnnNode type = totalType.get(i);
float distance = distance(type, knnNode);
type.setDistance(distance);
System.out.println( i+"类别为:【"+ type.getType() + "】 距离为:【"+distance +"】" );
}
return totalType;
}
/**
* 计算目标节点所属类别
* @param kList
* @return
*/
private static String getNodeType(List kList) {
//结算距离目标节点最近的K个节点中的,节点最多的类别是什么
int redNum = 0;
int blackNum = 0;
for (KnnNode result : kList) {
if (RED.equals(result.getType())){
redNum++;
}else if (BLACK.equals(result.getType())){
blackNum++;
}
}
return blackNum > redNum ? BLACK : RED;
}
/**
* 欧式距离计算公式
* @param source
* @param target
* @return
*/
private static float distance(KnnNode source, KnnNode target) {
float x = source.getX() - target.getX();
float y = source.getY() - target.getY();
float z = x * x + y * y;
float distance = (float) Math.sqrt(z);
return distance;
}
/**
* 初始化节点
* @return
*/
private static List init() {
List totalType = new ArrayList<>();
totalType.add(new KnnNode(1,2,RED));
totalType.add(new KnnNode(2,2,RED));
totalType.add(new KnnNode(1,3,RED));
totalType.add(new KnnNode(2,1,RED));
totalType.add(new KnnNode(2,3,RED));
totalType.add(new KnnNode(3,5,BLACK));
totalType.add(new KnnNode(4,6,BLACK));
totalType.add(new KnnNode(3,4,BLACK));
totalType.add(new KnnNode(5,4,BLACK));
totalType.add(new KnnNode(5,3,BLACK));
return totalType;
}
}
测试结果如下:
0类别为:【RED】 距离为:【2.0】
1类别为:【RED】 距离为:【1.0】
2类别为:【RED】 距离为:【2.236068】
3类别为:【RED】 距离为:【1.4142135】
4类别为:【RED】 距离为:【1.4142135】
5类别为:【BLACK】 距离为:【3.0】
6类别为:【BLACK】 距离为:【4.1231055】
7类别为:【BLACK】 距离为:【2.0】
8类别为:【BLACK】 距离为:【2.828427】
9类别为:【BLACK】 距离为:【2.236068】
Node{x=1.0, y=2.0, distance=2.0, type='RED'}
Node{x=2.0, y=2.0, distance=1.0, type='RED'}
Node{x=2.0, y=1.0, distance=1.4142135, type='RED'}
=========================================
Node{x=2.0, y=2.0, distance=1.0, type='RED'}
Node{x=2.0, y=1.0, distance=1.4142135, type='RED'}
Node{x=2.0, y=3.0, distance=1.4142135, type='RED'}
=========================================
目标节点所属类别为:RED