如何处理数据
批量读数据前,通过 data_augmentation_options 类指定预处理操作
data_augmentation_options一系列预处理操作在samples/configs/ssd_mobilenet_v2_coco.config中指定,例如:
data_augmentation_options {
random_horizontal_flip {
}
}
data_augmentation_options {
ssd_random_crop {
}
}
trainer.py 从配置文件中读入指定的预处理操作
data_augmentation_options = [
preprocessor_builder.build(step)
for step in train_config.data_augmentation_options]
object_detection/builders/preprocessor_builder.py 预处理的构建
preprocessor.proto中提供了选项列表:
NormalizeImage normalize_image = 1; //归一化
RandomHorizontalFlip random_horizontal_flip = 2; //水平翻转
RandomPixelValueScale random_pixel_value_scale = 3; //像素值缩放
RandomImageScale random_image_scale = 4; //图片缩放
RandomRGBtoGray random_rgb_to_gray = 5; //rgb到gray
RandomAdjustBrightness random_adjust_brightness = 6; //亮度
RandomAdjustContrast random_adjust_contrast = 7; //对比度
RandomAdjustHue random_adjust_hue = 8; //色度
RandomAdjustSaturation random_adjust_saturation = 9; //饱和度
RandomDistortColor random_distort_color = 10; //扭曲颜色
RandomJitterBoxes random_jitter_boxes = 11; // 随机抖动
RandomCropImage random_crop_image = 12; //剪切
RandomPadImage random_pad_image = 13; //旁白
RandomCropPadImage random_crop_pad_image = 14; //剪切旁白
RandomCropToAspectRatio random_crop_to_aspect_ratio = 15;
RandomBlackPatches random_black_patches = 16;
RandomResizeMethod random_resize_method = 17; //缩放
ScaleBoxesToPixelCoordinates scale_boxes_to_pixel_coordinates = 18;
ResizeImage resize_image = 19; //缩放尺寸
SubtractChannelMean subtract_channel_mean = 20;
SSDRandomCrop ssd_random_crop = 21;
SSDRandomCropPad ssd_random_crop_pad = 22;
SSDRandomCropFixedAspectRatio ssd_random_crop_fixed_aspect_ratio = 23;
SSDRandomCropPadFixedAspectRatio ssd_random_crop_pad_fixed_aspect_ratio = 24;
RandomVerticalFlip random_vertical_flip = 25; //垂直翻转
RandomRotation90 random_rotation90 = 26; //旋转90
RGBtoGray rgb_to_gray = 27;
ConvertClassLogitsToSoftmax convert_class_logits_to_softmax = 28;
RandomAbsolutePadImage random_absolute_pad_image = 29;
RandomSelfConcatImage random_self_concat_image = 30;
目前支持的预处理操作有如下多种,所有的预处理操作详见preprocessor.py中。 参数可以作为键值对提供。
prep_func_arg_map = {
normalize_image: (fields.InputDataFields.image,),
random_horizontal_flip: (
fields.InputDataFields.image,
fields.InputDataFields.groundtruth_boxes,
groundtruth_instance_masks,
groundtruth_keypoints,
),
random_vertical_flip: (
fields.InputDataFields.image,
fields.InputDataFields.groundtruth_boxes,
groundtruth_instance_masks,
groundtruth_keypoints,
),
random_rotation90: (
fields.InputDataFields.image,
fields.InputDataFields.groundtruth_boxes,
groundtruth_instance_masks,
groundtruth_keypoints,
),
random_pixel_value_scale: (fields.InputDataFields.image,),
random_image_scale: (
fields.InputDataFields.image,
groundtruth_instance_masks,
),
random_rgb_to_gray: (fields.InputDataFields.image,),
random_adjust_brightness: (fields.InputDataFields.image,),
random_adjust_contrast: (fields.InputDataFields.image,),
random_adjust_hue: (fields.InputDataFields.image,),
random_adjust_saturation: (fields.InputDataFields.image,),
random_distort_color: (fields.InputDataFields.image,),
random_jitter_boxes: (fields.InputDataFields.groundtruth_boxes,),
random_crop_image: (fields.InputDataFields.image,
fields.InputDataFields.groundtruth_boxes,
fields.InputDataFields.groundtruth_classes,
groundtruth_label_weights,
groundtruth_label_confidences,
multiclass_scores,
groundtruth_instance_masks,
groundtruth_keypoints),
random_pad_image: (fields.InputDataFields.image,
fields.InputDataFields.groundtruth_boxes,
groundtruth_keypoints),
random_absolute_pad_image: (fields.InputDataFields.image,
fields.InputDataFields.groundtruth_boxes),
random_crop_pad_image: (fields.InputDataFields.image,
fields.InputDataFields.groundtruth_boxes,
fields.InputDataFields.groundtruth_classes,
groundtruth_label_weights,
groundtruth_label_confidences,
multiclass_scores),
random_crop_to_aspect_ratio: (
fields.InputDataFields.image,
fields.InputDataFields.groundtruth_boxes,
fields.InputDataFields.groundtruth_classes,
groundtruth_label_weights,
groundtruth_label_confidences,
multiclass_scores,
groundtruth_instance_masks,
groundtruth_keypoints,
),
random_pad_to_aspect_ratio: (
fields.InputDataFields.image,
fields.InputDataFields.groundtruth_boxes,
groundtruth_instance_masks,
groundtruth_keypoints,
),
random_black_patches: (fields.InputDataFields.image,),
retain_boxes_above_threshold: (
fields.InputDataFields.groundtruth_boxes,
fields.InputDataFields.groundtruth_classes,
groundtruth_label_weights,
groundtruth_label_confidences,
multiclass_scores,
groundtruth_instance_masks,
groundtruth_keypoints,
),
image_to_float: (fields.InputDataFields.image,),
random_resize_method: (fields.InputDataFields.image,),
resize_to_range: (
fields.InputDataFields.image,
groundtruth_instance_masks,
),
resize_to_min_dimension: (
fields.InputDataFields.image,
groundtruth_instance_masks,
),
scale_boxes_to_pixel_coordinates: (
fields.InputDataFields.image,
fields.InputDataFields.groundtruth_boxes,
groundtruth_keypoints,
),
resize_image: (
fields.InputDataFields.image,
groundtruth_instance_masks,
),
subtract_channel_mean: (fields.InputDataFields.image,),
one_hot_encoding: (fields.InputDataFields.groundtruth_image_classes,),
rgb_to_gray: (fields.InputDataFields.image,),
random_self_concat_image: (fields.InputDataFields.image,
fields.InputDataFields.groundtruth_boxes,
fields.InputDataFields.groundtruth_classes,
groundtruth_label_weights,
groundtruth_label_confidences,
multiclass_scores),
ssd_random_crop: (fields.InputDataFields.image,
fields.InputDataFields.groundtruth_boxes,
fields.InputDataFields.groundtruth_classes,
groundtruth_label_weights,
groundtruth_label_confidences,
multiclass_scores,
groundtruth_instance_masks,
groundtruth_keypoints),
ssd_random_crop_pad: (fields.InputDataFields.image,
fields.InputDataFields.groundtruth_boxes,
fields.InputDataFields.groundtruth_classes,
groundtruth_label_weights,
groundtruth_label_confidences,
multiclass_scores),
ssd_random_crop_fixed_aspect_ratio: (
fields.InputDataFields.image,
fields.InputDataFields.groundtruth_boxes,
fields.InputDataFields.groundtruth_classes,
groundtruth_label_weights,
groundtruth_label_confidences,
multiclass_scores,
groundtruth_instance_masks,
groundtruth_keypoints),
ssd_random_crop_pad_fixed_aspect_ratio: (
fields.InputDataFields.image,
fields.InputDataFields.groundtruth_boxes,
fields.InputDataFields.groundtruth_classes,
groundtruth_label_weights,
groundtruth_label_confidences,
multiclass_scores,
groundtruth_instance_masks,
groundtruth_keypoints,
),
convert_class_logits_to_softmax: (multiclass_scores,),
}
return prep_func_arg_map
批量数据读取:创建两个队列
队列1 : 开启 N 个线程,每个线程从数据集依次读一条数据,写入队列 1。一个线程从队列 1 每次读 batch_size 条数据
队列2:将队列 1 出队列的数据写入队列 2, 当调用 dequeue 的事实,从队列 2 读取 batch_size 的数据。
批量读数据后,通过 模型的预处理函数进行预处理 detection_model.preprocess 之后,喂给模型。
trainer.py
def create_input_queue(batch_size_per_clone, create_tensor_dict_fn,
batch_queue_capacity, num_batch_queue_threads,
prefetch_queue_capacity, data_augmentation_options):
"""Sets up reader, prefetcher and returns input queue.
Args:
batch_size_per_clone: batch size to use per clone.
create_tensor_dict_fn: function to create tensor dictionary.
batch_queue_capacity: maximum number of elements to store within a queue.
num_batch_queue_threads: number of threads to use for batching.
prefetch_queue_capacity: maximum capacity of the queue used to prefetch
assembled batches.
data_augmentation_options: a list of tuples, where each tuple contains a
data augmentation function and a dictionary containing arguments and their
values (see preprocessor.py).
Returns:
input queue: a batcher.BatchQueue object holding enqueued tensor_dicts
(which hold images, boxes and targets). To get a batch of tensor_dicts,
call input_queue.Dequeue().
"""
#读一条数据
tensor_dict = create_tensor_dict_fn()
#增加维度
tensor_dict[fields.InputDataFields.image] = tf.expand_dims(
tensor_dict[fields.InputDataFields.image], 0)
#将图像转化为float
images = tensor_dict[fields.InputDataFields.image]
float_images = tf.to_float(images)
tensor_dict[fields.InputDataFields.image] = float_images
#是否包含 instance_masks
include_instance_masks = (fields.InputDataFields.groundtruth_instance_masks
in tensor_dict)
#是否包含关键点
include_keypoints = (fields.InputDataFields.groundtruth_keypoints
in tensor_dict)
#是否包含多标签
include_multiclass_scores = (fields.InputDataFields.multiclass_scores
in tensor_dict)
#预处理数据增强
if data_augmentation_options:
tensor_dict = preprocessor.preprocess(
tensor_dict, data_augmentation_options,
func_arg_map=preprocessor.get_default_func_arg_map(
include_label_weights=True,
include_multiclass_scores=include_multiclass_scores,
include_instance_masks=include_instance_masks,
include_keypoints=include_keypoints))
#创建两个队列
#队列1:开启num_batch_queue_threads个线程,每个线程从数据集依次读一条数据,写入队列
#一个线程从队列每次读batchisze条数据
#队列2:从队列1出队列的数据写入队列2,容量为prefetch_queue_capacity,当调用 dequeue 的时候,从队列 2 读取 batch_size 的数据。
input_queue = batcher.BatchQueue(
tensor_dict,
batch_size=batch_size_per_clone,
batch_queue_capacity=batch_queue_capacity,
num_batch_queue_threads=num_batch_queue_threads,
prefetch_queue_capacity=prefetch_queue_capacity)
return input_queue
def get_inputs(input_queue,
num_classes,
merge_multiple_label_boxes=False,
use_multiclass_scores=False):
"""Dequeues batch and constructs inputs to object detection model.
Args:
input_queue: BatchQueue object holding enqueued tensor_dicts.
num_classes: Number of classes.
merge_multiple_label_boxes: Whether to merge boxes with multiple labels
or not. Defaults to false. Merged boxes are represented with a single
box and a k-hot encoding of the multiple labels associated with the
boxes.
use_multiclass_scores: Whether to use multiclass scores instead of
groundtruth_classes.
Returns:
images: a list of 3-D float tensor of images.
image_keys: a list of string keys for the images.
locations_list: a list of tensors of shape [num_boxes, 4]
containing the corners of the groundtruth boxes.
classes_list: a list of padded one-hot (or K-hot) float32 tensors containing
target classes.
masks_list: a list of 3-D float tensors of shape [num_boxes, image_height,
image_width] containing instance masks for objects if present in the
input_queue. Else returns None.
keypoints_list: a list of 3-D float tensors of shape [num_boxes,
num_keypoints, 2] containing keypoints for objects if present in the
input queue. Else returns None.
weights_lists: a list of 1-D float32 tensors of shape [num_boxes]
containing groundtruth weight for each box.
"""
#从预提取队列中取一份数据[batchsize,height,width,3]
read_data_list = input_queue.dequeue()
label_id_offset = 1
#解析读到的数据
def extract_images_and_targets(read_data):
"""Extract images and targets from the input dict."""
image = read_data[fields.InputDataFields.image]
key = ''
if fields.InputDataFields.source_id in read_data:
key = read_data[fields.InputDataFields.source_id]
location_gt = read_data[fields.InputDataFields.groundtruth_boxes]
classes_gt = tf.cast(read_data[fields.InputDataFields.groundtruth_classes],
tf.int32)
classes_gt -= label_id_offset
if merge_multiple_label_boxes and use_multiclass_scores:
raise ValueError(
'Using both merge_multiple_label_boxes and use_multiclass_scores is'
'not supported'
)
if merge_multiple_label_boxes:
location_gt, classes_gt, _ = util_ops.merge_boxes_with_multiple_labels(
location_gt, classes_gt, num_classes)
classes_gt = tf.cast(classes_gt, tf.float32)
elif use_multiclass_scores:
classes_gt = tf.cast(read_data[fields.InputDataFields.multiclass_scores],
tf.float32)
else:
classes_gt = util_ops.padded_one_hot_encoding(
indices=classes_gt, depth=num_classes, left_pad=0)
masks_gt = read_data.get(fields.InputDataFields.groundtruth_instance_masks)
keypoints_gt = read_data.get(fields.InputDataFields.groundtruth_keypoints)
if (merge_multiple_label_boxes and (
masks_gt is not None or keypoints_gt is not None)):
raise NotImplementedError('Multi-label support is only for boxes.')
weights_gt = read_data.get(
fields.InputDataFields.groundtruth_weights)
return (image, key, location_gt, classes_gt, masks_gt, keypoints_gt,
weights_gt)
return zip(*map(extract_images_and_targets, read_data_list))
参考:
https://blog.csdn.net/weixin_39881922/article/details/87982524