scikit-learn------classification_report简介

sklearn中的classification_report函数用于显示主要分类指标的文本报告.在报告中显示每个类的精确度,召回率,F1值等信息。
主要参数:
y_true:1维数组,或标签指示器数组/稀疏矩阵,目标值。
y_pred:1维数组,或标签指示器数组/稀疏矩阵,分类器返回的估计值。
labels:array,shape = [n_labels],列表,需要评估的标签名称。
target_names:字符串列表,指定标签名称。
sample_weight:指定标签名称。
digits:评估报告中小数点的保留位数,如果 output_dict=True,此参数不起作用,返回的数值不作处理。
output_dict:若真,评估结果以字典形式返回。
classification_report用法示例:

# 鸢尾花数据集的随机森林结果评估
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split
# 鸢尾花数据集
iris = load_iris()
X = iris.data
y = iris.target
#[0, 1, 2] 标签转换为名称 ['setosa' 'versicolor' 'virginica']
y_labels = iris.target_names[y]

# 数据集拆分为训练集与测试集
X_train, X_test, y_train, y_test = train_test_split(X, y_labels, test_size=0.2)

# 使用训练集训练模型
clf = RandomForestClassifier(n_estimators=100)
clf.fit(X_train, y_train)

# 使用测试集预测结果
y_pred = clf.predict(X_test)

# 生成文本型分类报告
print(classification_report(y_test, y_pred))
"""
              precision    recall  f1-score   support

      setosa       1.00      1.00      1.00        10
  versicolor       0.83      1.00      0.91        10
   virginica       1.00      0.80      0.89        10

   micro avg       0.93      0.93      0.93        30
   macro avg       0.94      0.93      0.93        30
weighted avg       0.94      0.93      0.93        30
"""

# 生成字典型分类报告
report = classification_report(y_test, y_pred, output_dict=True)
for key, value in report["setosa"].items():
    print(f"{key:10s}:{value:10.2f}")
    
"""
precision :      1.00
recall    :      1.00
f1-score  :      1.00
support   :     10.00
"""

其中列表左边的一列为分类的标签名,右边support列为每个标签的出现次数。avg / total行为各列的均值(support列为总和)。
precision,recall,f1-score三列分别为各个类别的精确度/召回率及 F1值。

精确度:precision,正确预测为正的,占全部预测为正的比例,TP / (TP+FP)
召回率:recall,正确预测为正的,占全部实际为正的比例,TP / (TP+FN)
F1-score:精确率和召回率的调和平均数,2 * precision*recall /(precision+recall)

同时还会给出总体的微平均值,宏平均值和加权平均值。

微平均值:micro average,所有数据结果的平均值
宏平均值:macro average,所有标签结果的平均值
加权平均值:weighted average,所有标签结果的加权平均值

你可能感兴趣的:(Machine,Learning)