object_detection/train.py
关于TFRecord文件的位置一般在samples/configs/*.config中
举例:samples/configs/ssd_mobilenet_v1_coco.config
train_input_reader: {
tf_record_input_reader {
input_path: "PATH_TO_BE_CONFIGURED/mscoco_train.record"
}
label_map_path: "PATH_TO_BE_CONFIGURED/mscoco_label_map.pbtxt"
}
utils/config_util.py从config文件中读取各个配置
def create_configs_from_pipeline_proto(pipeline_config):
"""Creates a configs dictionary from pipeline_pb2.TrainEvalPipelineConfig.
Args:
pipeline_config: pipeline_pb2.TrainEvalPipelineConfig proto object.
Returns:
Dictionary of configuration objects. Keys are `model`, `train_config`,
`train_input_config`, `eval_config`, `eval_input_configs`. Value are
the corresponding config objects or list of config objects (only for
eval_input_configs).
"""
configs = {}
configs["model"] = pipeline_config.model
configs["train_config"] = pipeline_config.train_config
configs["train_input_config"] = pipeline_config.train_input_reader
configs["eval_config"] = pipeline_config.eval_config
configs["eval_input_configs"] = pipeline_config.eval_input_reader
# Keeps eval_input_config only for backwards compatibility. All clients should
# read eval_input_configs instead.
if configs["eval_input_configs"]:
configs["eval_input_config"] = configs["eval_input_configs"][0]
if pipeline_config.HasField("graph_rewriter"):
configs["graph_rewriter_config"] = pipeline_config.graph_rewriter
return configs
train.py读取数据
input_config = configs['train_input_config']
def get_next(config):
return dataset_util.make_initializable_iterator(
dataset_builder.build(config)).get_next()
#tf.data.Iterator,一个迭代器
create_input_dict_fn = functools.partial(get_next, input_config)
builders/dataset_builder.py
整个流程
core/standard_fields.py
class InputDataFields(object):
"""Names for the input tensors.
Holds the standard data field names to use for identifying input tensors. This
should be used by the decoder to identify keys for the returned tensor_dict
containing input tensors. And it should be used by the model to identify the
tensors it needs.
Attributes:
image: image.
image_additional_channels: additional channels.
original_image: image in the original input size.
key: unique key corresponding to image.
source_id: source of the original image.
filename: original filename of the dataset (without common path).
groundtruth_image_classes: image-level class labels.
groundtruth_boxes: coordinates of the ground truth boxes in the image.
groundtruth_classes: box-level class labels.
groundtruth_label_types: box-level label types (e.g. explicit negative).
groundtruth_is_crowd: [DEPRECATED, use groundtruth_group_of instead]
is the groundtruth a single object or a crowd.
groundtruth_area: area of a groundtruth segment.
groundtruth_difficult: is a `difficult` object
groundtruth_group_of: is a `group_of` objects, e.g. multiple objects of the
same class, forming a connected group, where instances are heavily
occluding each other.
proposal_boxes: coordinates of object proposal boxes.
proposal_objectness: objectness score of each proposal.
groundtruth_instance_masks: ground truth instance masks.
groundtruth_instance_boundaries: ground truth instance boundaries.
groundtruth_instance_classes: instance mask-level class labels.
groundtruth_keypoints: ground truth keypoints.
groundtruth_keypoint_visibilities: ground truth keypoint visibilities.
groundtruth_label_scores: groundtruth label scores.
groundtruth_weights: groundtruth weight factor for bounding boxes.
num_groundtruth_boxes: number of groundtruth boxes.
true_image_shapes: true shapes of images in the resized images, as resized
images can be padded with zeros.
verified_labels: list of human-verified image-level labels (note, that a
label can be verified both as positive and negative).
multiclass_scores: the label score per class for each box.
"""
image = 'image'
image_additional_channels = 'image_additional_channels'
original_image = 'original_image'
key = 'key'
source_id = 'source_id'
filename = 'filename'
groundtruth_image_classes = 'groundtruth_image_classes'
groundtruth_boxes = 'groundtruth_boxes'
groundtruth_classes = 'groundtruth_classes'
groundtruth_label_types = 'groundtruth_label_types'
groundtruth_is_crowd = 'groundtruth_is_crowd'
groundtruth_area = 'groundtruth_area'
groundtruth_difficult = 'groundtruth_difficult'
groundtruth_group_of = 'groundtruth_group_of'
proposal_boxes = 'proposal_boxes'
proposal_objectness = 'proposal_objectness'
groundtruth_instance_masks = 'groundtruth_instance_masks'
groundtruth_instance_boundaries = 'groundtruth_instance_boundaries'
groundtruth_instance_classes = 'groundtruth_instance_classes'
groundtruth_keypoints = 'groundtruth_keypoints'
groundtruth_keypoint_visibilities = 'groundtruth_keypoint_visibilities'
groundtruth_label_scores = 'groundtruth_label_scores'
groundtruth_weights = 'groundtruth_weights'
num_groundtruth_boxes = 'num_groundtruth_boxes'
true_image_shape = 'true_image_shape'
verified_labels = 'verified_labels'
multiclass_scores = 'multiclass_scores'
object_detection/builders/dataset_builder.py
# 返回 dataset.output_shapes 各个 key 对应的 shape
def _get_padding_shapes(dataset, max_num_boxes=None, num_classes=None,
spatial_image_shape=None):
"""Returns shapes to pad dataset tensors to before batching.
Args:
dataset: tf.data.Dataset object.
max_num_boxes: Max number of groundtruth boxes needed to computes shapes for
padding.
num_classes: Number of classes in the dataset needed to compute shapes for
padding.
spatial_image_shape: A list of two integers of the form [height, width]
containing expected spatial shape of the image.
Returns:
A dictionary keyed by fields.InputDataFields containing padding shapes for
tensors in the dataset.
Raises:
ValueError: If groundtruth classes is neither rank 1 nor rank 2.
"""
if not spatial_image_shape or spatial_image_shape == [-1, -1]:
height, width = None, None
else:
height, width = spatial_image_shape # pylint: disable=unpacking-non-sequence
padding_shapes = {
fields.InputDataFields.image: [height, width, 3],
fields.InputDataFields.source_id: [],
fields.InputDataFields.filename: [],
fields.InputDataFields.key: [],
fields.InputDataFields.groundtruth_difficult: [max_num_boxes],
fields.InputDataFields.groundtruth_boxes: [max_num_boxes, 4],
fields.InputDataFields.groundtruth_instance_masks: [max_num_boxes, height,
width],
fields.InputDataFields.groundtruth_is_crowd: [max_num_boxes],
fields.InputDataFields.groundtruth_group_of: [max_num_boxes],
fields.InputDataFields.groundtruth_is_crowd: [max_num_boxes],
fields.InputDataFields.groundtruth_group_of: [max_num_boxes],
fields.InputDataFields.groundtruth_area: [max_num_boxes],
fields.InputDataFields.groundtruth_weights: [max_num_boxes],
fields.InputDataFields.num_groundtruth_boxes: [],
fields.InputDataFields.groundtruth_label_types: [max_num_boxes],
fields.InputDataFields.groundtruth_label_scores: [max_num_boxes],
fields.InputDataFields.true_image_shape: [3],
fields.InputDataFields.multiclass_scores: [
max_num_boxes, num_classes + 1 if num_classes is not None else None],
}
# Determine whether groundtruth_classes are integers or one-hot encodings, and
# apply batching appropriately.
classes_shape = dataset.output_shapes[
fields.InputDataFields.groundtruth_classes]
if len(classes_shape) == 1: # Class integers.
padding_shapes[fields.InputDataFields.groundtruth_classes] = [max_num_boxes]
elif len(classes_shape) == 2: # One-hot or k-hot encoding.
padding_shapes[fields.InputDataFields.groundtruth_classes] = [
max_num_boxes, num_classes]
else:
raise ValueError('Groundtruth classes must be a rank 1 tensor (classes) or '
'rank 2 tensor (one-hot encodings)')
if fields.InputDataFields.original_image in dataset.output_shapes:
padding_shapes[fields.InputDataFields.original_image] = [None, None, 3]
if fields.InputDataFields.groundtruth_keypoints in dataset.output_shapes:
tensor_shape = dataset.output_shapes[fields.InputDataFields.
groundtruth_keypoints]
padding_shape = [max_num_boxes, tensor_shape[1].value,
tensor_shape[2].value]
padding_shapes[fields.InputDataFields.groundtruth_keypoints] = padding_shape
if (fields.InputDataFields.groundtruth_keypoint_visibilities
in dataset.output_shapes):
tensor_shape = dataset.output_shapes[fields.InputDataFields.
groundtruth_keypoint_visibilities]
padding_shape = [max_num_boxes, tensor_shape[1].value]
padding_shapes[fields.InputDataFields.
groundtruth_keypoint_visibilities] = padding_shape
return {tensor_key: padding_shapes[tensor_key]
for tensor_key, _ in dataset.output_shapes.items()}
def build(input_reader_config,
transform_input_data_fn=None,
batch_size=None,
max_num_boxes=None,
num_classes=None,
spatial_image_shape=None,
num_additional_channels=0):
"""Builds a tf.data.Dataset.
Builds a tf.data.Dataset by applying the `transform_input_data_fn` on all
records. Applies a padded batch to the resulting dataset.
Args:
input_reader_config: A input_reader_pb2.InputReader object.
transform_input_data_fn: Function to apply to all records, or None if
no extra decoding is required.
batch_size: Batch size. If None, batching is not performed.
max_num_boxes: Max number of groundtruth boxes needed to compute shapes for
padding. If None, will use a dynamic shape.
num_classes: Number of classes in the dataset needed to compute shapes for
padding. If None, will use a dynamic shape.
spatial_image_shape: A list of two integers of the form [height, width]
containing expected spatial shape of the image after applying
transform_input_data_fn. If None, will use dynamic shapes.
num_additional_channels: Number of additional channels to use in the input.
Returns:
A tf.data.Dataset based on the input_reader_config.
Raises:
ValueError: On invalid input reader proto.
ValueError: If no input paths are specified.
"""
'''
train_input_reader: {
tf_record_input_reader {
input_path: "data/wider_udlr.record"
}
label_map_path: "data/udlr_face.pbtxt"
}
'''
#input_reader_config必须是input_reader_pb2.InputReader类型对象
if not isinstance(input_reader_config, input_reader_pb2.InputReader):
raise ValueError('input_reader_config not of type '
'input_reader_pb2.InputReader.')
if input_reader_config.WhichOneof('input_reader') == 'tf_record_input_reader':
#input_path,读取tfrecord文件地址
config = input_reader_config.tf_record_input_reader
if not config.input_path:
raise ValueError('At least one input path must be specified in '
'`input_reader_config`.')
label_map_proto_file = None
if input_reader_config.HasField('label_map_path'):
#读取label地址
label_map_proto_file = input_reader_config.label_map_path
#初始化需要解码的字段,以及解码对应字段的 handler
decoder = tf_example_decoder.TfExampleDecoder(
load_instance_masks=input_reader_config.load_instance_masks,
instance_mask_type=input_reader_config.mask_type,
label_map_proto_file=label_map_proto_file,
use_display_name=input_reader_config.use_display_name,
num_additional_channels=num_additional_channels)
# 调用 tf.data.TFRecordDataset 从 config.input_path 读数据,调用 process_fn 解码
# 数据,预提取 input_reader_config.prefetch_size 条数据
def process_fn(value):
processed = decoder.decode(value)
if transform_input_data_fn is not None:
return transform_input_data_fn(processed)
return processed
dataset = dataset_util.read_dataset(
functools.partial(tf.data.TFRecordDataset, buffer_size=8 * 1000 * 1000),
process_fn, config.input_path[:], input_reader_config)
if batch_size:
padding_shapes = _get_padding_shapes(dataset, max_num_boxes, num_classes,
spatial_image_shape)
dataset = dataset.apply(
tf.contrib.data.padded_batch_and_drop_remainder(batch_size,
padding_shapes))
return dataset
raise ValueError('Unsupported input_reader_config.')
def make_initializable_iterator(dataset):
"""Creates an iterator, and initializes tables.
This is useful in cases where make_one_shot_iterator wouldn't work because
the graph contains a hash table that needs to be initialized.
Args:
dataset: A `tf.data.Dataset` object.
Returns:
A `tf.data.Iterator`.
"""
iterator = dataset.make_initializable_iterator()
tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS, iterator.initializer)
return iterator
# 调用 file_read_func 从 input_files 中读数据
# 调用 decode_func 将读到的数据解码
# 预提取 config.prefetch_size
def read_dataset(file_read_func, decode_func, input_files, config):
"""Reads a dataset, and handles repetition and shuffling.
Args:
file_read_func: Function to use in tf.data.Dataset.interleave, to read
every individual file into a tf.data.Dataset.
decode_func: Function to apply to all records.
input_files: A list of file paths to read.
config: A input_reader_builder.InputReader object.
Returns:
A tf.data.Dataset based on config.
"""
# Shard, shuffle, and read files.
filenames = tf.gfile.Glob(input_files)
num_readers = config.num_readers
if num_readers > len(filenames):
num_readers = len(filenames)
tf.logging.warning('num_readers has been reduced to %d to match input file '
'shards.' % num_readers)
filename_dataset = tf.data.Dataset.from_tensor_slices(tf.unstack(filenames))
if config.shuffle:
filename_dataset = filename_dataset.shuffle(
config.filenames_shuffle_buffer_size)
elif num_readers > 1:
tf.logging.warning('`shuffle` is false, but the input data stream is '
'still slightly shuffled since `num_readers` > 1.')
filename_dataset = filename_dataset.repeat(config.num_epochs or None)
records_dataset = filename_dataset.apply(
tf.contrib.data.parallel_interleave(
file_read_func,
cycle_length=num_readers,
block_length=config.read_block_length,
sloppy=config.shuffle))
if config.shuffle:
records_dataset = records_dataset.shuffle(config.shuffle_buffer_size)
tensor_dataset = records_dataset.map(
decode_func, num_parallel_calls=config.num_parallel_map_calls)
return tensor_dataset.prefetch(config.prefetch_size)
protos/input_reader.proto
syntax = "proto2";
package object_detection.protos;
// Configuration proto for defining input readers that generate Object Detection
// Examples from input sources. Input readers are expected to generate a
// dictionary of tensors, with the following fields populated:
//
// 'image': an [image_height, image_width, channels] image tensor that detection
// will be run on.
// 'groundtruth_classes': a [num_boxes] int32 tensor storing the class
// labels of detected boxes in the image.
// 'groundtruth_boxes': a [num_boxes, 4] float tensor storing the coordinates of
// detected boxes in the image.
// 'groundtruth_instance_masks': (Optional), a [num_boxes, image_height,
// image_width] float tensor storing binary mask of the objects in boxes.
// Instance mask format. Note that PNG masks are much more space efficient.
enum InstanceMaskType {
DEFAULT = 0; // Default implementation, currently NUMERICAL_MASKS
NUMERICAL_MASKS = 1; // [num_masks, H, W] float32 binary masks.
PNG_MASKS = 2; // Encoded PNG masks.
}
message InputReader {
// Path to StringIntLabelMap pbtxt file specifying the mapping from string
// labels to integer ids.
optional string label_map_path = 1 [default=""];
// Whether data should be processed in the order they are read in, or
// shuffled randomly.
optional bool shuffle = 2 [default=true];
// Buffer size to be used when shuffling.
optional uint32 shuffle_buffer_size = 11 [default = 2048];
// Buffer size to be used when shuffling file names.
optional uint32 filenames_shuffle_buffer_size = 12 [default = 100];
// Maximum number of records to keep in reader queue.
optional uint32 queue_capacity = 3 [default=2000];
// Minimum number of records to keep in reader queue. A large value is needed
// to generate a good random shuffle.
optional uint32 min_after_dequeue = 4 [default=1000];
// The number of times a data source is read. If set to zero, the data source
// will be reused indefinitely.
optional uint32 num_epochs = 5 [default=0];
// Number of reader instances to create.
optional uint32 num_readers = 6 [default=32];
// Number of records to read from each reader at once.
optional uint32 read_block_length = 15 [default=32];
// Number of decoded records to prefetch before batching.
optional uint32 prefetch_size = 13 [default = 512];
// Number of parallel decode ops to apply.
optional uint32 num_parallel_map_calls = 14 [default = 64];
// Number of groundtruth keypoints per object.
optional uint32 num_keypoints = 16 [default = 0];
// Whether to load groundtruth instance masks.
optional bool load_instance_masks = 7 [default = false];
// Type of instance mask.
optional InstanceMaskType mask_type = 10 [default = NUMERICAL_MASKS];
// Whether to use the display name when decoding examples. This is only used
// when mapping class text strings to integers.
optional bool use_display_name = 17 [default = false];
oneof input_reader {
TFRecordInputReader tf_record_input_reader = 8;
ExternalInputReader external_input_reader = 9;
}
}
// An input reader that reads TF Example protos from local TFRecord files.
message TFRecordInputReader {
// Path(s) to `TFRecordFile`s.
repeated string input_path = 1;
}
// An externally defined input reader. Users may define an extension to this
// proto to interface their own input readers.
message ExternalInputReader {
extensions 1 to 999;
}
data_decoders/tf_example_decoder.py
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tensorflow Example proto decoder for object detection.
A decoder to decode string tensors containing serialized tensorflow.Example
protos for object detection.
"""
import tensorflow as tf
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from object_detection.core import data_decoder
from object_detection.core import standard_fields as fields
from object_detection.protos import input_reader_pb2
from object_detection.utils import label_map_util
slim_example_decoder = tf.contrib.slim.tfexample_decoder
# TODO(lzc): keep LookupTensor and BackupHandler in sync with
# tf.contrib.slim.tfexample_decoder version.
class LookupTensor(slim_example_decoder.Tensor):
"""An ItemHandler that returns a parsed Tensor, the result of a lookup."""
def __init__(self,
tensor_key,
table,
shape_keys=None,
shape=None,
default_value=''):
"""Initializes the LookupTensor handler.
Simply calls a vocabulary (most often, a label mapping) lookup.
Args:
tensor_key: the name of the `TFExample` feature to read the tensor from.
table: A tf.lookup table.
shape_keys: Optional name or list of names of the TF-Example feature in
which the tensor shape is stored. If a list, then each corresponds to
one dimension of the shape.
shape: Optional output shape of the `Tensor`. If provided, the `Tensor` is
reshaped accordingly.
default_value: The value used when the `tensor_key` is not found in a
particular `TFExample`.
Raises:
ValueError: if both `shape_keys` and `shape` are specified.
"""
self._table = table
super(LookupTensor, self).__init__(tensor_key, shape_keys, shape,
default_value)
def tensors_to_item(self, keys_to_tensors):
unmapped_tensor = super(LookupTensor, self).tensors_to_item(keys_to_tensors)
return self._table.lookup(unmapped_tensor)
class BackupHandler(slim_example_decoder.ItemHandler):
"""An ItemHandler that tries two ItemHandlers in order."""
def __init__(self, handler, backup):
"""Initializes the BackupHandler handler.
If the first Handler's tensors_to_item returns a Tensor with no elements,
the second Handler is used.
Args:
handler: The primary ItemHandler.
backup: The backup ItemHandler.
Raises:
ValueError: if either is not an ItemHandler.
"""
if not isinstance(handler, slim_example_decoder.ItemHandler):
raise ValueError('Primary handler is of type %s instead of ItemHandler' %
type(handler))
if not isinstance(backup, slim_example_decoder.ItemHandler):
raise ValueError(
'Backup handler is of type %s instead of ItemHandler' % type(backup))
self._handler = handler
self._backup = backup
super(BackupHandler, self).__init__(handler.keys + backup.keys)
def tensors_to_item(self, keys_to_tensors):
item = self._handler.tensors_to_item(keys_to_tensors)
return control_flow_ops.cond(
pred=math_ops.equal(math_ops.reduce_prod(array_ops.shape(item)), 0),
true_fn=lambda: self._backup.tensors_to_item(keys_to_tensors),
false_fn=lambda: item)
class TfExampleDecoder(data_decoder.DataDecoder):
"""Tensorflow Example proto decoder."""
#初始化指定需要解码哪些字段,以及对应的handler
def __init__(self,
load_instance_masks=False,
instance_mask_type=input_reader_pb2.NUMERICAL_MASKS,
label_map_proto_file=None,
use_display_name=False,
dct_method='',
num_keypoints=0,
num_additional_channels=0):
"""Constructor sets keys_to_features and items_to_handlers.
Args:
load_instance_masks: whether or not to load and handle instance masks.
instance_mask_type: type of instance masks. Options are provided in
input_reader.proto. This is only used if `load_instance_masks` is True.
label_map_proto_file: a file path to a
object_detection.protos.StringIntLabelMap proto. If provided, then the
mapped IDs of 'image/object/class/text' will take precedence over the
existing 'image/object/class/label' ID. Also, if provided, it is
assumed that 'image/object/class/text' will be in the data.
use_display_name: whether or not to use the `display_name` for label
mapping (instead of `name`). Only used if label_map_proto_file is
provided.
dct_method: An optional string. Defaults to None. It only takes
effect when image format is jpeg, used to specify a hint about the
algorithm used for jpeg decompression. Currently valid values
are ['INTEGER_FAST', 'INTEGER_ACCURATE']. The hint may be ignored, for
example, the jpeg library does not have that specific option.
num_keypoints: the number of keypoints per object.
num_additional_channels: how many additional channels to use.
Raises:
ValueError: If `instance_mask_type` option is not one of
input_reader_pb2.DEFAULT, input_reader_pb2.NUMERICAL, or
input_reader_pb2.PNG_MASKS.
"""
self.keys_to_features = {
'image/encoded':
tf.FixedLenFeature((), tf.string, default_value=''),
'image/format':
tf.FixedLenFeature((), tf.string, default_value='jpeg'),
'image/filename':
tf.FixedLenFeature((), tf.string, default_value=''),
'image/key/sha256':
tf.FixedLenFeature((), tf.string, default_value=''),
'image/source_id':
tf.FixedLenFeature((), tf.string, default_value=''),
'image/height':
tf.FixedLenFeature((), tf.int64, default_value=1),
'image/width':
tf.FixedLenFeature((), tf.int64, default_value=1),
# Object boxes and classes.
'image/object/bbox/xmin':
tf.VarLenFeature(tf.float32),
'image/object/bbox/xmax':
tf.VarLenFeature(tf.float32),
'image/object/bbox/ymin':
tf.VarLenFeature(tf.float32),
'image/object/bbox/ymax':
tf.VarLenFeature(tf.float32),
'image/object/class/label':
tf.VarLenFeature(tf.int64),
'image/object/class/text':
tf.VarLenFeature(tf.string),
'image/object/area':
tf.VarLenFeature(tf.float32),
'image/object/is_crowd':
tf.VarLenFeature(tf.int64),
'image/object/difficult':
tf.VarLenFeature(tf.int64),
'image/object/group_of':
tf.VarLenFeature(tf.int64),
'image/object/weight':
tf.VarLenFeature(tf.float32),
}
# We are checking `dct_method` instead of passing it directly in order to
# ensure TF version 1.6 compatibility.
# 如果是 JPEG 图片,可以加速。
if dct_method:
image = slim_example_decoder.Image(
image_key='image/encoded',
format_key='image/format',
channels=3,
dct_method=dct_method)
additional_channel_image = slim_example_decoder.Image(
image_key='image/additional_channels/encoded',
format_key='image/format',
channels=1,
repeated=True,
dct_method=dct_method)
else:
# image 对象主要用来解码图片,
# 如果 image/format 为 raw 或 Raw 调用 parsing_ops.decode_raw 解码
# 如果 image/format 为 jpeg,调用 image_ops.decode_jpeg 解码
# 否则调用 image_ops.decode_image 解码图片
image = slim_example_decoder.Image(
image_key='image/encoded', format_key='image/format', channels=3)
additional_channel_image = slim_example_decoder.Image(
image_key='image/additional_channels/encoded',
format_key='image/format',
channels=1,
repeated=True)
self.items_to_handlers = {
fields.InputDataFields.image:
image,
fields.InputDataFields.source_id: (
slim_example_decoder.Tensor('image/source_id')),
fields.InputDataFields.key: (
slim_example_decoder.Tensor('image/key/sha256')),
fields.InputDataFields.filename: (
slim_example_decoder.Tensor('image/filename')),
# Object boxes and classes.
fields.InputDataFields.groundtruth_boxes: (
slim_example_decoder.BoundingBox(['ymin', 'xmin', 'ymax', 'xmax'],
'image/object/bbox/')),
fields.InputDataFields.groundtruth_area:
slim_example_decoder.Tensor('image/object/area'),
fields.InputDataFields.groundtruth_is_crowd: (
slim_example_decoder.Tensor('image/object/is_crowd')),
fields.InputDataFields.groundtruth_difficult: (
slim_example_decoder.Tensor('image/object/difficult')),
fields.InputDataFields.groundtruth_group_of: (
slim_example_decoder.Tensor('image/object/group_of')),
fields.InputDataFields.groundtruth_weights: (
slim_example_decoder.Tensor('image/object/weight')),
}
if num_additional_channels > 0:
self.keys_to_features[
'image/additional_channels/encoded'] = tf.FixedLenFeature(
(num_additional_channels,), tf.string)
self.items_to_handlers[
fields.InputDataFields.
image_additional_channels] = additional_channel_image
self._num_keypoints = num_keypoints
if num_keypoints > 0:
self.keys_to_features['image/object/keypoint/x'] = (
tf.VarLenFeature(tf.float32))
self.keys_to_features['image/object/keypoint/y'] = (
tf.VarLenFeature(tf.float32))
self.items_to_handlers[fields.InputDataFields.groundtruth_keypoints] = (
slim_example_decoder.ItemHandlerCallback(
['image/object/keypoint/y', 'image/object/keypoint/x'],
self._reshape_keypoints))
if load_instance_masks:
if instance_mask_type in (input_reader_pb2.DEFAULT,
input_reader_pb2.NUMERICAL_MASKS):
self.keys_to_features['image/object/mask'] = (
tf.VarLenFeature(tf.float32))
self.items_to_handlers[
fields.InputDataFields.groundtruth_instance_masks] = (
slim_example_decoder.ItemHandlerCallback(
['image/object/mask', 'image/height', 'image/width'],
self._reshape_instance_masks))
elif instance_mask_type == input_reader_pb2.PNG_MASKS:
self.keys_to_features['image/object/mask'] = tf.VarLenFeature(tf.string)
self.items_to_handlers[
fields.InputDataFields.groundtruth_instance_masks] = (
slim_example_decoder.ItemHandlerCallback(
['image/object/mask', 'image/height', 'image/width'],
self._decode_png_instance_masks))
else:
raise ValueError('Did not recognize the `instance_mask_type` option.')
if label_map_proto_file:
label_map = label_map_util.get_label_map_dict(label_map_proto_file,
use_display_name)
# We use a default_value of -1, but we expect all labels to be contained
# in the label map.
table = tf.contrib.lookup.HashTable(
initializer=tf.contrib.lookup.KeyValueTensorInitializer(
keys=tf.constant(list(label_map.keys())),
values=tf.constant(list(label_map.values()), dtype=tf.int64)),
default_value=-1)
# If the label_map_proto is provided, try to use it in conjunction with
# the class text, and fall back to a materialized ID.
# TODO(lzc): note that here we are using BackupHandler defined in this
# file(which is branching slim_example_decoder.BackupHandler). Need to
# switch back to slim_example_decoder.BackupHandler once tf 1.5 becomes
# more popular.
label_handler = BackupHandler(
LookupTensor('image/object/class/text', table, default_value=''),
slim_example_decoder.Tensor('image/object/class/label'))
else:
label_handler = slim_example_decoder.Tensor('image/object/class/label')
self.items_to_handlers[
fields.InputDataFields.groundtruth_classes] = label_handler
# decode 方法调用 parsing_ops.parse_single_example 解码对应的字段,调用对应字段的 hander 解码
def decode(self, tf_example_string_tensor):
"""Decodes serialized tensorflow example and returns a tensor dictionary.
Args:
tf_example_string_tensor: a string tensor holding a serialized tensorflow
example proto.
Returns:
A dictionary of the following tensors.
fields.InputDataFields.image - 3D uint8 tensor of shape [None, None, 3]
containing image.
fields.InputDataFields.source_id - string tensor containing original
image id.
fields.InputDataFields.key - string tensor with unique sha256 hash key.
fields.InputDataFields.filename - string tensor with original dataset
filename.
fields.InputDataFields.groundtruth_boxes - 2D float32 tensor of shape
[None, 4] containing box corners.
fields.InputDataFields.groundtruth_classes - 1D int64 tensor of shape
[None] containing classes for the boxes.
fields.InputDataFields.groundtruth_weights - 1D float32 tensor of
shape [None] indicating the weights of groundtruth boxes.
fields.InputDataFields.num_groundtruth_boxes - int32 scalar indicating
the number of groundtruth_boxes.
fields.InputDataFields.groundtruth_area - 1D float32 tensor of shape
[None] containing containing object mask area in pixel squared.
fields.InputDataFields.groundtruth_is_crowd - 1D bool tensor of shape
[None] indicating if the boxes enclose a crowd.
Optional:
fields.InputDataFields.image_additional_channels - 3D uint8 tensor of
shape [None, None, num_additional_channels]. 1st dim is height; 2nd dim
is width; 3rd dim is the number of additional channels.
fields.InputDataFields.groundtruth_difficult - 1D bool tensor of shape
[None] indicating if the boxes represent `difficult` instances.
fields.InputDataFields.groundtruth_group_of - 1D bool tensor of shape
[None] indicating if the boxes represent `group_of` instances.
fields.InputDataFields.groundtruth_keypoints - 3D float32 tensor of
shape [None, None, 2] containing keypoints, where the coordinates of
the keypoints are ordered (y, x).
fields.InputDataFields.groundtruth_instance_masks - 3D float32 tensor of
shape [None, None, None] containing instance masks.
"""
# 创建一个 TFExampleDecoder 对象,初始化需要解码的字段以及对应字段的 handler
serialized_example = tf.reshape(tf_example_string_tensor, shape=[])
decoder = slim_example_decoder.TFExampleDecoder(self.keys_to_features,
self.items_to_handlers)
keys = decoder.list_items()
# 调用 parsing_ops.parse_single_example 将 String 类型的 Tensor 进行解码。
# 需要解码的字段由 keys 指定
tensors = decoder.decode(serialized_example, items=keys)
tensor_dict = dict(zip(keys, tensors))
is_crowd = fields.InputDataFields.groundtruth_is_crowd
tensor_dict[is_crowd] = tf.cast(tensor_dict[is_crowd], dtype=tf.bool)
tensor_dict[fields.InputDataFields.image].set_shape([None, None, 3])
tensor_dict[fields.InputDataFields.num_groundtruth_boxes] = tf.shape(
tensor_dict[fields.InputDataFields.groundtruth_boxes])[0]
if fields.InputDataFields.image_additional_channels in tensor_dict:
channels = tensor_dict[fields.InputDataFields.image_additional_channels]
channels = tf.squeeze(channels, axis=3)
channels = tf.transpose(channels, perm=[1, 2, 0])
tensor_dict[fields.InputDataFields.image_additional_channels] = channels
def default_groundtruth_weights():
return tf.ones(
[tf.shape(tensor_dict[fields.InputDataFields.groundtruth_boxes])[0]],
dtype=tf.float32)
tensor_dict[fields.InputDataFields.groundtruth_weights] = tf.cond(
tf.greater(
tf.shape(
tensor_dict[fields.InputDataFields.groundtruth_weights])[0],
0), lambda: tensor_dict[fields.InputDataFields.groundtruth_weights],
default_groundtruth_weights)
return tensor_dict
def _reshape_keypoints(self, keys_to_tensors):
"""Reshape keypoints.
The instance segmentation masks are reshaped to [num_instances,
num_keypoints, 2].
Args:
keys_to_tensors: a dictionary from keys to tensors.
Returns:
A 3-D float tensor of shape [num_instances, num_keypoints, 2] with values
in {0, 1}.
"""
y = keys_to_tensors['image/object/keypoint/y']
if isinstance(y, tf.SparseTensor):
y = tf.sparse_tensor_to_dense(y)
y = tf.expand_dims(y, 1)
x = keys_to_tensors['image/object/keypoint/x']
if isinstance(x, tf.SparseTensor):
x = tf.sparse_tensor_to_dense(x)
x = tf.expand_dims(x, 1)
keypoints = tf.concat([y, x], 1)
keypoints = tf.reshape(keypoints, [-1, self._num_keypoints, 2])
return keypoints
def _reshape_instance_masks(self, keys_to_tensors):
"""Reshape instance segmentation masks.
The instance segmentation masks are reshaped to [num_instances, height,
width].
Args:
keys_to_tensors: a dictionary from keys to tensors.
Returns:
A 3-D float tensor of shape [num_instances, height, width] with values
in {0, 1}.
"""
height = keys_to_tensors['image/height']
width = keys_to_tensors['image/width']
to_shape = tf.cast(tf.stack([-1, height, width]), tf.int32)
masks = keys_to_tensors['image/object/mask']
if isinstance(masks, tf.SparseTensor):
masks = tf.sparse_tensor_to_dense(masks)
masks = tf.reshape(tf.to_float(tf.greater(masks, 0.0)), to_shape)
return tf.cast(masks, tf.float32)
def _decode_png_instance_masks(self, keys_to_tensors):
"""Decode PNG instance segmentation masks and stack into dense tensor.
The instance segmentation masks are reshaped to [num_instances, height,
width].
Args:
keys_to_tensors: a dictionary from keys to tensors.
Returns:
A 3-D float tensor of shape [num_instances, height, width] with values
in {0, 1}.
"""
def decode_png_mask(image_buffer):
image = tf.squeeze(
tf.image.decode_image(image_buffer, channels=1), axis=2)
image.set_shape([None, None])
image = tf.to_float(tf.greater(image, 0))
return image
png_masks = keys_to_tensors['image/object/mask']
height = keys_to_tensors['image/height']
width = keys_to_tensors['image/width']
if isinstance(png_masks, tf.SparseTensor):
png_masks = tf.sparse_tensor_to_dense(png_masks, default_value='')
return tf.cond(
tf.greater(tf.size(png_masks), 0),
lambda: tf.map_fn(decode_png_mask, png_masks, dtype=tf.float32),
lambda: tf.zeros(tf.to_int32(tf.stack([0, height, width]))))
参考:https://blog.csdn.net/wenxueliu/article/details/80727563