初窥Tensorflow Object Detection API 源码之(2.1)FasterRCNNMetaArch

  • class FasterRCNNMetaArchmodelDetectionModel
    • init
    • preprocess
      • image_resizer
      • 调用FeatureExtractor的preprocess处理resized_inputs
    • predict
    • 未完待续

DetectionModel关于FasterRCNN的派生类

object_detection/meta_architectures/faster_rcnn_meta_arch.py

class FasterRCNNMetaArch(model.DetectionModel):

init

一系列初始化,之后再详述

preprocess

实现父类的同名函数

def preprocess(self, inputs):

Args:
      inputs: a [batch, height_in, width_in, channels] float tensor representing
        a batch of images with values between 0 and 255.0.

此处有一个细节:
inputs是一个tensor矩阵,默认batch个input拥有同样的尺寸
这对两个方面有影响:
1. train_config的batch_size
2. image_resizer

我做过以下尝试:
batch_size=2,
image_resizer=keep_aspect_ratio_resizer,
而输入图像的尺寸不是相同的,train.py时出现内部tf.concat出错,因为image的尺寸不一致

if inputs.dtype is not tf.float32:
      raise ValueError('`preprocess` expects a tf.float32 tensor')
    with tf.name_scope('Preprocessor'):
      outputs = shape_utils.static_or_dynamic_map_fn(
          self._image_resizer_fn,
          elems=inputs,
          dtype=[tf.float32, tf.int32],
          parallel_iterations=self._parallel_iterations)
      resized_inputs = outputs[0]
      true_image_shapes = outputs[1]
      return (self._feature_extractor.preprocess(resized_inputs),
              true_image_shapes)

image_resizer

models/research/object_detection/protos/image_resizer.proto

message ImageResizer {
  oneof image_resizer_oneof {
    KeepAspectRatioResizer keep_aspect_ratio_resizer = 1;
    FixedShapeResizer fixed_shape_resizer = 2;
  }
}

// Configuration proto for image resizer that keeps aspect ratio.
message KeepAspectRatioResizer {
  // Desired size of the smaller image dimension in pixels.
  optional int32 min_dimension = 1 [default = 600];

  // Desired size of the larger image dimension in pixels.
  optional int32 max_dimension = 2 [default = 1024];

  // Desired method when resizing image.
  optional ResizeType resize_method = 3 [default = BILINEAR];

  // Whether to pad the image with zeros so the output spatial size is
  // [max_dimension, max_dimension]. Note that the zeros are padded to the
  // bottom and the right of the resized image.
  optional bool pad_to_max_dimension = 4 [default = false];

  // Whether to also resize the image channels from 3 to 1 (RGB to grayscale).
  optional bool convert_to_grayscale = 5 [default = false];
}

于是,搜索keep_aspect_ratio_resizer,在

models/research/object_detection/builders/image_resizer_builder.py

里找到了它

if image_resizer_config.WhichOneof(
      'image_resizer_oneof') == 'keep_aspect_ratio_resizer':
    keep_aspect_ratio_config = image_resizer_config.keep_aspect_ratio_resizer
    if not (keep_aspect_ratio_config.min_dimension <=
            keep_aspect_ratio_config.max_dimension):
      raise ValueError('min_dimension > max_dimension')
    method = _tf_resize_method(keep_aspect_ratio_config.resize_method)
    image_resizer_fn = functools.partial(
        preprocessor.resize_to_range,
        min_dimension=keep_aspect_ratio_config.min_dimension,
        max_dimension=keep_aspect_ratio_config.max_dimension,
        method=method,
        pad_to_max_dimension=keep_aspect_ratio_config.pad_to_max_dimension)
    if not keep_aspect_ratio_config.convert_to_grayscale:
      return image_resizer_fn
  elif image_resizer_config.WhichOneof(
      'image_resizer_oneof') == 'fixed_shape_resizer':
    fixed_shape_resizer_config = image_resizer_config.fixed_shape_resizer
    method = _tf_resize_method(fixed_shape_resizer_config.resize_method)
    image_resizer_fn = functools.partial(
        preprocessor.resize_image,
        new_height=fixed_shape_resizer_config.height,
        new_width=fixed_shape_resizer_config.width,
        method=method)
    if not fixed_shape_resizer_config.convert_to_grayscale:
      return image_resizer_fn
  else:
    raise ValueError('Invalid image resizer option.')

我们拿keep_aspect_ratio_resizer举个例子,它的作用主要是:
1. 与源图的长宽比保持一致
2. 在1的基础上确保最小尺寸或最大尺寸

调用FeatureExtractor的preprocess处理resized_inputs

channel_means = [123.68, 116.779, 103.939]
return resized_inputs - [[channel_means]]

初窥Tensorflow Object Detection API 源码之(2.1)FasterRCNNMetaArch_第1张图片

predict

未完待续

你可能感兴趣的:(Tensorflow,OD,API)