CIFAR10 T-SNE 绘制特征空间图

基本思想就是把测试数据输入模型,然后对模型提取的特征(未经过分类器)的部分进行降维绘图

1. 先引入包

import torch
from sklearn.manifold import TSNE # 这个是绘图关键
import random
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
from torchvision import datasets, transforms

2. 设置随机种子

为保证结果可复现,设置了随机种子

def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
setup_seed(1337)

3. 准备测试数据及模型

transform = transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])  # 对读取数据做个处理,打个包
testset = datasets.CIFAR10(root='../data/MNIST/', train=False,
                           download=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=32,
                                         shuffle=False, num_workers=2)
model_file = ["centralized_net.pkl"]
model = torch.load(model_name)

4. 输入测试数据得到特征表示

model.eval()
with torch.no_grad():
    for i, (image_batch, label_batch) in enumerate(testloader):
        image_batch, label_batch = image_batch.cuda(), label_batch.cuda()
        label_batch = label_batch.long().squeeze()
        inputs = image_batch
        logits, feature = model(inputs)
        if i == 0:
            feature_bank = feature
            label_bank = label_batch
            logits_bank = logits
        else:
            feature_bank = torch.cat((feature_bank, feature))
            label_bank = torch.cat((label_bank, label_batch))
            logits_bank = torch.cat((logits_bank, logits))

5. 绘图

针对feature_banklabel_bank进行绘图

 feature_bank = feature_bank.cpu().numpy()
 label_bank = label_bank.cpu().numpy()
 p, pseu = torch.max(torch.softmax(logits_bank, dim=-1), dim=-1)
 prob_bank = p.cpu().numpy()
 tsne = TSNE(2)
 output = tsne.fit_transform(feature_bank) # feature进行降维,降维至2维表示
 # 带真实值类别
 for i in range(10):	# 对每类的数据画上特定颜色的点
     index = (label_bank==i)
     plt.scatter(output[index, 0], output[index, 1],s=5, cmap=plt.cm.Spectral)
 plt.legend(["0", "1", "2", "3", "4", "5", "6","7", "8", "9"])
 plt.show()

CIFAR10 T-SNE 绘制特征空间图_第1张图片

你可能感兴趣的:(机器学习入门,pytorch,深度学习,python)