基于weka手工实现K-means

一、K-means聚类算法

K均值聚类(K-means clustering)是一种常见的无监督学习算法,用于将数据集中的样本划分为K个不同的类别或簇。它通过最小化样本点与所属簇中心点之间的距离来确定最佳的簇划分。

K均值聚类的基本思想如下:

  1. 随机选择K个初始聚类中心(质心)。
  2. 对于每个样本,计算其与各个聚类中心之间的距离,并将样本分配到距离最近的聚类中心所代表的簇。
  3. 对于每个簇,计算簇中样本的均值,并将该均值作为新的聚类中心。
  4. 重复步骤2和步骤3,直到聚类中心不再变化或达到预定的迭代次数。

K均值聚类的关键是如何选择初始的聚类中心。常见的方法是随机选择数据集中的K个样本作为初始聚类中心,或者使用一些启发式的方法来选择。

K均值聚类的优点包括简单易实现、计算效率高和可扩展性好。它在许多领域中被广泛应用,如数据分析、图像处理、模式识别等。然而,K均值聚类也存在一些限制,例如对于初始聚类中心的敏感性、对于离群值的影响较大以及需要事先指定簇的个数K等。

在实际应用中,可以根据实际问题和数据集的特点来选择合适的K值,并进行多次运行以获得更稳定的结果。此外,K均值聚类也可以与其他算法相结合,如层次聚类(hierarchical clustering)和密度聚类(density-based clustering),以获得更好的聚类效果。

总的来说,K均值聚类是一种常用的无监督学习算法,用于将数据集中的样本划分为K个簇。它简单而高效,适用于许多聚类问题。然而,在使用K均值聚类时需要注意选择初始聚类中心和合适的K值,以及对其限制和局限性的认识。

二、基于weka手工实现K-means聚类算法

package weka.clusterers.myf;

import weka.clusterers.RandomizableClusterer;
import weka.core.*;

import java.util.*;

/**
 * @author YFMan
 * @Description 自定义的 KMeans 聚类器
 * @Date 2023/6/8 15:01
 */
public class myKMeans extends RandomizableClusterer {
    // 聚类中心的数量
    private int m_NumClusters = 2;

    // 不同聚类中心的集合
    private Instances m_ClusterCentroids;

    // 聚类的最大迭代次数
    private int m_MaxIterations = 500;

    // 追踪收敛前完成的迭代次数
    private int m_Iterations = 0;

    // 构造函数
    public myKMeans() {
        super();
        // 设置随机种子
        m_SeedDefault = 10;
        setSeed(m_SeedDefault);
    }


    /*
     * @Author YFMan
     * @Description //基类定义的接口,必须要实现
     * @Date 2023/6/8 16:37
     * @Param []
     * @return weka.core.Capabilities
     **/
    @Override
    public Capabilities getCapabilities() {
        Capabilities result = super.getCapabilities();
        result.disableAll();
        result.enable(Capabilities.Capability.NO_CLASS);

        // attributes
        result.enable(Capabilities.Capability.NOMINAL_ATTRIBUTES);
        result.enable(Capabilities.Capability.NUMERIC_ATTRIBUTES);
        result.enable(Capabilities.Capability.MISSING_VALUES);

        return result;
    }

    /*
     * @Author YFMan
     * @Description //进行聚类
     * @Date 2023/6/8 16:38
     * @Param [data 用于聚类的数据集]
     * @return void
     **/
    @Override
    public void buildClusterer(Instances instances) throws Exception {
        // 迭代次数
        m_Iterations = 0;

        // 初始化聚类中心
        m_ClusterCentroids = new Instances(instances, m_NumClusters);

        // 每个样本属于哪个聚类中心
        int[] clusterAssignments = new int[instances.numInstances()];

        // 伪随机数生成器
        Random RandomO = new Random(getSeed());
        int instIndex;
        HashSet<Instance> initC = new HashSet<>();

        // 初始化聚类中心,随机选择 m_NumClusters 个样本作为聚类中心
        for (int j = instances.numInstances() - 1; j >= 0; j--) {
            instIndex = RandomO.nextInt(j + 1);

            if (!initC.contains(instances.instance(instIndex))) {
                m_ClusterCentroids.add(instances.instance(instIndex));
                initC.add(instances.instance(instIndex));
            }
            instances.swap(j, instIndex);

            if (m_ClusterCentroids.numInstances() == m_NumClusters) {
                break;
            }
        }

        boolean converged = false;
        // 用于存储每个聚类中心的样本集合
        Instances[] tempI = new Instances[m_NumClusters];
        while (!converged) {
            m_Iterations++;
            converged = true;
            // 计算每个样本 属于哪个聚类中心
            for (int i = 0; i < instances.numInstances(); i++) {
                Instance toCluster = instances.instance(i);
                int newC = clusterInstance(toCluster);
                // 如果样本所属的聚类中心发生变化,则说明还没有收敛
                if (newC != clusterAssignments[i]) {
                    converged = false;
                }
                clusterAssignments[i] = newC;
            }

            // 重新计算聚类中心
            m_ClusterCentroids = new Instances(instances, m_NumClusters);
            for (int i = 0; i < m_NumClusters; i++) {
                tempI[i] = new Instances(instances, 0);
            }
            for (int i = 0; i < instances.numInstances(); i++) {
                tempI[clusterAssignments[i]].add(instances.instance(i));
            }
            // 重新计算聚类中心
            for (int i = 0; i < m_NumClusters; i++) {
                // 计算每个属性的平均值
                m_ClusterCentroids.add(calculateCentroid(tempI[i]));
            }
            // 如果迭代次数达到最大值,则强制结束
            if (m_Iterations == m_MaxIterations) {
                converged = true;
            }
        }
    }

    /*
     * @Author YFMan
     * @Description //计算某个聚类中心的中心点
     * @Date 2023/6/8 16:57
     * @Param [instances 聚类中心的样本集合]
     * @return weka.core.Instance 聚类中心的中心点
     **/
    private Instance calculateCentroid(Instances instances) {

        int numInst = instances.numInstances();
        int numAttr = instances.numAttributes();

        Instance centroid = new Instance(numAttr);

        double sum;

        for (int i = 0; i < numAttr; i++) {
            sum = 0;
            for (int j = 0; j < numInst; j++) {
                sum += instances.instance(j).value(i);
            }
            centroid.setValue(i, sum / numInst);
        }

        return centroid;
    }

    /*
     * @Author YFMan
     * @Description //计算两个属性全为数值类型的样本之间的距离(欧式距离)
     * @Date 2023/6/8 16:47
     * @Param [first 第一个样例, second 第二个样例]
     * @return double
     **/
    private double distance(Instance first, Instance second) {
        // 定义欧式距离
        double euclideanDistance = 0;
        // 定义overlapping距离
        double overlappingDistance = 0;

        for (int index = 0; index < first.numAttributes(); index++) {
            if (index == first.classIndex()) {
                continue;
            }
            // 如果是数值类型的属性,则计算欧式距离
            if (first.attribute(index).isNumeric()) {
                double dis = first.value(index) - second.value(index);
                euclideanDistance += dis * dis;
            } else {
                // 如果是标称类型的属性,则计算是否相等
                if (first.value(index) != second.value(index)) {
                    overlappingDistance += 1;
                }
            }
        }

        return Math.sqrt(euclideanDistance) + overlappingDistance;
    }

    /*
     * @Author YFMan
     * @Description //对一个给定的样例进行分类
     * @Date 2023/6/8 16:50
     * @Param [instance 给定的样例]
     * @return int 返回样例所属的聚类中心id
     **/
    @Override
    public int clusterInstance(Instance instance) throws Exception {
        double minDist = Double.MAX_VALUE;
        int bestCluster = 0;
        for (int i = 0; i < m_NumClusters; i++) {
            double dist = distance(instance, m_ClusterCentroids.instance(i));
            if (dist < minDist) {
                minDist = dist;
                bestCluster = i;
            }
        }
        return bestCluster;
    }

    /*
     * @Author YFMan
     * @Description //返回聚类中心的数量
     * @Date 2023/6/8 16:34
     * @Param []
     * @return int
     **/
    @Override
    public int numberOfClusters() throws Exception {
        return m_NumClusters;
    }

    /*
     * @Author YFMan
     * @Description //主函数
     * @Date 2023/6/8 16:33
     * @Param [argv 命令行参数]
     * @return void
     **/
    public static void main(String[] argv) {
        runClusterer(new myKMeans(), argv);
    }
}

你可能感兴趣的:(机器学习,数据挖掘,kmeans,机器学习)