TNSE和PCA

模型见另一篇文章MNIST  softmax

导包 matplot设置是为了显示负坐标轴

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
import numpy as np
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import matplotlib
matplotlib.rcParams['axes.unicode_minus']=False

 下载数据

train_dataset = torchvision.datasets.MNIST(root='data/', 
                                           train=True, 
                                           transform=transforms.ToTensor(),
                                           download=True)
train_loader = torch.utils.data.DataLoader(
                        torchvision.datasets.MNIST('./data/',train=True,download=True,
                        transform=torchvision.transforms.Compose([
                        torchvision.transforms.ToTensor(),
                        torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=100, shuffle=True)

导入预训练模型

 

model = nn.Linear(784, 10)
model.load_state_dict(torch.load('model.ckpt'))#再加载网络的参数
print("load success")

 网络输出是tensor ,禁止更新梯度

pcadata=[]
labels=[]
with torch.no_grad():
    for i, (images, label) in enumerate(train_loader):
            # Reshape images to (batch_size, input_size)  
            images = images.reshape(100, 28*28)
            # Forward pass 前向传播
            outputs = model(images)   #100x10
            pcadata.append(outputs)
            labels.append(label)
print(len(pcadata))
print(len(labels))
print(pcadata[0].size)

 一个列表包含600个元素,每个元素是一个100x10的tensor,缝成一个60000x10的列表

#A是网络输出的数据,B是对应的标签
A=pcadata[0]
B=labels[0]
for i in range(1,600,):
    A=torch.cat((A,pcadata[i]),0)
    B=torch.cat((B,labels[i]),0)
print(A.shape)
print(B.shape)
print(A[0])
print(B[0])

列表转换为数组 

A=np.array(A)
AT=np.array(A)
print(A.shape,AT.shape)
B=np.array(B)
B1=np.array(B)

只取前2000个可视化,TSNE计算有点慢(之前取6w个需要算很久) 

A=A[0:2000,:]
meanVals=np.mean(A,axis=0)
A=A-meanVals
print('输入前矩阵',A.shape)
pca = PCA(n_components=2)
A = pca.fit_transform(A)
print('输入后矩阵',A.shape)


AT=AT[0:2000,:]
mean=np.mean(AT,axis=0)
AT=AT-mean
print('输入前矩阵',AT.shape)
tsne = TSNE(n_components=2)
AT = tsne.fit_transform(AT)
print('输入后矩阵',AT.shape)

PCA可视化

# 3 按类别对降维后的数据进行保存
A0,B0,A1,B1,A2,B2,A3,B3,A4,B4,A5,B5,A6,B6,A7,B7,A8,B8,A9,B9=[[] for x in range(20)]
for i in range(len(A)):
    if B[i] == 0:
        A0.append(A[i][0])
        B0.append(A[i][1])
    elif B[i] == 1:
        A1.append(A[i][0])
        B1.append(A[i][1])
    elif B[i]==2:
        A2.append(A[i][0])
        B2.append(A[i][1])
    elif B[i] == 3:
        A3.append(A[i][0])
        B3.append(A[i][1])
    elif B[i] == 4:
        A4.append(A[i][0])
        B4.append(A[i][1])
    elif B[i] == 5:
        A5.append(A[i][0])
        B5.append(A[i][1])
    elif B[i] == 6:
        A6.append(A[i][0])
        B6.append(A[i][1])
    elif B[i] == 7:
        A7.append(A[i][0])
        B7.append(A[i][1])
    elif B[i] == 8:
        A8.append(A[i][0])
        B8.append(A[i][1])
    elif B[i] == 9:
        A9.append(A[i][0])
        B9.append(A[i][1])
# 4 降维后数据可视化
plt.scatter(A0,B0, c = 'k')
plt.scatter(A1,B1, c = 'g')
plt.scatter(A2,B2, c = 'b')
plt.scatter(A3,B3, c = 'r')
plt.scatter(A4,B4, c = 'y')
plt.scatter(A5,B5, c = 'c')
plt.scatter(A6,B6, c = 'm')
plt.scatter(A7,B7, c = 'peru')
plt.scatter(A8,B8, c = 'pink')
plt.scatter(A9,B9, c = 'gold')
plt.legend(["label0","label1","label2","label3","label4","label5","label6","label7","label8","label9"])
plt.show()

 结果

TNSE和PCA_第1张图片

TSNE可视化

# 3 按类别对降维后的数据进行保存
A0,B0,A1,B1,A2,B2,A3,B3,A4,B4,A5,B5,A6,B6,A7,B7,A8,B8,A9,B9=[[] for x in range(20)]
for i in range(len(AT)):
    if B[i] == 0:
        A0.append(AT[i][0])
        B0.append(AT[i][1])
    elif B[i] == 1:
        A1.append(AT[i][0])
        B1.append(AT[i][1])
    elif B[i]==2:
        A2.append(AT[i][0])
        B2.append(AT[i][1])
    elif B[i] == 3:
        A3.append(AT[i][0])
        B3.append(AT[i][1])
    elif B[i] == 4:
        A4.append(AT[i][0])
        B4.append(AT[i][1])
    elif B[i] == 5:
        A5.append(AT[i][0])
        B5.append(AT[i][1])
    elif B[i] == 6:
        A6.append(AT[i][0])
        B6.append(AT[i][1])
    elif B[i] == 7:
        A7.append(AT[i][0])
        B7.append(AT[i][1])
    elif B[i] == 8:
        A8.append(AT[i][0])
        B8.append(AT[i][1])
    elif B[i] == 9:
        A9.append(AT[i][0])
        B9.append(AT[i][1])
# 4 降维后数据可视化
plt.scatter(A0,B0, c = 'k')
plt.scatter(A1,B1, c = 'g')
plt.scatter(A2,B2, c = 'b')
plt.scatter(A3,B3, c = 'r')
plt.scatter(A4,B4, c = 'y')
plt.scatter(A5,B5, c = 'c')
plt.scatter(A6,B6, c = 'm')
plt.scatter(A7,B7, c = 'peru')
plt.scatter(A8,B8, c = 'pink')
plt.scatter(A9,B9, c = 'gold')
plt.legend(["label0","label1","label2","label3","label4","label5","label6","label7","label8","label9"])
plt.show()

结果

TNSE和PCA_第2张图片

你可能感兴趣的:(可视化,深度学习,可视化)