Julia机器学习--KNN算法BallTree 结构分类统计,并使用圆绘制图表

NearestNeighbors是Julia中一个效率比较高的KNN分类统计代码库,它提供了BallTree,KDTree等多种数据结构。

这里使用BallTree结构,并绘制图表。这里仍然使用鸢尾花数据

代码示例

using RDatasets
using DataFrames
using CSV
using NearestNeighbors
using Colors
using PyPlot
using PyCall
import PyPlot:plot
import NearestNeighbors.HyperSphere
@pyimport matplotlib.patches as patch
# y = load("D:/leaning/Julia/pkg-other/Rdatasets/csv/datasets/iris.csv") |> DataFrame

iris = dataset("datasets", "iris"); # load the data
# iris = DataFrame(CSV.File(joinpath(dirname(pathof(DataFrames)),"D:/leaning/Julia/pkg-other/Rdatasets/csv/datasets/iris.csv")));
show(iris)
iris[:, 1:4]
features = collect(Matrix(iris[:, 1:4])'); # features to use for clustering

tree = BallTree(features, Euclidean(); leafsize = 50)


# 跳过非叶子节点
offset = tree.tree_data.n_internal_nodes + 1
nleafs = tree.tree_data.n_leafs

# 叶子节点的范围
index_range = offset: offset + nleafs - 1

# 生成颜色图谱
cols = distinguishable_colors(length(index_range), RGB(0,0,0))

# 创建图片
cfig = figure()
ax = cfig[:add_subplot](1,1,1)
ax[:set_aspect]("equal")
axis((2.5,9.0,1.0,5.0))

# 坐标上添加一个圆
function add_sphere(ax, hs::HyperSphere, col)
    ell = patch.Circle(hs.center, radius = hs.r, facecolor="none", edgecolor=col)
    ax[:add_artist](ell)
end


for (i, idx) = enumerate(index_range)
    col = cols[i]
    # 获取决策树中的叶子节点
    range = NearestNeighbors.get_leaf_range(tree.tree_data, idx)
    d = tree.data[range]
    for idex in 1:length(d)
        point = collect(d[idex])
        # 先画点
        plot(vec(point[1,:]), vec(point[2,:]), "*", color = (col.r, col.g, col.b))
    end
    # 设置圆
    sphere = tree.hyper_spheres[idx]
    add_sphere(ax, sphere, (col.r, col.g, col.b))
end

title("Leaf nodes with their corresponding points")
cfig[:savefig]("iris.png")

绘制的图表

Julia机器学习--KNN算法BallTree 结构分类统计,并使用圆绘制图表_第1张图片

 

你可能感兴趣的:(Julia,机器学习,数据分析)