sklearn学习笔记 半监督分类 之 标签传播对手写数字分类

手写数字数据集总共有1797个点,但只有30个将被标记。 混淆矩阵形式的结果和每个类的一系列指标将非常好。

标签传播模型将使用所有点进行训练,通过极少数标签对手写数字进行分类。

本次实验主要是简单展示下“半监督学习”的强大功能:

import numpy as np
import matplotlib.pyplot as plt
from scipy import stats
from sklearn import datasets
from sklearn.semi_supervised import label_propagation
from sklearn.metrics import confusion_matrix, classification_report


digits = datasets.load_digits()  # 导入手写体数字 数据集
rng = np.random.RandomState(0)  # 生成随机种子
indices = np.arange(len(digits.data))  # 切片
rng.shuffle(indices)

X = digits.data[indices[:330]]
y = digits.target[indices[:330]]
images = digits.images[indices[:330]]

n_total_samples = len(y)
n_labeled_points = 30

indices = np.arange(n_total_samples)

# 无标签集合
unlabeled_set = indices[n_labeled_points:]

# 将 y_train 中含有无标签的值设为 -1(进行清除)
y_train = np.copy(y)
y_train[unlabeled_set] = -1

# 标签传播学习
lp_model = label_propagation.LabelSpreading(gamma = 0.25, max_iter = 5)
lp_model.fit(X, y_train)
predicted_labels = lp_model.transduction_[unlabeled_set]
true_labels = y[unlabeled_set]


cm = confusion_matrix(true_labels, predicted_labels, labels=lp_model.classes_)

print("Label Spreading model: %d labeled & %d unlabeled points (%d total)" %(n_labeled_points, n_total_samples - n_labeled_points, n_total_samples))
print(classification_report(true_labels, predicted_labels))
print("Confusion matrix")
print(cm)


# 计算每个转换分布的不确定值
pred_entropies = stats.distributions.entropy(lp_model.label_distributions_.T)

# 选取前10个最不准的标签
uncertainty_index = np.argsort(pred_entropies)[-10:]


# 绘图
f = plt.figure(figsize = (7, 5))  # 设置图片大小
for index, image_index in enumerate(uncertainty_index):
    image = images[image_index]

    sub = f.add_subplot(2, 5, index + 1)  # 绘制子图
    sub.imshow(image, cmap = plt.cm.gray_r)
    plt.xticks([ ])
    plt.yticks([ ])
    sub.set_title('predict: %i\ntrue: %i' %(lp_model.transduction_[image_index], y[image_index]))

f.suptitle('Learning with small amount of labeled data')
plt.show()  # 显示

基于 Anaconda + Jupyter Notebook 环境的运行结果为:

sklearn学习笔记 半监督分类 之 标签传播对手写数字分类_第1张图片

显示前10个最不确定的预测结果:

sklearn学习笔记 半监督分类 之 标签传播对手写数字分类_第2张图片

 

推荐阅读:

sklearn学习笔记 半监督分类 之 标签传播与SVM的决策边界

你可能感兴趣的:(sklearn学习笔记)