【Tensorflow】slim.metrics评价模型的性能指标

 虽然是过时的东西了,虽然早已经是tensorflow2.0的时代了,虽然官方建议使用tf.metrics了,但是过时的东西也未必没用。最近又使用到了tensorflow的slim模型,记录一下其中用于评价模型性能指标的slim.metrics。

主要用的评价指标函数在tensorflow/contrib/metrics/python/ops/metric_ops.py这个文件里。

真正的实现在tensorflow/python/ops/metrics_impl.py文件里。

评价函数主要有

  • streaming_true_positives
  • streaming_true_negatives
  • streaming_false_positives
  • streaming_false_negatives
  • streaming_mean
  • streaming_mean_tensor
  • streaming_accuracy
  • streaming_precision
  • streaming_recall
  • streaming_false_positive_rate
  • streaming_false_negative_rate
  • streaming_true_positives_at_thresholds
  • streaming_false_positives_at_threholds
  • streaming_true_negatives_at_thresholds
  • streaming_false_positives_at_thresholds
  • streaming_curve_points
  • streaming_auc
  • streaming_dynamic_auc
  • streaming_specificity_at_sensitivity
  • streaming_sensitivity_at_specificity
  • streaming_precision_at_thresholds
  • streaming_recall_at_thresholds
  • streaming_false_positive_rate_at_thresholds
  • streaming_false_negative_rate_at_thresholds
  • streaming_recall_at_k
  • streaming_sparse_recall_at_k
  • streaming_sparse_precision_at_k
  • streaming_sparse_precision_at_top_k
  • streaming_sparse_average_precision_at_k
  • streaming_sparse_average_precision_at_top_k
  • streaming_mean_absolute_error
  • streaming_mean_relative_error
  • streaming_mean_squared_error
  • streaming_root_mean_squared_error
  • streaming_pearson_correlation
  • streaming_mean_cosine_distance
  • streaming_percentage_less
  • streaming_mean_iou

其实常用的评价指标就那几个accuracy, precision, recall。

其中疑惑的地方是关于true_positives相关的几个函数的计算,因为这关系到accuracy等指标的计算。

看accuracy的计算

def accuracy(labels,
             predictions,
             weights=None,
             metrics_collections=None,
             updates_collections=None,
             name=None):
  """Calculates how often `predictions` matches `labels`.

  The `accuracy` function creates two local variables, `total` and
  `count` that are used to compute the frequency with which `predictions`
  matches `labels`. This frequency is ultimately returned as `accuracy`: an
  idempotent operation that simply divides `total` by `count`.

  For estimation of the metric over a stream of data, the function creates an
  `update_op` operation that updates these variables and returns the `accuracy`.
  Internally, an `is_correct` operation computes a `Tensor` with elements 1.0
  where the corresponding elements of `predictions` and `labels` match and 0.0
  otherwise. Then `update_op` increments `total` with the reduced sum of the
  product of `weights` and `is_correct`, and it increments `count` with the
  reduced sum of `weights`.

  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.

  Args:
    labels: The ground truth values, a `Tensor` whose shape matches
      `predictions`.
    predictions: The predicted values, a `Tensor` of any shape.
    weights: Optional `Tensor` whose rank is either 0, or the same rank as
      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
      be either `1`, or the same as the corresponding `labels` dimension).
    metrics_collections: An optional list of collections that `accuracy` should
      be added to.
    updates_collections: An optional list of collections that `update_op` should
      be added to.
    name: An optional variable_scope name.

  Returns:
    accuracy: A `Tensor` representing the accuracy, the value of `total` divided
      by `count`.
    update_op: An operation that increments the `total` and `count` variables
      appropriately and whose value matches `accuracy`.

  Raises:
    ValueError: If `predictions` and `labels` have mismatched shapes, or if
      `weights` is not `None` and its shape doesn't match `predictions`, or if
      either `metrics_collections` or `updates_collections` are not a list or
      tuple.
    RuntimeError: If eager execution is enabled.
  """
  if context.executing_eagerly():
    raise RuntimeError('tf.metrics.accuracy is not supported when eager '
                       'execution is enabled.')

  predictions, labels, weights = _remove_squeezable_dimensions(
      predictions=predictions, labels=labels, weights=weights)
  predictions.get_shape().assert_is_compatible_with(labels.get_shape())
  if labels.dtype != predictions.dtype:
    predictions = math_ops.cast(predictions, labels.dtype)
  is_correct = math_ops.to_float(math_ops.equal(predictions, labels))
  return mean(is_correct, weights, metrics_collections, updates_collections,
              name or 'accuracy')

只是简单的把predictions和label转成bool类型,没有计算混淆矩阵,那这个评价指标只适用于二分类问题。

再看true_positives的计算

def true_positives(labels,
                   predictions,
                   weights=None,
                   metrics_collections=None,
                   updates_collections=None,
                   name=None):
  """Sum the weights of true_positives.

  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.

  Args:
    labels: The ground truth values, a `Tensor` whose dimensions must match
      `predictions`. Will be cast to `bool`.
    predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will
      be cast to `bool`.
    weights: Optional `Tensor` whose rank is either 0, or the same rank as
      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
      be either `1`, or the same as the corresponding `labels` dimension).
    metrics_collections: An optional list of collections that the metric
      value variable should be added to.
    updates_collections: An optional list of collections that the metric update
      ops should be added to.
    name: An optional variable_scope name.

  Returns:
    value_tensor: A `Tensor` representing the current value of the metric.
    update_op: An operation that accumulates the error from a batch of data.

  Raises:
    ValueError: If `predictions` and `labels` have mismatched shapes, or if
      `weights` is not `None` and its shape doesn't match `predictions`, or if
      either `metrics_collections` or `updates_collections` are not a list or
      tuple.
    RuntimeError: If eager execution is enabled.
  """
  if context.executing_eagerly():
    raise RuntimeError('tf.metrics.true_positives is not '
                       'supported when eager execution is enabled.')

  with variable_scope.variable_scope(name, 'true_positives',
                                     (predictions, labels, weights)):

    predictions, labels, weights = _remove_squeezable_dimensions(
        predictions=math_ops.cast(predictions, dtype=dtypes.bool),
        labels=math_ops.cast(labels, dtype=dtypes.bool),
        weights=weights)
    is_true_positive = math_ops.logical_and(
        math_ops.equal(labels, True), math_ops.equal(predictions, True))
    return _count_condition(is_true_positive, weights, metrics_collections,
                            updates_collections)

       同样也是将predictions和labels转换成bool型,然后比较计算相等并且为true的数量。那这里也只适用于二分类吧。

       这里还有另外一个true_positives_at_thresholds

def true_positives_at_thresholds(labels,
                                 predictions,
                                 thresholds,
                                 weights=None,
                                 metrics_collections=None,
                                 updates_collections=None,
                                 name=None):
  """Computes true positives at provided threshold values.

  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.

  Args:
    labels: A `Tensor` whose shape matches `predictions`. Will be cast to
      `bool`.
    predictions: A floating point `Tensor` of arbitrary shape and whose values
      are in the range `[0, 1]`.
    thresholds: A python list or tuple of float thresholds in `[0, 1]`.
    weights: Optional `Tensor` whose rank is either 0, or the same rank as
      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
      be either `1`, or the same as the corresponding `labels` dimension).
    metrics_collections: An optional list of collections that `true_positives`
      should be added to.
    updates_collections: An optional list of collections that `update_op` should
      be added to.
    name: An optional variable_scope name.

  Returns:
    true_positives:  A float `Tensor` of shape `[len(thresholds)]`.
    update_op: An operation that updates the `true_positives` variable and
      returns its current value.

  Raises:
    ValueError: If `predictions` and `labels` have mismatched shapes, or if
      `weights` is not `None` and its shape doesn't match `predictions`, or if
      either `metrics_collections` or `updates_collections` are not a list or
      tuple.
    RuntimeError: If eager execution is enabled.
  """
  if context.executing_eagerly():
    raise RuntimeError('tf.metrics.true_positives_at_thresholds is not '
                       'supported when eager execution is enabled.')

  with variable_scope.variable_scope(name, 'true_positives',
                                     (predictions, labels, weights)):
    values, update_ops = _confusion_matrix_at_thresholds(
        labels, predictions, thresholds, weights=weights, includes=('tp',))

    tp_value = _aggregate_variable(values['tp'], metrics_collections)

    if updates_collections:
      ops.add_to_collections(updates_collections, update_ops['tp'])

    return tp_value, update_ops['tp']

这里计算了混淆矩阵,那多分类的评价是不是应该调用的是类似streaming_precision_at_thresholds这样的评价指标。

但是关于accuracy没有streaming_accuracy_at_thresholds这样的接口,是不是要自己实现。

你可能感兴趣的:(tensorflow,tensorflow,slim,metrics,accuracy,precision)