pyspark环境使用dbscan聚类经纬度

DBSCAN on Spark我关注到的有三种实现

  1. https://github.com/alitouka/spark_dbscan  scala写的。作者还带有两个R写的小工具,which will help you choose parameters of the DBSCAN algorithm。
  2. https://github.com/irvingc/dbscan-on-spark  用scala写的,据说 占用较大内存。
  3. An Implementation of DBSCAN on PySpark  pyspark的,简单好用,但是改为经纬度坐标后没跑出来。作者对求解过程做了优化,原理是(m+1)ε ≥ d(x, c) ≥ mε then we can filter out points y and z if d(y, c) < (m-1)ε and d(z, c) > (m+2)ε ,基于此对数据做分区,减少了重复计算(不同分区说明距离太远,没必要计算判断了)。在分区边缘的点怎么处理呢?就是移动分区半个epsilon的距离,叠加所属的分区,使边缘的点也能被聚类到合适的cluster。

但是这三种我都没有使用。我用的sklearn的dbscan用做UDF,跑出来效果还不错。pyspark真简洁啊! :-)

# import findspark
# findspark.init()
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.types import *
import pandas as pd
import numpy as np
# import os
from sklearn.cluster import DBSCAN,KMeans


def dbscan_x(coords):
    kms_per_radian = 6371.0086
    # 半径200米内 20个点 
    epsilon = 0.2 / kms_per_radian
    db = DBSCAN(eps=epsilon, min_samples=20, algorithm='ball_tree', metric='haversine').fit(np.radians(coords))
    cluster_labels = db.labels_
    num_clusters = len(set(cluster_labels) - set([-1]))

    result = []
    coords_np = np.array(coords)
    kmeans = KMeans(n_clusters=1, n_init=1, max_iter=10, random_state=7)
    for n in range(num_clusters):
        # get center of Cluster 'n'
        one_cluster = coords_np[cluster_labels == n]
        kk = kmeans.fit(one_cluster)
        center = kk.cluster_centers_
        latlng = center[0].tolist()
        result.append([n, latlng[1], latlng[0]])
    return result


if __name__ == "__main__":
    spark = SparkSession.builder \
        .appName("stop_cluster") \
        .getOrCreate()
    
    # data_file = os.path.normpath('E:\\datas\\trj_data\\tmp_stop3_loc.csv')
    # traj_schema = StructType([
    #     StructField("cid", StringType()),
    #     StructField("lng", FloatType()), StructField("lat", FloatType())
    # ])
    # dataDF = spark.read.csv(data_file, schema=traj_schema)

    dataDF = spark.sql("select cid, lng, lat from tmp.tmp_some_data")

    schema_dbs = ArrayType(StructType([
        StructField("clusterid", IntegerType(), False),
        StructField("lng", FloatType(), False),
        StructField("lat", FloatType(), False)
    ]))
    udf_dbscan = F.udf(lambda x: dbscan_x(x), schema_dbs)
    dataDF = dataDF.withColumn('point', F.array(F.col('lat'),F.col('lng')) ) \
                    .groupBy('cid').agg(F.collect_list('point').alias('point_list')) \
                    .withColumn('cluster', udf_dbscan(F.col('point_list')))
    
    resultDF = dataDF.withColumn('centers', F.explode('cluster')) \
                    .select('cid', F.col('centers').getItem('lng').alias('lng'), 
                            F.col('centers').getItem('lat').alias('lat'), 
                            F.col('centers').getItem('clusterid').alias('clusterid')
                        )
    resultDF.write.mode("overwrite").format("orc").saveAsTable("tmp.tmp_cluster_ret")
    resultDF.show()


    spark.stop()

 

你可能感兴趣的:(数据挖掘)