python手写聚类算法:Kmeans&DBscan算法

python手写聚类算法:Kmeans&DBscan

文章目录

  • python手写聚类算法:Kmeans&DBscan
    • 算法思路以及步骤介绍
    • 手写代码
      • Kmeans
    • 手写DBscan
    • 关于手写程序的说明

算法思路以及步骤介绍

首先,我们分别介绍一下Kmeans算法以及DBSCAN算法。

Kmeans算法步骤:首先先随机的选择K个点(这里的K是超参数),这K个点作为中心点,对于剩下的所有的点,计算剩下的点和这三个点的距离,距离中最小的,认为属于这个类。在更新完一遍之后,计算类中的均值向量作为新的中心,再次重复上面的步骤,直到类的中心不变。主要的思路就是:每一类一定有一个中心,这个类中的对象一定离自己所属的类的中心最近,在和所有的中心的距离中。

DBSCAN的思路是通过定义了核心对象和密度直达来进行聚类(优点像感染式的聚类),其中核心对象指的是,此对象的较近的距离内有足够多的样本的点,这个距离和点的个数都是我们来决定的。这个域里面的对象都叫做密度直达的点,对于这些点,如果也是一个核心对象的话,就把它们的领域也加入到这个类里面,以此类推,直到所有的核心对象都已经考虑过了,停止算法。主要的思路就是,如果我们是一类,我们一定离的比较足够近,而如果你和其他的一个足够近,那么说明可能大家都是一类。

手写代码

Kmeans

# 输入没有标签的点
# 返回分类的情况(类别及对应的样本)
# 实现输入新的点,可以完成对于点类别的预测
class Kmeans:
    def __init__(self):
        pass

    def distance(self, Xi):
        import numpy as np
        dist = []
        for k in self.centroid:  # 对模型里面所有的样本进行遍历
            dist.append(np.linalg.norm(k - Xi, ord=self.p))
        dist0 = list(enumerate(dist))
        near = sorted(dist0, key=lambda x: x[1])[0][0]
        return near

    def predict(self, Xi):
        import numpy as np
        num = Xi.shape[0]
        dist = []
        for i in range(num):
            dist.append([])
            for k in self.centroid:  # 对模型里面所有的样本进行遍历
                dist[i].append(np.linalg.norm(k - Xi, ord=self.p))
            dist0 = list(enumerate(dist[i]))
            near = sorted(dist0, key=lambda x: x[1])[0][0]
            dist[i] = near
        return dist

    def fit(self, X, k=3, p=1):  # 开始训练
        import random
        import numpy as np
        import copy
        self.k = k
        self.p = p
        self.centroid = []
        self.data = X
        num, col = X.shape
        pick = random.sample(range(0, num), self.k)  # 随机取k个样本点
        # 初始化分类结果记录矩阵以及中心点
        for i in pick:
            self.centroid.append(X[i, :])
        flag = True
        while flag:
            classlist = []
            for i in range(self.k):
                classlist.append([])
            centroid = self.centroid.copy()  # 记录原来的中心
            for i in range(num):
                near = self.distance(self.data[i, :])
                classlist[near].append(self.data[i, :])
            # 更新中心点
            for i in range(self.k):
                self.centroid[i] = np.average(classlist[i], axis=0)
            if (np.array(self.centroid) - np.array(centroid)).all() == 0:  # 如果中心未改变停止迭代
                flag = False
        return self.centroid, classlist

手写DBscan

# 这一次我们学会了上面的编程经验,即对象有时候操作会比较的复杂,因为可能会出现两层列表反复操作,我们这里还是使用矩阵和索引完成对于程序的编写
import numpy as np

# 找到核心对象
def fit(X,eta,num):
    coreindex = {}  # 放所有的核心对象以及它们密度直达的点
    # 对每一个样本进行求解
    for i in range(X.shape[0]):
        dist = [] # 放所有可以密度直达的点
        for k in range(X.shape[0]):
            if i != k:
                distance = np.linalg.norm(X[i,:]-X[k,:], ord=2)
                if distance <= eta:
                    dist.append(k)
        if len(dist)>=num:   # 判断是不是核心对象
            coreindex[i]=dist
    return coreindex

# 输入核心对象进行训练
def train(coreindex):
    import random
    classindex = 0  # 记录一个类的名字
    classdict = {}
    index = list(coreindex.keys())  # 存放未找的核心对象
    while len(index)!=0:
         # 先随机找一个对象
        sample = random.randint(0,len(index)-1)
        sample = index[sample]
        classdict[classindex] = coreindex[sample]   # 把核心对象的密度直达的点放进去
        classdict[classindex].append(sample)  # 把这个点也放进去
        index.remove(sample)
        for i in classdict[classindex]:  # 在核心对象的域进行搜索
            if i in index:    # 如果域里面有的点也是核心对象的话,就加入这个点和它的成员。并且在核心对象的成员里面去掉所有的有关的点
                index.remove(i)
                for j in coreindex[i]:
                    if j in classdict[classindex]: # 不能重复加点
                        pass
                    else:
                        classdict[classindex].append(j)
        classindex += 1  # 寻找下一个类别
    return classdict

关于手写程序的说明

自己在设计程序的时候,读懂几个地方,就很好读懂全部了。首先在Kmeans里面,centroid表示的是中心的点,每一次都会留一下,并且更新一遍,停止的条件就是这个更新前后不变了,classlist = []这个是一个二维的列表,第一个维度表示的是第几类,第二个维度表示的是样本(我直接把横纵坐标都放进去了),其实更好的方式是只放样本的索引进去。这也是我们最后会返回的结果。

在DBSCAN算法中,我们设计了coreindex = {},一个字典,键用来放所有的核心对象,值用来放核心对象的邻域,classdict,这个也是一个字典,键表示类别的编号,值是样本的索引。其实设计好这些之后,其余的按照程序的顺序,就很好很自然的可以写出来和读懂。感觉了解程序最需要知道的就是输入输出,知道这些设计之后,基本就是编写或者使用就很舒服了。

你可能感兴趣的:(算法,Algorithm,聚类,python,算法)