TSNE是由T和SNE组成,T分布和随机近邻嵌入(Stochastic neighbor Embedding).
TSNE是一种可视化工具,将高位数据降到2-3维,然后画成图。
t-SNE的缺点是:占用内存大,运行时间长。
本次是以工程化思维 来 使用t-SNE对nlp领域的数据进行降维与可视化。有问题的话还请各位不吝赐教。
详细的参数介绍可以看官网
输入X是二维的数据,array类型的
data = np.array([
[1,2,3,4,5,6],
[2,2,2,2,2,2],
[0,0,1,2,1,1]
])
输入的Y也是array类型
label = np.array([[1], [1], [4]])
tsne = TSNE(n_components=2, init='pca', random_state=0)
result = tsne.fit_transform(data)
作图时,根据label的值对坐标中的点进行标号上色。
但是Y是array类型,因此需要int转化一下int(label[i])
def plot_embedding(data, label, title):
x_min, x_max = np.min(data, 0), np.max(data, 0)
data = (data - x_min) / (x_max - x_min)
fig = plt.figure()
ax = plt.subplot(111)
for i in range(data.shape[0]):
plt.text(data[i, 0], data[i, 1], int(label[i]),
color=plt.cm.Set1(int(label[i]) / 10.),
fontdict={'weight': 'bold', 'size': 9})
plt.xticks([])
plt.yticks([])
plt.title(title)
return fig
# this is a code test
from time import time
import numpy as np
import matplotlib.pyplot as plt
import pickle
from sklearn.preprocessing import LabelEncoder, OneHotEncoder
from sklearn import datasets
from sklearn.manifold import TSNE
def get_test_data():
data = np.array([
[1,2,3,4,5,6],
[2,2,2,2,2,2],
[0,0,1,2,1,1]
])
MAX_LEN = 3
label = np.array([[1], [1], [4]])
# ohe = OneHotEncoder()
# label = ohe.fit_transform(label).toarray()
n_samples, n_features = data.shape
print(f"n_sample {n_samples}, n_features: {n_features}")
return data, label, n_samples, n_features
def plot_embedding(data, label, title):
x_min, x_max = np.min(data, 0), np.max(data, 0)
data = (data - x_min) / (x_max - x_min)
fig = plt.figure()
ax = plt.subplot(111)
for i in range(data.shape[0]):
plt.text(data[i, 0], data[i, 1], int(label[i]),
color=plt.cm.Set1(int(label[i]) / 10.),
fontdict={'weight': 'bold', 'size': 9})
plt.xticks([])
plt.yticks([])
plt.title(title)
return fig
def main():
data, label, n_samples, n_features = get_test_data()
print('Computing t-SNE embedding')
tsne = TSNE(n_components=2, init='pca', random_state=0)
t0 = time()
result = tsne.fit_transform(data)
fig = plot_embedding(result, label,
't-SNE embedding of the digits (time %.2fs)'
% (time() - t0))
plt.show(fig)
if __name__ == '__main__':
main()