Tensorflow slim.metrics 常见指标Accuracy、Precision、Recall针对多分类的计算方法

最近用eval_image_classifier.py脚本测试多分类模型时,发现slim.metrics中Accuracy指标和自己计算的值有偏差,于是特意去看了源代码,发现此处计算的Accuracy其实是各类召回率的算术平均值,而且指标中给的Precision和Recall都是对于二分类来计算的,对于多分类模型,根本不适用,所以要想自己测试多分类模型的评价指标,需要自己单独进行测试,下面来分析源码。
(1)在代码中添加想要查看的指标

 # Define the metrics:
    names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({
        'TP':slim.metrics.streaming_true_positives(predictions, labels),
        'TN':slim.metrics.streaming_true_negatives(predictions, labels),
        'FP':slim.metrics.streaming_false_positives(predictions, labels),
        'FN':slim.metrics.streaming_false_negatives(predictions, labels),
        'Accuracy': slim.metrics.streaming_accuracy(predictions, labels),
        'Precision': slim.metrics.streaming_precision(predictions, labels),
        'Recall':slim.metrics.streaming_recall(predictions, labels),
        'Recall_1': slim.metrics.streaming_recall_at_k(
            logits, labels, 1),
    })

查看各指标的值:
Tensorflow slim.metrics 常见指标Accuracy、Precision、Recall针对多分类的计算方法_第1张图片
Accuracy = Recall_1=0.88
Precision = 0.93
Recall = 0.929
发现Accuracy等于Recall_1,且不满足(TP+TN)/(TP+FP+TN+FN)
这个问题有网友提到过https://stackoverflow.com/questions/43408200/tf-slim-computation-of-accuracy
(2)分析代码中各个指标的计算方法
首先来看看四个变量TP、TN、FP、FN的值是如何计算的
TP:

def true_positives(labels,
                   predictions,
                   weights=None,
                   metrics_collections=None,
                   updates_collections=None,
                   name=None):
  if context.executing_eagerly():
    raise RuntimeError('tf.metrics.true_positives is not '
                       'supported when eager execution is enabled.')

  with variable_scope.variable_scope(name, 'true_positives',
                                     (predictions, labels, weights)):

    predictions, labels, weights = _remove_squeezable_dimensions(
        predictions=math_ops.cast(predictions, dtype=dtypes.bool),
        labels=math_ops.cast(labels, dtype=dtypes.bool),
        weights=weights)
    # 将标签和预测值转为bool型变量,label等于0为false负例,大于0的为true,正例
    is_true_positive = math_ops.logical_and(
        math_ops.equal(labels, True), math_ops.equal(predictions, True))
     #统计的是标签和预测值同时为true的个数
    return _count_condition(is_true_positive, weights, metrics_collections,
                            updates_collections)

FP:

def false_positives(labels,
                    predictions,
                    weights=None,
                    metrics_collections=None,
                    updates_collections=None,
                    name=None):
    if context.executing_eagerly():
    raise RuntimeError('tf.metrics.false_positives is not supported when '
                       'eager execution is enabled.')

  with variable_scope.variable_scope(name, 'false_positives',
                                     (predictions, labels, weights)):

    predictions, labels, weights = _remove_squeezable_dimensions(
        predictions=math_ops.cast(predictions, dtype=dtypes.bool),
        labels=math_ops.cast(labels, dtype=dtypes.bool),
        weights=weights)
    is_false_positive = math_ops.logical_and(
        math_ops.equal(labels, False), math_ops.equal(predictions, True))
         #统计的是标签为false、预测值为true的个数
    return _count_condition(is_false_positive, weights, metrics_collections,
                            updates_collections)

另外两个类似,我就不展示了,从源码可以看出,四个值的计算仅仅是针对二分类,标签为0,1时适用,对于multi-class,是不适用的。

那么基于以上四个值计算的精确率和召回率也是不适用多分类的。
Precison = (TP)/(TP+FP)
Recall = (TP)/(TP+FN)

再来看看Accuracy:

def accuracy(labels,
             predictions,
             weights=None,
             metrics_collections=None,
             updates_collections=None,
             name=None):
    if context.executing_eagerly():
    raise RuntimeError('tf.metrics.accuracy is not supported when eager '
                       'execution is enabled.')

  predictions, labels, weights = _remove_squeezable_dimensions(
      predictions=predictions, labels=labels, weights=weights)
  predictions.get_shape().assert_is_compatible_with(labels.get_shape())
  if labels.dtype != predictions.dtype:
    predictions = math_ops.cast(predictions, labels.dtype)
  is_correct = math_ops.to_float(math_ops.equal(predictions, labels))
  #查看标签值和预测值是否相等。
  return mean(is_correct, weights, metrics_collections, updates_collections,
              name or 'accuracy')
  #统计每一类相等的个数(每一类的召回率),并求平均值

你可能感兴趣的:(Tensorflow)