P-R曲线绘制(多分类问题)

以iris数据为样本实现P-R曲线的绘制

import matplotlib.pyplot as plt
import numpy as np
from sklearn import svm, datasets
from sklearn.metrics import precision_recall_curve, average_precision_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import label_binarize
from sklearn.multiclass import OneVsRestClassifier

导入iris数据集

iris = datasets.load_iris()
X = iris.data
y = iris.target

因为target列是iris分类的文字描述形式,需将其转换为类别标签

y = label_binarize(y,classes=[0,1,2]) # 运用标签二值化的方法
n_classes = y.shape[1] 

形成的部分y如下图

P-R曲线绘制(多分类问题)_第1张图片

 为使得曲线效果变化显著,适当增加噪声样本

random_state = np.random.RandomState(0)
n_samples,n_features = X.shape
# 增加200倍的噪声值,即在原始x的列上增加200*4列
X = np.c_[X,random_state.randn(n_samples,200*n_features)]

训练模型,并计算decision_function()

X_train,X_test,y_train,y_test = train_test_split(X,y, test_size=0.8,random_state=random_state)
classifier = OneVsRestClassifier(svm.SVC(kernel = "linear",probability = True, random_state=random_state))
# decision_function计算样本点到分割超平面的函数距离。输出表示分类器对x_test的预测样本是位于超平面的右侧还是左侧,以及离它有多远。它还告诉我们分类器为x_test预测的每个值是正的(大幅度正值)还是负的(大幅度负值)。
y_score = classifier.fit(X_train,y_train).decision_function(X_test)
print(y_score)

P-R曲线绘制(多分类问题)_第2张图片

对三个分类以此计算precision、recall,并且运用micro方式对precision、recall求平均(也可以使用macro、weighted的方式进行求平均

 

precision = dict()
recall = dict()
average_precision = dict()
for i in range(n_classes):
    precision[i],recall[i],_ = precision_recall_curve(y_test[:,i],y_score[:,i])
    average_precision[i] = average_precision_score(y_test[:,i],y_score[:,i])
    
precision["micro"],recall["micro"],_ = precision_recall_curve(y_test.ravel(),y_score.ravel())
average_precision["micro"] = average_precision_score(y_test,y_score,average="micro") 

绘制P-R曲线

plt.clf()
plt.plot(recall["micro"],precision["micro"],label = "micro_average P_R(area={0:0.2f})".format(average_precision["micro"]))
for i in range(n_classes):
    plt.plot(recall[i],precision[i],label = "P_R curve of class{0}(area={1:0.2f})".format(i,average_precision[i]))

plt.xlim([0.0,0.1])
plt.ylim([0.0,1.05])
plt.legend(loc = "lower right")
plt.show()

P-R曲线绘制(多分类问题)_第3张图片

 源代码源自深度学习基础_哈尔滨工业大学_中国大学MOOC(慕课) (icourse163.org)

你可能感兴趣的:(分类,sklearn)