tensorflow版本的deeplabv3+源码解读1

目录

    • 1.deeplabv3+整体结构
    • 2.train.py
    • 3 总结

读源码太痛苦了,各种看不懂。因为刚接触语义分割用了deeplab这个模型,想好好地把源码看一下。读第一遍只能把API查一下,了解函数的作用。这是读的第二遍,把各模块的注释写一下。如果有人有更好地方法读懂源代码,求告知。

1.deeplabv3+整体结构

看一下deeplabv3+整个文件夹结构:
tensorflow版本的deeplabv3+源码解读1_第1张图片我是从local_test_mobilenetv2.sh作为入口开始读的。

2.train.py

2.1 首先看main函数:

def main(unused_argv):
  tf.logging.set_verbosity(tf.logging.INFO) # 将tensorflow日志信息输出到屏幕

  tf.gfile.MakeDirs(FLAGS.train_logdir) # 创建一个目录,若目录存在则成功,无返回
  tf.logging.info('Training on %s set', FLAGS.train_split) # 打印日志信息,train_split默认为train

  graph = tf.Graph() # 实例化一个graph类
  with graph.as_default(): # 作为整个tensorflow运行环境默认图
    with tf.device(tf.train.replica_device_setter(ps_tasks=FLAGS.num_ps_tasks)): # 指定模型运行的设备,分布式训练.num_ps_tasks默认为0,参数服务器数量
      assert FLAGS.train_batch_size % FLAGS.num_clones == 0, (
          'Training batch size not divisble by number of clones (GPUs).') # num_clones默认为1,train_batch_size默认为8,若除不尽则报错
      clone_batch_size = FLAGS.train_batch_size // FLAGS.num_clones # //整数除法
      # dataset/data_generator.py中的Dataset类
      dataset = data_generator.Dataset(
          dataset_name=FLAGS.dataset,
          split_name=FLAGS.train_split,
          dataset_dir=FLAGS.dataset_dir,
          batch_size=clone_batch_size,
          crop_size=[int(sz) for sz in FLAGS.train_crop_size],
          min_resize_value=FLAGS.min_resize_value,
          max_resize_value=FLAGS.max_resize_value,
          resize_factor=FLAGS.resize_factor,
          min_scale_factor=FLAGS.min_scale_factor,
          max_scale_factor=FLAGS.max_scale_factor,
         scale_factor_step_size=FLAGS.scale_factor_step_size,
          model_variant=FLAGS.model_variant,
          num_readers=2,
          is_training=True,
          should_shuffle=True,
          should_repeat=True)
      # 调用train.py中的_train_deeplab_model函数见2.2。传入的参数为tf.data.Iterator类型的迭代器,类别数,忽略标签
      # 返回更新模型参数的张量和日志操作
      train_tensor, summary_op = _train_deeplab_model(
          dataset.get_one_shot_iterator(), dataset.num_of_classes,
          dataset.ignore_label)

      # Soft placement allows placing on CPU ops without GPU implementation.
      # allow_soft_placement为true时,自动分配cpu和gpu
      session_config = tf.ConfigProto(
          allow_soft_placement=True, log_device_placement=False)
      # 调用model.py中的函数
      last_layers = model.get_extra_layer_scopes(
          FLAGS.last_layers_contain_logits_only)
      init_fn = None
      # 若给出预训练模型
      if FLAGS.tf_initial_checkpoint:
        # 调用utils/train_utils.py中的get_model_init_fn,返回从checkpoint初始化的模型参数
        init_fn = train_utils.get_model_init_fn(
            FLAGS.train_logdir,
            FLAGS.tf_initial_checkpoint,
            FLAGS.initialize_last_layer,
            last_layers,
            ignore_missing_vars=True)

      scaffold = tf.train.Scaffold(
          init_fn=init_fn,
          summary_op=summary_op,
      )
      # train_number_of_steps默认为30000,训练的迭代次数,stop_hook是在特定步数停止的钩子
      stop_hook = tf.train.StopAtStepHook(
          last_step=FLAGS.training_number_of_steps)
    
     # profile路径,默认NOne
      profile_dir = FLAGS.profile_logdir
      if profile_dir is not None:
        tf.gfile.MakeDirs(profile_dir)
      
      #  ProfileContext 将采样一些 step 并将 profile 缓存到文件,用户然后可以使用命令行工具或 Web UI 进行交互式分析
      with tf.contrib.tfprof.ProfileContext(
          enabled=profile_dir is not None, profile_dir=profile_dir):
        # 处理初始化,模型恢复,和处理Hooks的类似与Session的类[详细介绍](https://blog.csdn.net/MrR1ght/article/details/81006343)
        with tf.train.MonitoredTrainingSession(
            master=FLAGS.master,
            is_chief=(FLAGS.task == 0),
            config=session_config,
            scaffold=scaffold,
            checkpoint_dir=FLAGS.train_logdir,
            summary_dir=FLAGS.train_logdir,
            log_step_count_steps=FLAGS.log_steps,
            save_summaries_steps=FLAGS.save_summaries_secs,
            save_checkpoint_secs=FLAGS.save_interval_secs,
            hooks=[stop_hook]) as sess:
          while not sess.should_stop(): # 异常终端会将should_stop设置为true
            sess.run([train_tensor])# 训练模型

2.2 _train_deeplab_model函数

def _train_deeplab_model(iterator, num_of_classes, ignore_label):
  """Trains the deeplab model.

  Args:
    iterator: An iterator of type tf.data.Iterator for images and labels.
    num_of_classes: Number of classes for the dataset.
    ignore_label: Ignore label for the dataset.

  Returns:
    train_tensor: A tensor to update the model variables.
    summary_op: An operation to log the summaries.
  """
  # 代表全局步数,每个batch自动加1
  global_step = tf.train.get_or_create_global_step()
  # 调用utils/train_utils.py中的函数,得到学习率
  learning_rate = train_utils.get_model_learning_rate(
      FLAGS.learning_policy, FLAGS.base_learning_rate,
      FLAGS.learning_rate_decay_step, FLAGS.learning_rate_decay_factor,
      FLAGS.training_number_of_steps, FLAGS.learning_power,
      FLAGS.slow_start_step, FLAGS.slow_start_learning_rate)
  tf.summary.scalar('learning_rate', learning_rate)# 显示标量信息,即学习率

  # momentum优化器,默认0.9
  optimizer = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum)

  tower_losses = []
  tower_grads = []
  for i in range(FLAGS.num_clones): # num_clones默认为1
    with tf.device('/gpu:%d' % i): # 指定gpu
      # First tower has default name scope.
      name_scope = ('clone_%d' % i) if i else ''
      with tf.name_scope(name_scope) as scope: # 指定对象或操作在某一区域
        # 调用_tower_loss函数,见2.3,返回一个batch数据的总loss
        loss = _tower_loss(
            iterator=iterator,
            num_of_classes=num_of_classes,
            ignore_label=ignore_label,
            scope=scope,
            reuse_variable=(i != 0))
        tower_losses.append(loss) # 每个设备的loss append

  if FLAGS.quantize_delay_step >= 0: # 量化模型,quantize_delay_step默认为-1 
    if FLAGS.num_clones > 1:
      raise ValueError('Quantization doesn\'t support multi-clone yet.')
    tf.contrib.quantize.create_training_graph(
        quant_delay=FLAGS.quantize_delay_step)

  for i in range(FLAGS.num_clones):
    with tf.device('/gpu:%d' % i):
      name_scope = ('clone_%d' % i) if i else ''
      with tf.name_scope(name_scope) as scope:
        grads = optimizer.compute_gradients(tower_losses[i]) # 计算梯度,然后自己对梯度进行处理
        tower_grads.append(grads)# 每个设备的梯度append

  with tf.device('/cpu:0'):
    grads_and_vars = _average_gradients(tower_grads)# 调用_average_gradients函数见2.5。计算每个共享变量的梯度,返回(gradient,variant)

    # Modify the gradients for biases and last layer variables.
    last_layers = model.get_extra_layer_scopes(
        FLAGS.last_layers_contain_logits_only)
    # 调用train_utils.py中的函数,返回一个map{var:multiplier}
    grad_mult = train_utils.get_model_gradient_multipliers(
        last_layers, FLAGS.last_layer_gradient_multiplier)
    # 将梯度乘子与梯度相乘
    if grad_mult:
      grads_and_vars = tf.contrib.training.multiply_gradients(
          grads_and_vars, grad_mult)

    # Create gradient update op.
    grad_updates = optimizer.apply_gradients(
        grads_and_vars, global_step=global_step)

    # Gather update_ops. These contain, for example,
    # the updates for the batch_norm variables created by model_fn.
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) # 获取tf.Graphkeys.UPDATE_OPS(包含每个训练步骤之前的op,配合tf.control_dependecies使用)的所有元素,返回一个列表
    update_ops.append(grad_updates)
    update_op = tf.group(*update_ops) # 组合操作,返回一个op

    total_loss = tf.losses.get_total_loss(add_regularization_losses=True)

    # Print total loss to the terminal.
    # This implementation is mirrored from tf.slim.summaries.
    should_log = math_ops.equal(math_ops.mod(global_step, FLAGS.log_steps), 0) # log_steps默认为10,每10步log
    # tf.cond()类似于if。。。else。。。
    total_loss = tf.cond(
        should_log,
        lambda: tf.Print(total_loss, [total_loss], 'Total loss is :'),
        lambda: total_loss)

    tf.summary.scalar('total_loss', total_loss)
    # 该函数保证其辖域中的操作必须要在该函数所传递的参数中的操作(update_op)完成后再进行,即保证先更新。
    with tf.control_dependencies([update_op]):
      train_tensor = tf.identity(total_loss, name='train_op')

    # Excludes summaries from towers other than the first one.
    summary_op = tf.summary.merge_all(scope='(?!clone_)')

  return train_tensor, summary_op

2.3 _tower_loss函数

def _tower_loss(iterator, num_of_classes, ignore_label, scope, reuse_variable):
  """Calculates the total loss on a single tower running the deeplab model.

  Args:
    iterator: An iterator of type tf.data.Iterator for images and labels.
    num_of_classes: Number of classes for the dataset.
    ignore_label: Ignore label for the dataset.
    scope: Unique prefix string identifying the deeplab tower.
    reuse_variable: If the variable should be reused.

  Returns:
     The total loss for a batch of data.
  """
  # 获得当前变量的scope作为指定变量的scope,reuse=true表示共享变量
  with tf.variable_scope(
      tf.get_variable_scope(), reuse=True if reuse_variable else None):
    # 调用_build_deeplab函数见2.4,传入iterator,{‘semantic’:num_of_classes},ignore_label.返回deeplab模型
    _build_deeplab(iterator, {common.OUTPUT_TYPE: num_of_classes}, ignore_label)
  # 得到loss列表
  losses = tf.losses.get_losses(scope=scope)
  for loss in losses:
    tf.summary.scalar('Losses/%s' % loss.op.name, loss) # 将某个name的loss输出到日志中
  
  # 获取总正则化loss
  regularization_loss = tf.losses.get_regularization_loss(scope=scope)
  tf.summary.scalar('Losses/%s' % regularization_loss.op.name,
                    regularization_loss)
  # 计算总loss。tf.add_n()将list中数值相加
  total_loss = tf.add_n([tf.add_n(losses), regularization_loss])
  return total_loss

2.4 _build_deeplab函数

def _build_deeplab(iterator, outputs_to_num_classes, ignore_label):
  """Builds a clone of DeepLab.
  Args:
    iterator: An iterator of type tf.data.Iterator for images and labels.
    outputs_to_num_classes: A map from output type to the number of classes. For
      example, for the task of semantic segmentation with 21 semantic classes,
      we would have outputs_to_num_classes['semantic'] = 21.
    ignore_label: Ignore label.
  """
  samples = iterator.get_next()

  # Add name to input and label nodes so we can add to summary.common.IMAGE=‘image’,common.label='label'
  samples[common.IMAGE] = tf.identity(samples[common.IMAGE], name=common.IMAGE)
  samples[common.LABEL] = tf.identity(samples[common.LABEL], name=common.LABEL)
  # common.py中的ModelOptions类
  model_options = common.ModelOptions(
      outputs_to_num_classes=outputs_to_num_classes,
      crop_size=[int(sz) for sz in FLAGS.train_crop_size],
      atrous_rates=FLAGS.atrous_rates,
      output_stride=FLAGS.output_stride)
  # 调用model.py中的函数
  outputs_to_scales_to_logits = model.multi_scale_logits(
      samples[common.IMAGE],
      model_options=model_options,
      image_pyramid=FLAGS.image_pyramid,
      weight_decay=FLAGS.weight_decay, # 正则化参数,默认为0.00004
      is_training=True,
      fine_tune_batch_norm=FLAGS.fine_tune_batch_norm,
      nas_training_hyper_parameters={
          'drop_path_keep_prob': FLAGS.drop_path_keep_prob,
          'total_training_steps': FLAGS.training_number_of_steps,
      })

  # Add name to graph node so we can add to summary. common.OUTPUT_TYPE=‘semantic’
  output_type_dict = outputs_to_scales_to_logits[common.OUTPUT_TYPE]
  output_type_dict[model.MERGED_LOGITS_SCOPE] = tf.identity(
      output_type_dict[model.MERGED_LOGITS_SCOPE], name=common.OUTPUT_TYPE)

  for output, num_classes in six.iteritems(outputs_to_num_classes):

# 调用utils/train_utils.py中的函数,为每个尺度的logits增加交叉熵损失
  train_utils.add_softmax_cross_entropy_loss_for_each_scale(
        outputs_to_scales_to_logits[output],
        samples[common.LABEL],
        num_classes,
        ignore_label,
        loss_weight=1.0,
        upsample_logits=FLAGS.upsample_logits,
        hard_example_mining_step=FLAGS.hard_example_mining_step,
        top_k_percent_pixels=FLAGS.top_k_percent_pixels,
        scope=output)

    # Log the summary,调用_log_summaries函数见2.6
    _log_summaries(samples[common.IMAGE], samples[common.LABEL], num_classes,
                   output_type_dict[model.MERGED_LOGITS_SCOPE])

2.5_average_gradients函数
计算所有tower每个共享变量的平均梯度

def _average_gradients(tower_grads):
  """Calculates average of gradient for each shared variable across all towers.
  Note that this function provides a synchronization point across all towers.
  Args:
    tower_grads: List of lists of (gradient, variable) tuples. The outer list is
      over individual gradients. The inner list is over the gradient calculation
      for each tower.

  Returns:
     List of pairs of (gradient, variable) where the gradient has been summed
       across all towers.
  """
  average_grads = []
  for grad_and_vars in zip(*tower_grads): # zip将可迭代参数中的元素打包成一个个元组形成列表
    # Note that each grad_and_vars looks like the following:
    #   ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN))
    grads, variables = zip(*grad_and_vars) # grads应为[grad0_gpu0,...,grad0_gpuN]
    # 按某一维度计算均值,这里计算的是grad0_gpu0...grad0_gpuN的均值
    grad = tf.reduce_mean(tf.stack(grads, axis=0), axis=0) # tf.stack按某一维度拼接。这里形成一个2维矩阵

    # All vars are of the same value, using the first tower here.
    average_grads.append((grad, variables[0]))

  return average_grads

2.6 _log_summaries函数
输出到日志文件,用tensorboard查看

def _log_summaries(input_image, label, num_of_classes, output):
  """Logs the summaries for the model.
  Args:
    input_image: Input image of the model. Its shape is [batch_size, height,
      width, channel].
    label: Label of the image. Its shape is [batch_size, height, width].
    num_of_classes: The number of classes of the dataset.
    output: Output of the model. Its shape is [batch_size, height, width].
  """
  # Add summaries for model variables.
  for model_var in tf.model_variables(): # 进程内存储的模型参数集合
    tf.summary.histogram(model_var.op.name, model_var) # 直方图形式显示模型参数

  # Add summaries for images, labels, semantic predictions.
  if FLAGS.save_summaries_images:
    tf.summary.image('samples/%s' % common.IMAGE, input_image)

    # Scale up summary image pixel values for better visualization.
    pixel_scaling = max(1, 255 // num_of_classes)
    summary_label = tf.cast(label * pixel_scaling, tf.uint8)
    tf.summary.image('samples/%s' % common.LABEL, summary_label)
    # 增加维度,-1表示最后一维
    predictions = tf.expand_dims(tf.argmax(output, 3), -1) # tf.argmax()返回某个维度的最大值
    summary_predictions = tf.cast(predictions * pixel_scaling, tf.uint8)
    tf.summary.image('samples/%s' % common.OUTPUT_TYPE, summary_predictions)

3 总结

train.py总的逻辑调用:
--------------------------------------- | -->_average_gradients
main -->_train_deeplab_model -->_tower_loss -->_build_deeplab -->_log_summaries

你可能感兴趣的:(tensorflow)