在学习西瓜书上的流形学习时,我们学习到了测地线的概念,那么如何画测地线呢?本文将使用python简单的实现一下在瑞士卷数据集上测地线的绘制。
前言
一、具体步骤
1.引入库
2.读入数据
3.绘图
4.测地线的绘制
4.1首先对每个点基于欧 氏距离找出其近邻点
4.2建立一个近邻连接图
4.3找出从源点到终点的最短路径
编辑
4.4绘制
5.结果展示
总结
源代码:
在学习西瓜书上的流形学习时,我们学习到了测地线的概念,那么如何画测地线呢?
首先看书上怎么说的
书上讲的很清楚了,求测地线的步骤大致为:
(1)首先对每个点基于欧 氏距离找出其近邻点
(2)建立一个近邻连接图,近邻点之间存在连接,而非近邻点之间不存在连接
(3)找出从源点到终点的最短路径,连接起来就是我们要的测地线了
所以接下来我将按照这个步骤一步步的实现它。
代码如下(示例):
import mat4py as mp
import numpy as np
# 载入必要库
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from sklearn import datasets
%matplotlib inline
import pandas as pd
import networkx as nx # 导入 NetworkX 工具包
from sklearn.neighbors import NearestNeighbors
代码如下(示例):
from sklearn.datasets import make_swiss_roll
# 用make_swiss_roll得到渐变色
X, t = make_swiss_roll(n_samples=1000, noise=0.2, random_state=42)
我们看一下原始数据集在3维空间上的分布,可以看到这是一个流形。
# 绘图
fig = plt.figure(figsize=(12, 8))
ax = Axes3D(fig, elev=10, azim=80)
ax.scatter(X[:, 0], X[:, 1], X[:, 2], c=t, cmap=plt.cm.Spectral)
ax.set_title('S Curve', fontsize=20)
这里我们直接调用NearestNeighbors()方法计算就行了
返回值说明:
# 返回值indices:第0列元素为参考点的索引,后面是(n_neighbors - 1)个与之最近的点的索引
# 返回值distances:第0列元素为与自身的距离(为0),后面是(n_neighbors - 1)个与之最近的点与参考点的距离
# j 计算每个点的k近邻:
nbrs = NearestNeighbors(n_neighbors=n_neighbors, algorithm='ball_tree').fit(X)
distances, indices = nbrs.kneighbors(X)
近邻点之间存在连接,而非近邻点之间不存在连接
初始化近邻矩阵:
dist_matrix=np.zeros((m,m))
获取近邻矩阵:
for i in range(m):
for j in range(m):
if j not in indices[i]:#若X[j]点不是X[i]的k近邻,则距离为0
dist_matrix[i][j]=0
else:#若X[j]点是X[i]的k近邻
for index in range(len(indices[i])):#求X[j]到X[i]的距离
if indices[i][index]==j:
dist_matrix[i][j]=distances[i][index]
break
这里可以使用NetworkX图去求
dfAdj = pd.DataFrame(dist_matrix)
G1 = nx.from_pandas_adjacency(dfAdj) # 由 pandas 顶点邻接矩阵 创建 NetworkX 图
# 两个指定顶点之间的最短加权路径
minWPath = nx.bellman_ford_path(G1, source=source, target=target) # 顶点 10 到 顶点 100 的最短加权路径
print("最短路径为:",minWPath)
有了最短路径,把路径上的点连起来就可以进行绘制了
(1)获得坐标
if len(X[0])==2:
x=[]
y=[]
for i in minWPath:
x.append(X[i,0])
y.append(X[i,1])
return x,y
if len(X[0])==3:
x=[]
y=[]
z=[]
for i in minWPath:
x.append(X[i,0])
y.append(X[i,1])
z.append(X[i,2])
return x,y,z
(2)绘制
import matplotlib as mpl
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
import matplotlib.pyplot as plt
# 绘图
fig = plt.figure(figsize=(12, 8))
ax = Axes3D(fig, elev=10, azim=80)
ax.scatter(X[:, 0], X[:, 1], X[:, 2], c=t, cmap=plt.cm.Spectral)
ax.set_title('S Curve', fontsize=20)
x,y,z=cedi_line(X)
ax.plot(x, y, z, label='parametric curve',color='red')
# 显示图例
ax.legend()
# 显示图形
plt.show()
降成二维后,测地线的绘制
以上就是今天要讲的内容,本文基于西瓜书上绘制测地线的方法进行了实现,至于有不有更简洁、更正确的画法,还请不吝赐教!
本文参考的文章:
https://blog.csdn.net/youcans/article/details/116999881https://blog.csdn.net/youcans/article/details/116999881