数据挖掘Java——KNN算法的实现

一、KNN算法的前置知识

k-近邻(kNN, k-NearestNeighbor)是在训练集中选取离输入的数据点最近的k个邻居,根据这个k个邻居中出现次数最多的类别(最大表决规则),作为该数据点的类别。

分类在数据挖掘中是一项非常重要的任务。分类的目的是学会一个分类函数或分类模型(也常常称作分类器),该模型能把数据库中的数据项映射到给定类型中的某一个类别。分类可用于预测。预测的目的是从历史数据记录中自动推导出对给定数据的趋势描述,从而能对未来数据进行预测。统计学中常用的预测方法是回归。数据挖掘中的分类和统计学中的回归方法是一对相互联系又有区别的概念。一般地,分类的输出是离散的类别值,而回归的输出则是连续数值。

相似性:给定一个数据库D={t1,t2,…,tn}和一组类C={C1,C2,…,Cm}。对于任意的元组ti={ti1,ti2,…,tik}∈D,如果存在一个Cj∈C,使得sim(ti,Cj)≥sim(ti,Cp),存在Cp∈C,Cp≠Cj,则ti被分配到类Cj中,其中sim(ti,Cj)称为相似性。在实际的计算中往往用距离来表征,距离越近,相似性越大,距离越远,相似性越小。

为了计算相似性,需要首先得到表示每个类的向量。计算方法有多种,例如代表每个类的向量可以通过计算每个类的中心来表示。另外,在模式识别中,一个预先定义的图像用于代表每个类,分类就是把待分类的样例与预先定义的图像进行比较。

二、KNN算法的基本思想

KNN算法的思想比较简单。假定每个类包含多个训练数据,且每个训练数据都有一个唯一的类别标记,KNN算法的主要思想就是计算每个训练数据到待分类元组的距离,取和待分类元组距离最近的k个训练数据,k个数据中哪个类别的训练数据占多数,则待分类元组就属于哪个类别。

三、KNN算法和强关联规则挖掘的例子

KNN算法例子
数据挖掘Java——KNN算法的实现_第1张图片
数据挖掘Java——KNN算法的实现_第2张图片

四、KNN算法的实现过程

实验内容
某班有14个同学,已登记身高及等级,新同学易昌,身高174cm,等级是什么。请用knn算法进行分类识别,其中k=5。
数据挖掘Java——KNN算法的实现_第3张图片

实验思路
(1)定义学生类Student,学生类中定义姓名,身高,等级等属性,利用lombok依赖的@Data注解对Student类的get和set方法进行注入。定义初始数据集,定义14个实体学生类,将14个实体Student类添加到初始数据集dataList中。
(2)调用initData()方法对数据集进行初始化,定义一个Student类stuV0并对其姓名和身高进行实例化作为输入,调用Knn()方法,得到含有等级的Student类对象student,将对象student进行输出。
(3)在Knn()方法体内部,初始时将数据集的前5项加入到categoryList集合中,categoryList集合用于存放距离stuV0最近的k个学生,只是最初存放数据集的前5项而已。遍历数据集dataList,计算stuV0距离从数据集的第6项开始的每一项的距离v0Tod,另外调用getCalculate()方法计算出stuV0距离categoryList集合中最远的学生对象stuU,若stuU距离stuV0的身高距离uToV0大于v0Tod,则在categoryList中去除掉stuU,将数据集中的该项加入到categoryList集合中来。
(4)在getCalculate()方法体内部,定义变量maxHeight用于存放stuV0与类别集合categoryList的最远距离,定义Student类对象resultStu用于存放要返回的学生,即stuV0与类别集合categoryList距离最远的学生。遍历categoryList集合,若stuU与stuV0之间的距离大于maxHeight,则将v0ToU赋值给maxHeight,将stuU赋值给resultStu,最终将Student类对象resultStu返回。
(5)调用getCategoryStudent()方法找出categoryList中同等级占比最多的学生等级rank,最终用rank实例化stuV0的rank属性,返回stuV0。
(6)在getCategoryStudent()方法体内部,遍历categoryList,找出同等级占比最多的学生等级,其实就是找等级“高”、等级“中等”,等级“矮”的学生哪个类别的学生数量最多,并将学生数量最多的那个类别返回。

实现源码

学生类
package com.data.mining.entity;

import lombok.Data;

@Data
public class Student {
    private String name;
    private int height;
    private String rank;

    public Student(){}

    public Student(String n, int h){
        name = n;
        height = h;
    }

    public Student(String n, int h, String r){
        name = n;
        height = h;
        rank = r;
    }
}

KNN算法实现代码
package com.data.mining.main;

import com.data.mining.entity.Student;

import java.util.ArrayList;
import java.util.List;

public class Knn {
    //定义初始数据集
    public static List<Student> dataList = new ArrayList<>();

    public static void main(String[] args) {
        initData();
        Student stuV0 = new Student("易昌", 174);
        Student student = Knn(stuV0);
        System.out.println(student.toString());
    }

    /**
     * 找出同等级占比最多的学生等级
     * @param categoryList
     * @return
     */
    public static String getCategoryStudent(List<Student> categoryList){
        int tallCount = 0;
        int midCount = 0;
        int smallCount = 0;
        for (Student stuU : categoryList) {
            if (stuU.getRank().equals("高")) tallCount++;
            else if (stuU.getRank().equals("中等")) midCount++;
            else smallCount++;
        }
        int max = 0;
        max = tallCount > midCount ? tallCount : midCount;
        max = smallCount > max ? smallCount : max;
        if (smallCount == max) return "矮";
        else if (tallCount == max) return "高";
        else return "中等";
    }

    /**
     * 计算出stuV0距离categoryList集合中最远的学生对象
     * @param stuV0
     * @param categoryList
     * @return
     */
    public static Student getCalculate(Student stuV0, List<Student> categoryList) {
        int maxHeight = 0; //存放stuV0与类别集合categoryList的最远距离
        Student resultStu = new Student(); //存放要返回的学生,即stuV0与类别集合categoryList距离最远的学生
        for (Student stuU : categoryList) {
            int v0ToU = Math.abs(stuV0.getHeight() - stuU.getHeight()); //stuV0与stuU的距离
            if (v0ToU > maxHeight){ //stuV0与stuU的距离大于maxHeight,则对maxHeight和resultStu进行更新
                maxHeight = v0ToU;
                resultStu = stuU;
            }
        }
        return resultStu;
    }

    /**
     * 对输入学生类进行Knn算法实例化该学生的等级后,将该学生返回
     * @param stuV0
     * @return
     */
    public static Student Knn(Student stuV0){
        List<Student> categoryList = new ArrayList<>(); //存放距离stuV0最近的k个学生,最初存放数据集的前5项
        for (int i = 0; i < dataList.size(); i++) {
            if (i < 5) categoryList.add(dataList.get(i));
            else {
                //stuV0距离剩下数据集中某项的距离
                int v0Tod = Math.abs(stuV0.getHeight() - dataList.get(i).getHeight());
                Student stuU =  getCalculate(stuV0, categoryList); //存放stuV0距离类别集合中最远的学生
                int uToV0 = Math.abs(stuU.getHeight() - stuV0.getHeight());
                if (uToV0 > v0Tod){
                    categoryList.remove(stuU); //在集合列表中去除stuU
                    categoryList.add(dataList.get(i));
                }
            }
        }
        System.out.println(categoryList.toString());
        String rank = getCategoryStudent(categoryList);
        stuV0.setRank(rank);

        return stuV0;
    }


    /**
     * 初始化数据
     */
    public static void initData(){
        Student s1 = new Student("李丽", 150, "矮");
        Student s2 = new Student("吉米", 192, "高");
        Student s3 = new Student("马大华", 170, "中等");
        Student s4 = new Student("王晓华", 173, "中等");
        Student s5 = new Student("刘敏", 160, "矮");
        Student s6 = new Student("张强", 175, "中等");
        Student s7 = new Student("李秦", 160, "矮");
        Student s8 = new Student("王壮", 190, "高");
        Student s9 = new Student("刘冰", 168, "中等");
        Student s10 = new Student("张喆", 178, "中等");
        Student s11 = new Student("杨毅", 170, "中等");
        Student s12 = new Student("徐田", 168, "中等");
        Student s13 = new Student("高杰", 165, "矮");
        Student s14 = new Student("张晓", 178, "中等");

        dataList.add(s1);
        dataList.add(s2);
        dataList.add(s3);
        dataList.add(s4);
        dataList.add(s5);
        dataList.add(s6);
        dataList.add(s7);
        dataList.add(s8);
        dataList.add(s9);
        dataList.add(s10);
        dataList.add(s11);
        dataList.add(s12);
        dataList.add(s13);
        dataList.add(s14);
    }
}

实验结果
数据挖掘Java——KNN算法的实现_第4张图片
另外,除了输出题目要求的学生身高等级外,笔者还输出了输入学生所在的簇,以对照题目确保结果正确。
数据挖掘Java——KNN算法的实现_第5张图片

五、实验总结

本实验结果笔者并不保证一定是正确的,笔者仅仅是提供一种使用Java语言实现KNN算法的思路。因为实验并没有给答案,笔者已将网络上有答案的实验数据输入程序后,程序输出的结果和答案一致,所以问题应该不大。若有写的不到位的地方,还请各位多多指点!
笔者主页还有其他数据挖掘算法的总结,欢迎各位光顾!

你可能感兴趣的:(数据挖掘,算法,数据挖掘,java,分类算法)