spark自带的kmeans例子分析

import sys

import numpy as np
from pyspark.sql import SparkSession
#该函数主要是将文件的string类型转换成float类型
def parseVector(line):
    return np.array([float(x) for x in line.split(' ')])
#该函数将点分配到点集中,返回的是点集的index
#其中传入的参数p是需分配的点的值(可以看成矢量,假设是m维),centers是目前的中心点的值(可以看成n*m维的矩阵,其中n指的是中心点的个数)
def closestPoint(p, centers):
    bestIndex = 0
    closest = float("+inf")
    for i in range(len(centers)):
        tempDist = np.sum((p - centers[i]) ** 2)#计算需分配的点与第i个中心点的距离
        if tempDist < closest:
            closest = tempDist
            bestIndex = i
    return bestIndex


if __name__ == "__main__":

    if len(sys.argv) != 4:
        print("Usage: kmeans   ", file=sys.stderr)
        exit(-1)

    print("""WARN: This is a naive implementation of KMeans Clustering and is given
       as an example! Please refer to examples/src/main/python/ml/kmeans_example.py for an
       example on how to use ML's KMeans implementation.""", file=sys.stderr)

    spark = SparkSession\
        .builder\
        .appName("PythonKMeans")\
        .getOrCreate()

    lines = spark.read.text(sys.argv[1]).rdd.map(lambda r: r[0])
    data = lines.map(parseVector).cache()
    K = int(sys.argv[2])  #设置的中心点的个数
    convergeDist = float(sys.argv[3])  #设置的阈值

    kPoints = data.takeSample(False, K, 1)  #在data点集中随机选取K个点作为中心点
    tempDist = 1.0

    while tempDist > convergeDist:
        closest = data.map(
            lambda p: (closestPoint(p, kPoints), (p, 1)))#返回的结果是(p点所分配的集合的index值,(p点的值,1))
        pointStats = closest.reduceByKey(   #reduceByKey针对具有相同键(在这里指的是p点被分配到相同的点集)的二元组
#则p1_c1和p2_c2指的都是二元组的value值,也就是(p,1)。所以,该句子表示将在同一个点集上的点的(p,1)分别求和,其中别忘了p是一个矢量
lambda p1_c1, p2_c2: (p1_c1[0] + p2_c2[0], p1_c1[1] + p2_c2[1])) #得到的是(点集的index,(点值的求和,该点集的点的个数)) newPoints = pointStats.map( # 返回(点集的index,点值求和/点的个数(矢量除法)),作为新的中心点 lambda st: (st[0], st[1][0] / st[1][1])).collect()        #计算新旧中心点的距离差 tempDist = sum(np.sum((kPoints[iK] - p) ** 2) for (iK, p) in newPoints) for (iK, p) in newPoints: kPoints[iK] = p #设置新的中心点的值 print("Final centers: " + str(kPoints)) spark.stop()

有可能表述不是很准确,但是能看懂就行。(随机可能学多了,喜欢用矩阵看问题)

你可能感兴趣的:(spark自带的kmeans例子分析)