这里的源码是从train.py开始看的。之后还有eval.py
train.py
在trian.py中config文件,被分成三分
model_config = configs['model']
train_config = configs['train_config']
input_config = configs['train_input_config']
moedel_config是构建模型的文件。
model_fn = functools.partial(
model_builder.build,
model_config=model_config,
is_training=True)
在model_bulider.py中build会选择模型种类
def build(model_config, is_training):
"""Builds a DetectionModel based on the model config.
Args:
model_config: A model.proto object containing the config for the desired
DetectionModel.
is_training: True if this model is being built for training purposes.
Returns:
DetectionModel based on the config.
Raises:
ValueError: On invalid meta architecture or model.
"""
if not isinstance(model_config, model_pb2.DetectionModel):
raise ValueError('model_config not of type model_pb2.DetectionModel.')
meta_architecture = model_config.WhichOneof('model')
if meta_architecture == 'ssd':
return _build_ssd_model(model_config.ssd, is_training)
if meta_architecture == 'faster_rcnn':
return _build_faster_rcnn_model(model_config.faster_rcnn, is_training)
raise ValueError('Unknown meta architecture: {}'.format(meta_architecture))
如果你选择faster-rcnn,在model_builder.py中这些都是构建faster-rcnn模型的参数
如果你有兴趣,在protos/model_pb2.py有很多model_config的默认值
这时候模型已经构建完了
回到train.py中
train_config = configs['train_config']
发现这是对trainer.py进行的配置文件,在trainer.py的train函数中,如下:
在protos/train_pb2.py中的默认配置如下:
_descriptor.FieldDescriptor(
name='batch_size', full_name='object_detection.protos.TrainConfig.batch_size', index=0,
number=1, type=13, cpp_type=3, label=1,
has_default_value=True, default_value=32,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='data_augmentation_options', full_name='object_detection.protos.TrainConfig.data_augmentation_options', index=1,
number=2, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='sync_replicas', full_name='object_detection.protos.TrainConfig.sync_replicas', index=2,
number=3, type=8, cpp_type=7, label=1,
has_default_value=True, default_value=False,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='keep_checkpoint_every_n_hours', full_name='object_detection.protos.TrainConfig.keep_checkpoint_every_n_hours', index=3,
number=4, type=13, cpp_type=3, label=1,
has_default_value=True, default_value=1000,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='optimizer', full_name='object_detection.protos.TrainConfig.optimizer', index=4,
number=5, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='gradient_clipping_by_norm', full_name='object_detection.protos.TrainConfig.gradient_clipping_by_norm', index=5,
number=6, type=2, cpp_type=6, label=1,
has_default_value=True, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='fine_tune_checkpoint', full_name='object_detection.protos.TrainConfig.fine_tune_checkpoint', index=6,
number=7, type=9, cpp_type=9, label=1,
has_default_value=True, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='from_detection_checkpoint', full_name='object_detection.protos.TrainConfig.from_detection_checkpoint', index=7,
number=8, type=8, cpp_type=7, label=1,
has_default_value=True, default_value=False,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='num_steps', full_name='object_detection.protos.TrainConfig.num_steps', index=8,
number=9, type=13, cpp_type=3, label=1,
has_default_value=True, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='startup_delay_steps', full_name='object_detection.protos.TrainConfig.startup_delay_steps', index=9,
number=10, type=2, cpp_type=6, label=1,
has_default_value=True, default_value=15,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='bias_grad_multiplier', full_name='object_detection.protos.TrainConfig.bias_grad_multiplier', index=10,
number=11, type=2, cpp_type=6, label=1,
has_default_value=True, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='freeze_variables', full_name='object_detection.protos.TrainConfig.freeze_variables', index=11,
number=12, type=9, cpp_type=9, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='replicas_to_aggregate', full_name='object_detection.protos.TrainConfig.replicas_to_aggregate', index=12,
number=13, type=5, cpp_type=1, label=1,
has_default_value=True, default_value=1,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='batch_queue_capacity', full_name='object_detection.protos.TrainConfig.batch_queue_capacity', index=13,
number=14, type=5, cpp_type=1, label=1,
has_default_value=True, default_value=150,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='num_batch_queue_threads', full_name='object_detection.protos.TrainConfig.num_batch_queue_threads', index=14,
number=15, type=5, cpp_type=1, label=1,
has_default_value=True, default_value=8,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='prefetch_queue_capacity', full_name='object_detection.protos.TrainConfig.prefetch_queue_capacity', index=15,
number=16, type=5, cpp_type=1, label=1,
has_default_value=True, default_value=5,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='merge_multiple_label_boxes', full_name='object_detection.protos.TrainConfig.merge_multiple_label_boxes', index=16,
number=17, type=8, cpp_type=7, label=1,
has_default_value=True, default_value=False,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
再看input_config
input_config = configs['train_input_config']
在builder/input_reader_builder中
input_reader_pb2中默认值:
_descriptor.FieldDescriptor(
name='label_map_path', full_name='object_detection.protos.InputReader.label_map_path', index=0,
number=1, type=9, cpp_type=9, label=1,
has_default_value=True, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='shuffle', full_name='object_detection.protos.InputReader.shuffle', index=1,
number=2, type=8, cpp_type=7, label=1,
has_default_value=True, default_value=True,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='queue_capacity', full_name='object_detection.protos.InputReader.queue_capacity', index=2,
number=3, type=13, cpp_type=3, label=1,
has_default_value=True, default_value=2000,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='min_after_dequeue', full_name='object_detection.protos.InputReader.min_after_dequeue', index=3,
number=4, type=13, cpp_type=3, label=1,
has_default_value=True, default_value=1000,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='num_epochs', full_name='object_detection.protos.InputReader.num_epochs', index=4,
number=5, type=13, cpp_type=3, label=1,
has_default_value=True, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='num_readers', full_name='object_detection.protos.InputReader.num_readers', index=5,
number=6, type=13, cpp_type=3, label=1,
has_default_value=True, default_value=8,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='load_instance_masks', full_name='object_detection.protos.InputReader.load_instance_masks', index=6,
number=7, type=8, cpp_type=7, label=1,
has_default_value=True, default_value=False,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='tf_record_input_reader', full_name='object_detection.protos.InputReader.tf_record_input_reader', index=7,
number=8, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='external_input_reader', full_name='object_detection.protos.InputReader.external_input_reader', index=8,
number=9, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
value.py
在value.py中也是被分为三个部分
model_config = configs['model']
eval_config = configs['eval_config']
if FLAGS.eval_training_data:
input_config = configs['train_input_config']
else:
input_config = configs['eval_input_config']
这里
eval_config = configs['eval_config']
为新增的一个配置文件,进行计算评估用的一个文件。
evaluator.py文件使用了这里的config文件参数
文件开头的几种分数评估方式。
EVAL_METRICS_CLASS_DICT = {
'pascal_voc_metrics':
object_detection_evaluation.PascalDetectionEvaluator,
'weighted_pascal_voc_metrics':
object_detection_evaluation.WeightedPascalDetectionEvaluator,
'open_images_metrics':
object_detection_evaluation.OpenImagesDetectionEvaluator
}
eval_pb2.py文件中的eval_config的默认值。
_descriptor.FieldDescriptor(
name='num_visualizations', full_name='object_detection.protos.EvalConfig.num_visualizations', index=0,
number=1, type=13, cpp_type=3, label=1,
has_default_value=True, default_value=10,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='num_examples', full_name='object_detection.protos.EvalConfig.num_examples', index=1,
number=2, type=13, cpp_type=3, label=1,
has_default_value=True, default_value=5000,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='eval_interval_secs', full_name='object_detection.protos.EvalConfig.eval_interval_secs', index=2,
number=3, type=13, cpp_type=3, label=1,
has_default_value=True, default_value=300,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='max_evals', full_name='object_detection.protos.EvalConfig.max_evals', index=3,
number=4, type=13, cpp_type=3, label=1,
has_default_value=True, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='save_graph', full_name='object_detection.protos.EvalConfig.save_graph', index=4,
number=5, type=8, cpp_type=7, label=1,
has_default_value=True, default_value=False,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='visualization_export_dir', full_name='object_detection.protos.EvalConfig.visualization_export_dir', index=5,
number=6, type=9, cpp_type=9, label=1,
has_default_value=True, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='eval_master', full_name='object_detection.protos.EvalConfig.eval_master', index=6,
number=7, type=9, cpp_type=9, label=1,
has_default_value=True, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='metrics_set', full_name='object_detection.protos.EvalConfig.metrics_set', index=7,
number=8, type=9, cpp_type=9, label=1,
has_default_value=True, default_value=_b("pascal_voc_metrics").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='export_path', full_name='object_detection.protos.EvalConfig.export_path', index=8,
number=9, type=9, cpp_type=9, label=1,
has_default_value=True, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='ignore_groundtruth', full_name='object_detection.protos.EvalConfig.ignore_groundtruth', index=9,
number=10, type=8, cpp_type=7, label=1,
has_default_value=True, default_value=False,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='use_moving_averages', full_name='object_detection.protos.EvalConfig.use_moving_averages', index=10,
number=11, type=8, cpp_type=7, label=1,
has_default_value=True, default_value=False,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='eval_instance_masks', full_name='object_detection.protos.EvalConfig.eval_instance_masks', index=11,
number=12, type=8, cpp_type=7, label=1,
has_default_value=True, default_value=False,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None)
所有的超参数的默认值都可以在config文件中进行修改。