python数据分析中使用plt散点图展示DBSCAN聚类结果

做数据分析时可能会用到聚类,此时我们可以借助散点图直观地查看聚类结果并调试参数。
下面正式开始, 先说明一下数据:

import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import DBSCAN

plt.style.use('ggplot')  # 美化`

x = np.array([171, 649, 172, 653, 170, 636, 179, 651, 175, 356, 644, 173, 651, 166, 209, 257, 649, 173, 652, 177, 651,
               180, 649, 181, 648, 183, 649, 178, 652, 181, 648, 180, 650, 178, 649, 177, 292, 650, 180, 649, 177, 285,
               644, 177, 649, 178, 221, 649, 178, 648, 175, 651, 182, 653, 180, 275, 651, 178, 646, 177, 652, 183, 650,
               182, 165, 280, 180, 177, 177])`
y = np.array([764, 766, 806, 806, 852, 848, 893, 893, 937, 937, 935, 977, 981, 1017, 1017, 1017, 1024, 1067, 1068, 1109,
               1112, 1156, 1156, 1197, 1199, 1242, 1239, 1282, 1283, 1327, 1328, 1372, 1373, 1416, 1415, 1458, 1458, 1458,
               1500, 1502, 1543, 1543, 1544, 1586, 1590, 1628, 1628, 1634, 1676, 1677, 1713, 1721, 1765, 1762, 1805, 1805,
               1806, 1848, 1848, 1890, 1896, 1938, 1938, 1983, 2015, 2015, 2070, 2109, 2153])
data = np.vstack((x2, y2)).T

mean_height, mean_width = 36, 29

其中(x, y)是每个物体的中心坐标,另外还有物体的平均高度和宽度。接下来进行聚类,并按聚类结果整合数据。

# 因为需要甄别异常值,所有使用DBSCAN聚类
db = DBSCAN(eps=1.5*mean_height, min_samples=2,).fit(data) 
labels = db.labels_  # -1为异常值

cluster_label = list(set(labels))
cluster_data = {k: [] for k in cluster_label}  # key=聚类标签, value=该列别下的所有数据
for index, ele in enumerate(labels):
    cluster_data[ele].append(data[index])

然后用散点图直观的看一下聚类结果。

color = ["#6A5ACD", '#228B22', '#B8860B', '#B22222', '#FF69B4',
         '#1E90FF', '#4B0082', "#00FF7F", '#FFFACD', '#0000FF']*10  #  万一类别太多,颜色不够用

for index, ele in enumerate(cluster_label):
    scatter_x = [x[0] for x in cluster_data[ele]]
    scatter_y = [y[1] for y in cluster_data[ele]]

    if cluster_label[index] == -1:
        s = 90  # 异常点突出显示
    else:
        s = 30
    #  颜色,大小,标签
    plt.scatter(scatter_x, scatter_y, c=color[index], s=s, label=f"c_{cluster_label[index]}")

plt.legend(loc=0, ncol=2,)  # 加上图列
plt.show()

最后是结果图:

聚类结果.png

当聚类邻域eps=1.5*mean_height时,聚类产生了4个类,图上5个红色点就是初步判断出来的异常点,其他三类是正常的聚类结果。但是x轴200~300之间的三个蓝点实际上也是异常点,c_0, c_2实际上是一类,这些在聚类分析里面没有判断出来。实际生产上要借助其他信息判断了。
再看一下第一次eps=math.sqrt(mean_height*mean_height + mean_width*mean_width)时的聚类结果:
聚类结果1.png

异常点是多识别出来一个,但是类别太多了。还是eps=1.5*mean_height时效果好一些。

你可能感兴趣的:(python数据分析中使用plt散点图展示DBSCAN聚类结果)