sklearn.metrics.multilabel_confusion_matrix(y_true, y_pred, *, sample_weight=None, labels=None, samplewise=False)
计算class-wise(默认)或sample-wise多标签混淆矩阵
计算class-wise multi_confusion时,输入y_true和y_pred形状为(n_samples, n_labels) (多类多标签情况)or (n_samples,)(多类单标签情况),输出multi_confusion形状为(n_labels, 2, 2)
上述n_labels即类别数。
计算sample-wise multi_confusion时,输入y_true和y_pred形状为(n_samples, n_labels) (多类多标签情况)or (n_samples,)(多类单标签情况),输出multi_confusion形状为(n_samples, 2, 2)
例子: 多类单标签 >>> import numpy as np >>> from sklearn.metrics import multilabel_confusion_matrix >>> y_true = np.array([[1, 0, 1], ... [0, 1, 0]]) >>> y_pred = np.array([[1, 0, 0], ... [0, 1, 1]]) >>> multilabel_confusion_matrix(y_true, y_pred) array([[[1, 0], [0, 1]], [[1, 0], [0, 1]], [[0, 1], [1, 0]]])
多类单标签 >>> y_true = ["cat", "ant", "cat", "cat", "ant", "bird"] >>> y_pred = ["ant", "ant", "cat", "cat", "ant", "cat"] >>> multilabel_confusion_matrix(y_true, y_pred, ... labels=["ant", "bird", "cat"]) array([[[3, 1], [0, 2]], [[5, 0], [1, 0]], [[2, 1], [1, 2]]])
参考:
sklearn.metrics.confusion_matrix — scikit-learn 1.1.1 documentation