参考:sklearn手册
t-SNE 是目前来说效果最好的数据降维与可视化方法,但是它的缺点也很明显,比如:占内存大,运行时间长。
但是,当我们想要对高维数据进行分类,又不清楚这个数据集有没有很好的可分性(即同类之间间隔小,异类之间间隔大),可以通过 t-SNE 投影到 2 维或者 3 维的空间中观察一下。如果在低维空间中具有可分性,则数据是可分的;如果在高维空间中不具有可分性,可能是数据不可分,也可能仅仅是因为不能投影到低维空间。
使用 t-SNE 的缺点大概是:
t-SNE 的计算复杂度很高,在数百万个样本数据集中可能需要几个小时,而 PCA 可以在几秒钟或几分钟内完成
Barnes-Hut t-SNE 方法(下面讲)限于二维或三维嵌入。
算法是随机的,具有不同种子的多次实验可以产生不同的结果。虽然选择 loss 最小的结果就行,但可能需要多次实验以选择超参数。
全局结构未明确保留。这个问题可以通过 PCA 初始化点(使用init =‘pca’)来缓解。
t-SNE 的主要目的是高维数据的可视化。因此,当数据嵌入二维或三维时,效果最好。有时候优化 KL 散度可能有点棘手。有五个参数可以控制 t-SNE 的优化,即会影响最后的可视化质量:
perplexity 困惑度
early exaggeration factor 前期放大系数
learning rate 学习率
maximum number of iterations 最大迭代次数
angle 角度
Barnes-Hut t-SNE 主要是对传统 t-SNE 在速度上做了优化,是现在最流行的 t-SNE 方法,同时它与传统 t-SNE 还有一些不同:
Barnes-Hut 仅在目标维度为 3 或更小时才起作用。以 2D 可视化为主。
Barnes-Hut 仅适用于密集的输入数据。稀疏数据矩阵只能用特定的方法嵌入,或者可以通过投影近似,例如使用sklearn.decomposition.TruncatedSVD
Barnes-Hut 是一个近似值。使用 angle 参数对近似进行控制,因此当参数method="exact"时,TSNE()使用传统方法,此时 angle 参数不能使用。
Barnes-Hut 可以处理更多的数据。 Barnes-Hut 可用于嵌入数十万个数据点。
为了可视化的目的(这是 t-SNE 的主要用处),强烈建议使用 Barnes-Hut 方法。method="exact"时,传统的 t-SNE 方法尽管可以达到该算法的理论极限,效果更好,但受制于计算约束,只能对小数据集的可视化。
对于 MNIST 来说,t-SNE 可视化后可以自然的将字符按标签分开,见本文最后的例程;而 PCA 降维可视化后的手写字符,不同类别之间会重叠在一起,这也证明了 t-SNE 的非线性特性的强大之处。值得注意的是:未能在 2D 中用 t-SNE 显现良好分离的均匀标记的组不一定意味着数据不能被监督模型正确分类,还可能是因为 2 维不足以准确地表示数据的内部结构。
一个简单的例子,输入 4 个 3 维的数据,然后通过 t-SNE 降维称 2 维的数据。
import numpy as np
from sklearn.manifold import TSNE
X = np.array([[0,0,0],[0,1,1],[1,0,1],[1,1,1]])
tsne = TSNE(n_components = 2)
tsne.fit_transform(X)
print(tsne.embedding_)
'''输出
[[ 3.17274952 -186.43092346]
[ 43.70787048 -283.6920166 ]
[ 100.43157196 -145.89025879]
[ 140.96669006 -243.15138245]]'
S 曲线上的数据是高维的数据,其中不同颜色表示数据的不同类别。当我们通过 t-SNE 嵌入到二维空间中后,可以看到数据点之间的类别信息完美的保留了下来
# coding='utf-8'
"""# 一个对 S 曲线数据集上进行各种降维的说明。"""
from time import time
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.ticker import NullFormatter
from sklearn import manifold, datasets
# # Next line to silence pyflakes. This import is needed.
# Axes3D
n_points = 1000
# X 是一个(1000, 3)的 2 维数据,color 是一个(1000,)的 1 维数据
X, color = datasets.samples_generator.make_s_curve(n_points, random_state=0)
n_neighbors = 10
n_components = 2
fig = plt.figure(figsize=(8, 8))
# 创建了一个 figure,标题为"Manifold Learning with 1000 points, 10 neighbors"
plt.suptitle("Manifold Learning with %i points, %i neighbors"
% (1000, n_neighbors), fontsize=14)
'''绘制 S 曲线的 3D 图像'''
ax = fig.add_subplot(211, projection='3d')
ax.scatter(X[:, 0], X[:, 1], X[:, 2], c=color, cmap=plt.cm.Spectral)
ax.view_init(4, -72) # 初始化视角
'''t-SNE'''
t0 = time()
tsne = manifold.TSNE(n_components=n_components, init='pca', random_state=0)
Y = tsne.fit_transform(X) # 转换后的输出
t1 = time()
print("t-SNE: %.2g sec" % (t1 - t0)) # 算法用时
ax = fig.add_subplot(2, 1, 2)
plt.scatter(Y[:, 0], Y[:, 1], c=color, cmap=plt.cm.Spectral)
plt.title("t-SNE (%.2g sec)" % (t1 - t0))
ax.xaxis.set_major_formatter(NullFormatter()) # 设置标签显示格式为空
ax.yaxis.set_major_formatter(NullFormatter())
# plt.axis('tight')
plt.show()
这里的手写数字数据集是一堆 8*8 的数组,每一个数组都代表着一个手写数字。如下图所示
# coding='utf-8'
"""t-SNE 对手写数字进行可视化"""
from time import time
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.manifold import TSNE
def get_data():
digits = datasets.load_digits(n_class=6)
data = digits.data
label = digits.target
n_samples, n_features = data.shape
return data, label, n_samples, n_features
def plot_embedding(data, label, title):
x_min, x_max = np.min(data, 0), np.max(data, 0)
data = (data - x_min) / (x_max - x_min)
fig = plt.figure()
ax = plt.subplot(111)
for i in range(data.shape[0]):
plt.text(data[i, 0], data[i, 1], str(label[i]),
color=plt.cm.Set1(label[i] / 10.),
fontdict={'weight': 'bold', 'size': 9})
plt.xticks([])
plt.yticks([])
plt.title(title)
return fig
def main():
data, label, n_samples, n_features = get_data()
print('Computing t-SNE embedding')
tsne = TSNE(n_components=2, init='pca', random_state=0)
t0 = time()
result = tsne.fit_transform(data)
fig = plot_embedding(result, label,
't-SNE embedding of the digits (time %.2fs)'
% (time() - t0))
plt.show(fig)
if __name__ == '__main__':
main()