以iris数据为样本实现P-R曲线的绘制
import matplotlib.pyplot as plt
import numpy as np
from sklearn import svm, datasets
from sklearn.metrics import precision_recall_curve, average_precision_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import label_binarize
from sklearn.multiclass import OneVsRestClassifier
导入iris数据集
iris = datasets.load_iris()
X = iris.data
y = iris.target
因为target列是iris分类的文字描述形式,需将其转换为类别标签
y = label_binarize(y,classes=[0,1,2]) # 运用标签二值化的方法
n_classes = y.shape[1]
形成的部分y如下图
为使得曲线效果变化显著,适当增加噪声样本
random_state = np.random.RandomState(0)
n_samples,n_features = X.shape
# 增加200倍的噪声值,即在原始x的列上增加200*4列
X = np.c_[X,random_state.randn(n_samples,200*n_features)]
训练模型,并计算decision_function()
X_train,X_test,y_train,y_test = train_test_split(X,y, test_size=0.8,random_state=random_state)
classifier = OneVsRestClassifier(svm.SVC(kernel = "linear",probability = True, random_state=random_state))
# decision_function计算样本点到分割超平面的函数距离。输出表示分类器对x_test的预测样本是位于超平面的右侧还是左侧,以及离它有多远。它还告诉我们分类器为x_test预测的每个值是正的(大幅度正值)还是负的(大幅度负值)。
y_score = classifier.fit(X_train,y_train).decision_function(X_test)
print(y_score)
对三个分类以此计算precision、recall,并且运用micro方式对precision、recall求平均(也可以使用macro、weighted的方式进行求平均
precision = dict()
recall = dict()
average_precision = dict()
for i in range(n_classes):
precision[i],recall[i],_ = precision_recall_curve(y_test[:,i],y_score[:,i])
average_precision[i] = average_precision_score(y_test[:,i],y_score[:,i])
precision["micro"],recall["micro"],_ = precision_recall_curve(y_test.ravel(),y_score.ravel())
average_precision["micro"] = average_precision_score(y_test,y_score,average="micro")
绘制P-R曲线
plt.clf()
plt.plot(recall["micro"],precision["micro"],label = "micro_average P_R(area={0:0.2f})".format(average_precision["micro"]))
for i in range(n_classes):
plt.plot(recall[i],precision[i],label = "P_R curve of class{0}(area={1:0.2f})".format(i,average_precision[i]))
plt.xlim([0.0,0.1])
plt.ylim([0.0,1.05])
plt.legend(loc = "lower right")
plt.show()
源代码源自深度学习基础_哈尔滨工业大学_中国大学MOOC(慕课) (icourse163.org)