数据读取与处理

       

 如何处理数据

批量读数据前,通过 data_augmentation_options 类指定预处理操作
data_augmentation_options一系列预处理操作在samples/configs/ssd_mobilenet_v2_coco.config中指定,例如:
data_augmentation_options {
    random_horizontal_flip {
    }
  }
  data_augmentation_options {
    ssd_random_crop {
    }
  }

trainer.py 从配置文件中读入指定的预处理操作
data_augmentation_options = [
      preprocessor_builder.build(step)
      for step in train_config.data_augmentation_options]
object_detection/builders/preprocessor_builder.py  预处理的构建

preprocessor.proto中提供了选项列表:

NormalizeImage normalize_image = 1; //归一化
RandomHorizontalFlip random_horizontal_flip = 2; //水平翻转
RandomPixelValueScale random_pixel_value_scale = 3; //像素值缩放
RandomImageScale random_image_scale = 4; //图片缩放
RandomRGBtoGray random_rgb_to_gray = 5; //rgb到gray
RandomAdjustBrightness random_adjust_brightness = 6; //亮度
RandomAdjustContrast random_adjust_contrast = 7; //对比度
RandomAdjustHue random_adjust_hue = 8; //色度
RandomAdjustSaturation random_adjust_saturation = 9; //饱和度
RandomDistortColor random_distort_color = 10; //扭曲颜色
RandomJitterBoxes random_jitter_boxes = 11; // 随机抖动
RandomCropImage random_crop_image = 12; //剪切
RandomPadImage random_pad_image = 13; //旁白
RandomCropPadImage random_crop_pad_image = 14; //剪切旁白
RandomCropToAspectRatio random_crop_to_aspect_ratio = 15;
RandomBlackPatches random_black_patches = 16;
RandomResizeMethod random_resize_method = 17; //缩放
ScaleBoxesToPixelCoordinates scale_boxes_to_pixel_coordinates = 18;
ResizeImage resize_image = 19; //缩放尺寸
SubtractChannelMean subtract_channel_mean = 20;
SSDRandomCrop ssd_random_crop = 21;
SSDRandomCropPad ssd_random_crop_pad = 22;
SSDRandomCropFixedAspectRatio ssd_random_crop_fixed_aspect_ratio = 23;
SSDRandomCropPadFixedAspectRatio ssd_random_crop_pad_fixed_aspect_ratio = 24;
RandomVerticalFlip random_vertical_flip = 25; //垂直翻转
RandomRotation90 random_rotation90 = 26; //旋转90
RGBtoGray rgb_to_gray = 27;
ConvertClassLogitsToSoftmax convert_class_logits_to_softmax = 28;
RandomAbsolutePadImage random_absolute_pad_image = 29;
RandomSelfConcatImage random_self_concat_image = 30;

目前支持的预处理操作有如下多种,所有的预处理操作详见preprocessor.py中。 参数可以作为键值对提供。

prep_func_arg_map = {
      normalize_image: (fields.InputDataFields.image,),
      random_horizontal_flip: (
          fields.InputDataFields.image,
          fields.InputDataFields.groundtruth_boxes,
          groundtruth_instance_masks,
          groundtruth_keypoints,
      ),
      random_vertical_flip: (
          fields.InputDataFields.image,
          fields.InputDataFields.groundtruth_boxes,
          groundtruth_instance_masks,
          groundtruth_keypoints,
      ),
      random_rotation90: (
          fields.InputDataFields.image,
          fields.InputDataFields.groundtruth_boxes,
          groundtruth_instance_masks,
          groundtruth_keypoints,
      ),
      random_pixel_value_scale: (fields.InputDataFields.image,),
      random_image_scale: (
          fields.InputDataFields.image,
          groundtruth_instance_masks,
      ),
      random_rgb_to_gray: (fields.InputDataFields.image,),
      random_adjust_brightness: (fields.InputDataFields.image,),
      random_adjust_contrast: (fields.InputDataFields.image,),
      random_adjust_hue: (fields.InputDataFields.image,),
      random_adjust_saturation: (fields.InputDataFields.image,),
      random_distort_color: (fields.InputDataFields.image,),
      random_jitter_boxes: (fields.InputDataFields.groundtruth_boxes,),
      random_crop_image: (fields.InputDataFields.image,
                          fields.InputDataFields.groundtruth_boxes,
                          fields.InputDataFields.groundtruth_classes,
                          groundtruth_label_weights,
                          groundtruth_label_confidences,
                          multiclass_scores,
                          groundtruth_instance_masks,
                          groundtruth_keypoints),
      random_pad_image: (fields.InputDataFields.image,
                         fields.InputDataFields.groundtruth_boxes,
                         groundtruth_keypoints),
      random_absolute_pad_image: (fields.InputDataFields.image,
                                  fields.InputDataFields.groundtruth_boxes),
      random_crop_pad_image: (fields.InputDataFields.image,
                              fields.InputDataFields.groundtruth_boxes,
                              fields.InputDataFields.groundtruth_classes,
                              groundtruth_label_weights,
                              groundtruth_label_confidences,
                              multiclass_scores),
      random_crop_to_aspect_ratio: (
          fields.InputDataFields.image,
          fields.InputDataFields.groundtruth_boxes,
          fields.InputDataFields.groundtruth_classes,
          groundtruth_label_weights,
          groundtruth_label_confidences,
          multiclass_scores,
          groundtruth_instance_masks,
          groundtruth_keypoints,
      ),
      random_pad_to_aspect_ratio: (
          fields.InputDataFields.image,
          fields.InputDataFields.groundtruth_boxes,
          groundtruth_instance_masks,
          groundtruth_keypoints,
      ),
      random_black_patches: (fields.InputDataFields.image,),
      retain_boxes_above_threshold: (
          fields.InputDataFields.groundtruth_boxes,
          fields.InputDataFields.groundtruth_classes,
          groundtruth_label_weights,
          groundtruth_label_confidences,
          multiclass_scores,
          groundtruth_instance_masks,
          groundtruth_keypoints,
      ),
      image_to_float: (fields.InputDataFields.image,),
      random_resize_method: (fields.InputDataFields.image,),
      resize_to_range: (
          fields.InputDataFields.image,
          groundtruth_instance_masks,
      ),
      resize_to_min_dimension: (
          fields.InputDataFields.image,
          groundtruth_instance_masks,
      ),
      scale_boxes_to_pixel_coordinates: (
          fields.InputDataFields.image,
          fields.InputDataFields.groundtruth_boxes,
          groundtruth_keypoints,
      ),
      resize_image: (
          fields.InputDataFields.image,
          groundtruth_instance_masks,
      ),
      subtract_channel_mean: (fields.InputDataFields.image,),
      one_hot_encoding: (fields.InputDataFields.groundtruth_image_classes,),
      rgb_to_gray: (fields.InputDataFields.image,),
      random_self_concat_image: (fields.InputDataFields.image,
                                 fields.InputDataFields.groundtruth_boxes,
                                 fields.InputDataFields.groundtruth_classes,
                                 groundtruth_label_weights,
                                 groundtruth_label_confidences,
                                 multiclass_scores),
      ssd_random_crop: (fields.InputDataFields.image,
                        fields.InputDataFields.groundtruth_boxes,
                        fields.InputDataFields.groundtruth_classes,
                        groundtruth_label_weights,
                        groundtruth_label_confidences,
                        multiclass_scores,
                        groundtruth_instance_masks,
                        groundtruth_keypoints),
      ssd_random_crop_pad: (fields.InputDataFields.image,
                            fields.InputDataFields.groundtruth_boxes,
                            fields.InputDataFields.groundtruth_classes,
                            groundtruth_label_weights,
                            groundtruth_label_confidences,
                            multiclass_scores),
      ssd_random_crop_fixed_aspect_ratio: (
          fields.InputDataFields.image,
          fields.InputDataFields.groundtruth_boxes,
          fields.InputDataFields.groundtruth_classes,
          groundtruth_label_weights,
          groundtruth_label_confidences,
          multiclass_scores,
          groundtruth_instance_masks,
          groundtruth_keypoints),
      ssd_random_crop_pad_fixed_aspect_ratio: (
          fields.InputDataFields.image,
          fields.InputDataFields.groundtruth_boxes,
          fields.InputDataFields.groundtruth_classes,
          groundtruth_label_weights,
          groundtruth_label_confidences,
          multiclass_scores,
          groundtruth_instance_masks,
          groundtruth_keypoints,
      ),
      convert_class_logits_to_softmax: (multiclass_scores,),
  }

  return prep_func_arg_map

批量数据读取:创建两个队列 

队列1 : 开启 N 个线程,每个线程从数据集依次读一条数据,写入队列 1。一个线程从队列 1 每次读 batch_size 条数据 
队列2:将队列 1 出队列的数据写入队列 2, 当调用 dequeue 的事实,从队列 2 读取 batch_size 的数据。
批量读数据后,通过 模型的预处理函数进行预处理 detection_model.preprocess 之后,喂给模型。

trainer.py

def create_input_queue(batch_size_per_clone, create_tensor_dict_fn,
                       batch_queue_capacity, num_batch_queue_threads,
                       prefetch_queue_capacity, data_augmentation_options):
  """Sets up reader, prefetcher and returns input queue.
 
  Args:
    batch_size_per_clone: batch size to use per clone.
    create_tensor_dict_fn: function to create tensor dictionary.
    batch_queue_capacity: maximum number of elements to store within a queue.
    num_batch_queue_threads: number of threads to use for batching.
    prefetch_queue_capacity: maximum capacity of the queue used to prefetch
                             assembled batches.
    data_augmentation_options: a list of tuples, where each tuple contains a
      data augmentation function and a dictionary containing arguments and their
      values (see preprocessor.py).
 
  Returns:
    input queue: a batcher.BatchQueue object holding enqueued tensor_dicts
      (which hold images, boxes and targets).  To get a batch of tensor_dicts,
      call input_queue.Dequeue().
  """
  #读一条数据
  tensor_dict = create_tensor_dict_fn()
  #增加维度
  tensor_dict[fields.InputDataFields.image] = tf.expand_dims(
      tensor_dict[fields.InputDataFields.image], 0)
  #将图像转化为float
  images = tensor_dict[fields.InputDataFields.image]
  float_images = tf.to_float(images)
  tensor_dict[fields.InputDataFields.image] = float_images
  #是否包含 instance_masks
  include_instance_masks = (fields.InputDataFields.groundtruth_instance_masks
                            in tensor_dict)
  #是否包含关键点
  include_keypoints = (fields.InputDataFields.groundtruth_keypoints
                       in tensor_dict)
  #是否包含多标签
  include_multiclass_scores = (fields.InputDataFields.multiclass_scores
                               in tensor_dict)
  #预处理数据增强
  if data_augmentation_options:
    tensor_dict = preprocessor.preprocess(
        tensor_dict, data_augmentation_options,
        func_arg_map=preprocessor.get_default_func_arg_map(
            include_label_weights=True,
            include_multiclass_scores=include_multiclass_scores,
            include_instance_masks=include_instance_masks,
            include_keypoints=include_keypoints))
  #创建两个队列
  #队列1:开启num_batch_queue_threads个线程,每个线程从数据集依次读一条数据,写入队列
  #一个线程从队列每次读batchisze条数据
  #队列2:从队列1出队列的数据写入队列2,容量为prefetch_queue_capacity,当调用 dequeue 的时候,从队列 2 读取 batch_size 的数据。
  input_queue = batcher.BatchQueue(
      tensor_dict,
      batch_size=batch_size_per_clone,
      batch_queue_capacity=batch_queue_capacity,
      num_batch_queue_threads=num_batch_queue_threads,
      prefetch_queue_capacity=prefetch_queue_capacity)
  return input_queue
def get_inputs(input_queue,
               num_classes,
               merge_multiple_label_boxes=False,
               use_multiclass_scores=False):
  """Dequeues batch and constructs inputs to object detection model.
 
  Args:
    input_queue: BatchQueue object holding enqueued tensor_dicts.
    num_classes: Number of classes.
    merge_multiple_label_boxes: Whether to merge boxes with multiple labels
      or not. Defaults to false. Merged boxes are represented with a single
      box and a k-hot encoding of the multiple labels associated with the
      boxes.
    use_multiclass_scores: Whether to use multiclass scores instead of
      groundtruth_classes.
 
  Returns:
    images: a list of 3-D float tensor of images.
    image_keys: a list of string keys for the images.
    locations_list: a list of tensors of shape [num_boxes, 4]
      containing the corners of the groundtruth boxes.
    classes_list: a list of padded one-hot (or K-hot) float32 tensors containing
      target classes.
    masks_list: a list of 3-D float tensors of shape [num_boxes, image_height,
      image_width] containing instance masks for objects if present in the
      input_queue. Else returns None.
    keypoints_list: a list of 3-D float tensors of shape [num_boxes,
      num_keypoints, 2] containing keypoints for objects if present in the
      input queue. Else returns None.
    weights_lists: a list of 1-D float32 tensors of shape [num_boxes]
      containing groundtruth weight for each box.
  """
  #从预提取队列中取一份数据[batchsize,height,width,3]
  read_data_list = input_queue.dequeue()
  label_id_offset = 1
  #解析读到的数据
  def extract_images_and_targets(read_data):
    """Extract images and targets from the input dict."""
    image = read_data[fields.InputDataFields.image]
    key = ''
    if fields.InputDataFields.source_id in read_data:
      key = read_data[fields.InputDataFields.source_id]
    location_gt = read_data[fields.InputDataFields.groundtruth_boxes]
    classes_gt = tf.cast(read_data[fields.InputDataFields.groundtruth_classes],
                         tf.int32)
    classes_gt -= label_id_offset
    if merge_multiple_label_boxes and use_multiclass_scores:
      raise ValueError(
          'Using both merge_multiple_label_boxes and use_multiclass_scores is'
          'not supported'
      )
    if merge_multiple_label_boxes:
      location_gt, classes_gt, _ = util_ops.merge_boxes_with_multiple_labels(
          location_gt, classes_gt, num_classes)
      classes_gt = tf.cast(classes_gt, tf.float32)
    elif use_multiclass_scores:
      classes_gt = tf.cast(read_data[fields.InputDataFields.multiclass_scores],
                           tf.float32)
    else:
      classes_gt = util_ops.padded_one_hot_encoding(
          indices=classes_gt, depth=num_classes, left_pad=0)
    masks_gt = read_data.get(fields.InputDataFields.groundtruth_instance_masks)
    keypoints_gt = read_data.get(fields.InputDataFields.groundtruth_keypoints)
    if (merge_multiple_label_boxes and (
        masks_gt is not None or keypoints_gt is not None)):
      raise NotImplementedError('Multi-label support is only for boxes.')
    weights_gt = read_data.get(
        fields.InputDataFields.groundtruth_weights)
    return (image, key, location_gt, classes_gt, masks_gt, keypoints_gt,
            weights_gt)
  return zip(*map(extract_images_and_targets, read_data_list))

参考: 

https://blog.csdn.net/weixin_39881922/article/details/87982524

你可能感兴趣的:(【AI入门】)