基于weka手工实现KNN

一、KNN模型

K最近邻(K-Nearest Neighbors,简称KNN)算法是一种常用的基于实例的监督学习算法。它可以用于分类和回归问题,并且是一种非常直观和简单的机器学习算法。

KNN算法的基本思想是:对于一个新的样本数据,在训练数据集中找到与其最接近的K个邻居,然后根据这K个邻居的标签或属性进行预测。预测的过程即为统计K个邻居中最常见的标签(对于分类问题)或计算K个邻居的平均值(对于回归问题)。

KNN算法的主要步骤如下:

  1. 准备训练数据集:包括样本数据和对应的标签(对于分类问题)或属性值(对于回归问题)。
  2. 选择一个合适的距离度量方法,常用的有欧氏距离、overlapping距离等。
  3. 对于一个新的样本数据,计算其与训练数据集中所有样本的距离。
  4. 根据距离的大小,选取与新样本距离最近的K个邻居。
  5. 对于分类问题,统计K个邻居中各类别的出现频率,选择频率最高的类别作为预测结果。
  6. 对于回归问题,计算K个邻居的平均值作为预测结果。
  7. 输出预测结果。

KNN算法的核心思想是基于样本之间的相似性进行预测。它假设相似的样本在特征空间中具有相似的输出,因此通过寻找最近的邻居来进行预测。KNN算法的优点包括简单易懂、无需模型训练和快速预测。然而,它也有一些限制,如对于大规模数据集的计算开销较大,对于高维数据和不平衡数据的处理能力较弱等。

在使用KNN算法时,需要注意选择合适的K值,较小的K值可能会导致对噪声敏感,较大的K值可能会导致模糊性增加。此外,对数据进行预处理(如特征缩放)也可能对KNN的性能产生影响。

总的来说,KNN算法是一种简单但有效的机器学习算法,适用于小规模数据集和简单分类或回归任务。它在实际应用中被广泛使用,特别是在模式识别、推荐系统和数据挖掘等领域。

关于KNN算法更加详细的介绍可以参考这篇博客:机器学习08:最近邻学习.

二、基于weka手工实现KNN算法

package weka.classifiers.myf;

import weka.classifiers.Classifier;
import weka.core.*;

/**
 * @author YFMan
 * @Description 自定义的 KNN 分类器
 * @Date 2023/5/25 14:35
 */
public class myKNN extends Classifier {
    // 训练数据集
    protected Instances m_Train;

    // 类别数
    protected int m_NumClasses;

    // 设置 kNN 参数
    protected int m_kNN = 3;

    // 属性数
    protected double m_NumAttributesUsed;

    /*
     * @Author YFMan
     * @Description 根据训练数据 建立 KNN 模型
     * @Date 2023/5/25 18:27
     * @Param [instances]
     * @return void
     **/
    public void buildClassifier(Instances instances) throws Exception {
        // 初始化类别数
        m_NumClasses = instances.numClasses();
        // 初始化训练集
        m_Train = instances;

        // 初始化属性数
        m_NumAttributesUsed = 0.0;
        for (int i = 0; i < m_Train.numAttributes(); i++) {
            if (i != m_Train.classIndex()) {
                m_NumAttributesUsed += 1.0;
            }
        }
    }

    /*
     * @Author YFMan
     * @Description 对单个实例进行分类
     * @Date 2023/5/25 18:27
     * @Param [instance]
     * @return double[]
     **/
    public double[] distributionForInstance(Instance instance) throws Exception {
        // 计算 instance 与 instances 中每个实例的欧式距离
        double[] distances = new double[m_Train.numInstances()];
        for (int i = 0; i < m_Train.numInstances(); i++) {
            distances[i] = 0;
            // 计算 instance 与 instances 中每个实例的 d^2
            for (int j = 0; j < m_Train.numAttributes(); j++) {
                if (j != m_Train.classIndex()) {
                    // 计算 overlap 距离
//                    if(instance.value(j)!=m_Train.instance(i).value(j)){
//                        distances[i] += 1;
//                    }
                    // 计算 Euclidean 距离
                    double diff = instance.value(j) - m_Train.instance(i).value(j);
                    distances[i] += diff * diff;
                }
            }
            // 对 d^2 开根号
            distances[i] = Math.sqrt(distances[i]);
        }

        // 对 distances 进行排序 (得到的是排序后的下标)
        int[] sortedDistances = Utils.sort(distances);

        // 计算 distribution
        double[] distribution = new double[m_NumClasses];
        for (int i=0;i<m_NumClasses;i++){
            distribution[i] = 1.0;
        }
        int total = m_NumClasses;
        for (int i = 0; i < m_kNN; i++) {
            distribution[(int) m_Train.instance(sortedDistances[i]).classValue()] += 1.0;
            total += 1;
        }

        // 归一化
        for (int i=0;i<m_NumClasses;i++){
            distribution[i] /= total;
        }
        // 返回各个类别的 distribution
        return distribution;
    }

    /*
     * @Author YFMan
     * @Description 主函数
     * @Date 2023/5/25 18:27
     * @Param [argv] 命令行参数
     * @return void
     **/
    public static void main(String[] argv) {
        runClassifier(new myKNN(), argv);
    }
}

你可能感兴趣的:(机器学习,数据挖掘,人工智能,机器学习,算法)