在瑞士卷数据集上使用python绘制测地线

在学习西瓜书上的流形学习时,我们学习到了测地线的概念,那么如何画测地线呢?本文将使用python简单的实现一下在瑞士卷数据集上测地线的绘制。

 

文章目录

前言

一、具体步骤

1.引入库

2.读入数据

3.绘图

 4.测地线的绘制

 4.1首先对每个点基于欧 氏距离找出其近邻点

 4.2建立一个近邻连接图

 4.3找出从源点到终点的最短路径

​编辑

 4.4绘制

5.结果展示

总结

源代码:


 

 


前言

 在学习西瓜书上的流形学习时,我们学习到了测地线的概念,那么如何画测地线呢?

在瑞士卷数据集上使用python绘制测地线_第1张图片

首先看书上怎么说的

 

在瑞士卷数据集上使用python绘制测地线_第2张图片

 书上讲的很清楚了,求测地线的步骤大致为:

(1)首先对每个点基于欧 氏距离找出其近邻点

(2)建立一个近邻连接图,近邻点之间存在连接,而非近邻点之间不存在连接

(3)找出从源点到终点的最短路径,连接起来就是我们要的测地线了

所以接下来我将按照这个步骤一步步的实现它。

一、具体步骤

1.引入库

代码如下(示例):

import mat4py as mp
import numpy as np
# 载入必要库
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from sklearn import datasets
%matplotlib inline
import pandas as pd
import networkx as nx  # 导入 NetworkX 工具包

from sklearn.neighbors import NearestNeighbors

2.读入数据

代码如下(示例):

from sklearn.datasets import make_swiss_roll
# 用make_swiss_roll得到渐变色
X, t = make_swiss_roll(n_samples=1000, noise=0.2, random_state=42)

3.绘图

我们看一下原始数据集在3维空间上的分布,可以看到这是一个流形。

# 绘图
fig = plt.figure(figsize=(12, 8))
ax = Axes3D(fig, elev=10, azim=80)
ax.scatter(X[:, 0], X[:, 1], X[:, 2], c=t, cmap=plt.cm.Spectral)
ax.set_title('S Curve', fontsize=20)

在瑞士卷数据集上使用python绘制测地线_第3张图片


 4.测地线的绘制

 4.1首先对每个点基于欧 氏距离找出其近邻点

这里我们直接调用NearestNeighbors()方法计算就行了

返回值说明:

# 返回值indices:第0列元素为参考点的索引,后面是(n_neighbors - 1)个与之最近的点的索引
# 返回值distances:第0列元素为与自身的距离(为0),后面是(n_neighbors - 1)个与之最近的点与参考点的距离

# j 计算每个点的k近邻:
    nbrs = NearestNeighbors(n_neighbors=n_neighbors, algorithm='ball_tree').fit(X)
    distances, indices = nbrs.kneighbors(X)

 4.2建立一个近邻连接图

近邻点之间存在连接,而非近邻点之间不存在连接

初始化近邻矩阵:

dist_matrix=np.zeros((m,m))

 获取近邻矩阵:

 for i in range(m):
        for j in range(m):
            if j not in indices[i]:#若X[j]点不是X[i]的k近邻,则距离为0
                dist_matrix[i][j]=0
            else:#若X[j]点是X[i]的k近邻
                for index in range(len(indices[i])):#求X[j]到X[i]的距离
                    if indices[i][index]==j:
                        dist_matrix[i][j]=distances[i][index]
                        break

 4.3找出从源点到终点的最短路径

这里可以使用NetworkX图去求

dfAdj = pd.DataFrame(dist_matrix)
G1 = nx.from_pandas_adjacency(dfAdj)  # 由 pandas 顶点邻接矩阵 创建 NetworkX 图
# 两个指定顶点之间的最短加权路径
minWPath = nx.bellman_ford_path(G1, source=source, target=target)  # 顶点 10 到 顶点 100 的最短加权路径
print("最短路径为:",minWPath)

d3b022bca565435c803ab5f4bd829202.png

 4.4绘制

有了最短路径,把路径上的点连起来就可以进行绘制了

(1)获得坐标

if len(X[0])==2:
        x=[]
        y=[]
        for i in minWPath:
            x.append(X[i,0])
            y.append(X[i,1])
        return x,y
    if len(X[0])==3:
        x=[]
        y=[]
        z=[]
        for i in minWPath:
            x.append(X[i,0])
            y.append(X[i,1])
            z.append(X[i,2])
        return x,y,z

(2)绘制

import matplotlib as mpl
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
import matplotlib.pyplot as plt

# 绘图
fig = plt.figure(figsize=(12, 8))
ax = Axes3D(fig, elev=10, azim=80)
ax.scatter(X[:, 0], X[:, 1], X[:, 2], c=t, cmap=plt.cm.Spectral)
ax.set_title('S Curve', fontsize=20)

x,y,z=cedi_line(X)
ax.plot(x, y, z, label='parametric curve',color='red')

# 显示图例
ax.legend()

# 显示图形
plt.show()

5.结果展示

在瑞士卷数据集上使用python绘制测地线_第4张图片

降成二维后,测地线的绘制

在瑞士卷数据集上使用python绘制测地线_第5张图片

 

 

总结

以上就是今天要讲的内容,本文基于西瓜书上绘制测地线的方法进行了实现,至于有不有更简洁、更正确的画法,还请不吝赐教!

源代码:

 

本文参考的文章:

https://blog.csdn.net/youcans/article/details/116999881icon-default.png?t=M4ADhttps://blog.csdn.net/youcans/article/details/116999881

 

 

 

 

 

 

 

 

 

你可能感兴趣的:(机器学习,python,机器学习)