DeepLabv3+训练代码详解

代码地址
训练代码train.py

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import six
import tensorflow as tf
#import tensorflow.compat.v1 as tf
from tensorflow.contrib import quantize as contrib_quantize
from tensorflow.contrib import tfprof as contrib_tfprof

import sys
sys.path.append("..")
from deeplab import common
from deeplab import model

from deeplab.datasets import data_generator
from deeplab.utils import train_utils
from deployment import model_deploy

slim = tf.contrib.slim
flags = tf.app.flags
FLAGS = flags.FLAGS

#设置GPU训练参数
flags.DEFINE_integer('num_clones', 1, 'Number of clones to deploy.')

flags.DEFINE_boolean('clone_on_cpu', False, 'Use CPUs to deploy clones.')

flags.DEFINE_integer('num_replicas', 1, 'Number of worker replicas.')

flags.DEFINE_integer('startup_delay_steps', 15,
                     'Number of training steps between replicas startup.')

flags.DEFINE_integer(
    'num_ps_tasks', 0,
    'The number of parameter servers. If the value is 0, then '
    'the parameters are handled locally by the worker.')

flags.DEFINE_string('master', '', 'BNS name of the tensorflow server')

flags.DEFINE_integer('task', 0, 'The task ID.')

#设置保存日志参数
flags.DEFINE_string('train_logdir', None,
                    'Where the checkpoint and logs are stored.')

flags.DEFINE_integer('log_steps', 10,
                     'Display logging information at every log_steps.')

flags.DEFINE_integer('save_interval_secs', 1200,
                     'How often, in seconds, we save the model to disk.')

flags.DEFINE_integer('save_summaries_secs', 600,
                     'How often, in seconds, we compute the summaries.')

flags.DEFINE_boolean(
    'save_summaries_images', False,
    'Save sample inputs, labels, and semantic predictions as '
    'images to summary.')

# Settings for profiling.
flags.DEFINE_string('profile_logdir', None,
                    'Where the profile files are stored.')

#设置训练优化函数
flags.DEFINE_enum('optimizer', 'momentum', ['momentum', 'adam'],
                  'Which optimizer to use.')


#设置学习方式
flags.DEFINE_enum('learning_policy', 'poly', ['poly', 'step'],
                  'Learning rate policy for training.')

# Use 0.007 when training on PASCAL augmented training set, train_aug. 
# When fine-tuning on PASCAL trainval set, use learning rate=0.0001.
flags.DEFINE_float('base_learning_rate', .0001,
                   'The base learning rate for model training.')

flags.DEFINE_float('decay_steps', 0.0,
                   'Decay steps for polynomial learning rate schedule.')

flags.DEFINE_float('end_learning_rate', 0.0,
                   'End learning rate for polynomial learning rate schedule.')

flags.DEFINE_float('learning_rate_decay_factor', 0.1,
                   'The rate to decay the base learning rate.')

flags.DEFINE_integer('learning_rate_decay_step', 2000,
                     'Decay the base learning rate at a fixed step.')

flags.DEFINE_float('learning_power', 0.9,
                   'The power value used in the poly learning policy.')

flags.DEFINE_integer('training_number_of_steps', 30000,
                     'The number of steps used for training')

flags.DEFINE_float('momentum', 0.9, 'The momentum value to use')

# Adam optimizer flags
flags.DEFINE_float('adam_learning_rate', 0.001,
                   'Learning rate for the adam optimizer.')
flags.DEFINE_float('adam_epsilon', 1e-08, 'Adam optimizer epsilon.')

# When fine_tune_batch_norm=True, use at least batch size larger than 12
# (batch size more than 16 is better). Otherwise, one could use smaller batch
# size and set fine_tune_batch_norm=False.
flags.DEFINE_integer('train_batch_size', 4,
                     'The number of images in each batch during training.')

# For weight_decay, use 0.00004 for MobileNet-V2 or Xcpetion model variants.
# Use 0.0001 for ResNet model variants.
flags.DEFINE_float('weight_decay', 0.00004,
                   'The value of the weight decay for training.')

flags.DEFINE_list('train_crop_size', '513,513',
                  'Image crop size [height, width] during training.')
                  
#最后一层的梯度乘数,如果值> 1,则用于增大最后一层的梯度
flags.DEFINE_float(
    'last_layer_gradient_multiplier', 1.0,
    'The gradient multiplier for last layers, which is used to '
    'boost the gradient of last layers if the value > 1.')

flags.DEFINE_boolean('upsample_logits', True,
                     'Upsample logits during training.')

# Hyper-parameters for NAS training strategy.

flags.DEFINE_float(
    'drop_path_keep_prob', 1.0,
    'Probability to keep each path in the NAS cell when training.')

# Settings for fine-tuning the network.

flags.DEFINE_string('tf_initial_checkpoint', None,
                    'The initial checkpoint in tensorflow format.')

# Set to False if one does not want to re-use the trained classifier weights.
#使用预训练的所有权重,设置initialize_last_layer=True
#类别数不同,设置 initialize_last_layer=False,last_layers_contain_logits_only=True
flags.DEFINE_boolean('initialize_last_layer', False,
                     'Initialize the last layer.')
#是否仅将logit视为最后一层,若是False则输出多个模块,可以对ASPP和解码器等进行操作和修改;
#若为True,仅将logits视为最后一层(即排除ASPP模块,解码器模块等),
flags.DEFINE_boolean('last_layers_contain_logits_only', False,
                     'Only consider logits as last layers or not.')



flags.DEFINE_integer('slow_start_step', 0,
                     'Training model with small learning rate for few steps.')

flags.DEFINE_float('slow_start_learning_rate', 1e-4,
                   'Learning rate employed during slow start.')

# Set to True if one wants to fine-tune the batch norm parameters in DeepLabv3.
# Set to False and use small batch size to save GPU memory.
flags.DEFINE_boolean('fine_tune_batch_norm', False,
                     'Fine tune the batch norm parameters or not.')

flags.DEFINE_float('min_scale_factor', 0.5,
                   'Mininum scale factor for data augmentation.')

flags.DEFINE_float('max_scale_factor', 2.,
                   'Maximum scale factor for data augmentation.')

flags.DEFINE_float('scale_factor_step_size', 0.25,
                   'Scale factor step size for data augmentation.')

# For `xception_65`, use atrous_rates = [12, 24, 36] if output_stride = 8, or
# rates = [6, 12, 18] if output_stride = 16. For `mobilenet_v2`, use None. Note
# one could use different atrous_rates/output_stride during training/evaluation.
flags.DEFINE_multi_integer('atrous_rates', None,
                           'Atrous rates for atrous spatial pyramid pooling.')

flags.DEFINE_integer('output_stride', 16,
                     'The ratio of input to output spatial resolution.')

# Hard example mining related flags.
flags.DEFINE_integer(
    'hard_example_mining_step', 0,
    'The training step in which exact hard example mining kicks off. Note we '
    'gradually reduce the mining percent to the specified '
    'top_k_percent_pixels. For example, if hard_example_mining_step=100K and '
    'top_k_percent_pixels=0.25, then mining percent will gradually reduce from '
    '100% to 25% until 100K steps after which we only mine top 25% pixels.')

flags.DEFINE_float(
    'top_k_percent_pixels', 1.0,
    'The top k percent pixels (in terms of the loss values) used to compute '
    'loss during training. This is useful for hard pixel mining.')

# Quantization setting.
flags.DEFINE_integer(
    'quantize_delay_step', -1,
    'Steps to start quantized training. If < 0, will not quantize model.')

# Dataset settings.
flags.DEFINE_string('dataset', 'pascal_voc_seg',
                    'Name of the segmentation dataset.')

flags.DEFINE_string('train_split', 'train',
                    'Which split of the dataset to be used for training')

flags.DEFINE_string('dataset_dir', None, 'Where the dataset reside.')
'''
构建DeepLab的网络。
  参数:
     iterator:用于图像和标签的tf.data.Iterator类型的迭代器。
    outputs_to_num_classes:从输出类型到类数的映射。
    例如,对于具有21个语义类的语义分割任务,将具有outputs_to_num_classes ['semantic'] = 21。
     ignore_label:忽略标签。
 '''
def _build_deeplab(iterator, outputs_to_num_classes, ignore_label):
  #获取语义图像和对应标签
  samples = iterator.get_next()
  # Add name to input and label nodes so we can add to summary.
  #将名称添加到输入和标签节点,以便我们可以添加到摘要
  samples[common.IMAGE] = tf.identity(samples[common.IMAGE], name=common.IMAGE)
  samples[common.LABEL] = tf.identity(samples[common.LABEL], name=common.LABEL)
  #模型配置参数
  model_options = common.ModelOptions(
      outputs_to_num_classes=outputs_to_num_classes,#输出的类别
      crop_size=[int(sz) for sz in FLAGS.train_crop_size],#输入图像的大小
      atrous_rates=FLAGS.atrous_rates,#空洞率
      output_stride=FLAGS.output_stride)#输出stride
'''ModelOptions(outputs_to_num_classes={'semantic': 4}, crop_size=[513, 513], atrous_rates=[6,12,18], 
output_stride=16,preprocessed_images_dtype=tf.float32, merge_method='max', add_image_level_feature=True, 
image_pooling_crop_size=None, image_pooling_stride=[1, 1], aspp_with_batch_norm=True, aspp_with_separable_conv=True, 
multi_grid=None, decoder_output_stride=[4], decoder_use_separable_conv=True, logits_kernel_size=1, 
model_variant='xception_65', depth_multiplier=1.0, divisible_by=None, prediction_with_upsampled_logits=True, 
dense_prediction_cell_config=None, nas_architecture_options={'nas_stem_output_num_conv_filters': 20, 
'nas_use_classification_head': False, 'nas_remove_os32_stride': False}, use_bounded_activation=False, 
aspp_with_concat_projection=True, aspp_with_squeeze_and_excitation=False, aspp_convs_filters=256, 
decoder_use_sum_merge=False, decoder_filters=256, decoder_output_is_logits=False, image_se_uses_qsigmoid=False, 
label_weights=1.0, sync_batch_norm_method='None', batch_norm_decay=0.9997)'''
  #构建网络模型并得出输出,返回的多尺度logit均被下采样
  outputs_to_scales_to_logits = model.multi_scale_logits(
      samples[common.IMAGE],#[批,高度,宽度,通道]的图像张量
      model_options=model_options,#之前的模型参数,用于配置模型
      image_pyramid=FLAGS.image_pyramid,#输入图像比例以进行多比例特征提取
      weight_decay=FLAGS.weight_decay,#模型变量的权重衰减
      is_training=True,#是否正在训练
      fine_tune_batch_norm=FLAGS.fine_tune_batch_norm,#是否微调批处理规范参数
      nas_training_hyper_parameters={#用于存储以下内容的超参
          'drop_path_keep_prob': FLAGS.drop_path_keep_prob,#训练时将每个路径保留在网络权重中的可能性。
          'total_training_steps': FLAGS.training_number_of_steps,#总共训练步骤,可帮助降低路径概率计算。
      })

  # 将名称添加到图节点,以便我们可以添加到摘要
  output_type_dict = outputs_to_scales_to_logits[common.OUTPUT_TYPE]
  # tf.identity是返回一个一模一样新的tensor的op,这会增加一个新节点到gragh中
  output_type_dict[model.MERGED_LOGITS_SCOPE] = tf.identity(output_type_dict[model.MERGED_LOGITS_SCOPE], name=common.OUTPUT_TYPE)
  #获取每个logits的损失
  for output, num_classes in six.iteritems(outputs_to_num_classes):
    train_utils.add_softmax_cross_entropy_loss_for_each_scale(
        outputs_to_scales_to_logits[output],
        samples[common.LABEL],
        num_classes,
        ignore_label,
        loss_weight=model_options.label_weights,
        upsample_logits=FLAGS.upsample_logits,
        hard_example_mining_step=FLAGS.hard_example_mining_step,
        top_k_percent_pixels=FLAGS.top_k_percent_pixels,
        scope=output)



def main(unused_argv):
  tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)# 显示INFO及更高级别的日志消息
  # 创建DeploymentConfig类对象config
  # 这个配置类描述了如何将一个模型部署在多个单机的多个GPU上,在每个单机上,模型将被复制num_clones次
  config = model_deploy.DeploymentConfig(
      num_clones=FLAGS.num_clones,# 每个单机部署多少个clone(即部署在多少个GPU)
      clone_on_cpu=FLAGS.clone_on_cpu,#如果为True,则单机中的每个clone将被放在CPU中
      replica_id=FLAGS.task,# 整数,为其部署模型的副本的索引,主要副本通常为0。
      num_replicas=FLAGS.num_replicas,# 要使用的副本数,使用多少台机器,通常为1,表示单机部署。此时`worker_device`, `num_ps_tasks`和 `ps_device`这几个参数将被忽略。
      num_ps_tasks=FLAGS.num_ps_tasks)#参数服务器的数量。 如果值为0,则参数由工作程序本地处理。
      
  # 一个btch的数据要被同时平均分配给所有的GPU,此时每个GPU的batch_size为train_batch_size/num_clones
  # 故train_batch_size必须可以被num_clones整除
  assert FLAGS.train_batch_size % config.num_clones == 0, ('Training batch size not divisble by number of clones (GPUs).')
  clone_batch_size = FLAGS.train_batch_size // config.num_clones

  #创建模型保存目录
  tf.io.gfile.makedirs(FLAGS.train_logdir)
  tf.compat.v1.logging.info('Training on %s set', FLAGS.train_split)
  
  #配置模型输入数据
  with tf.Graph().as_default() as graph:
  	#config.inputs_device()默认为/device:CPU:0,即使用CPU作为输入设备
    with tf.device(config.inputs_device()):
      dataset = data_generator.Dataset(
          dataset_name=FLAGS.dataset,
          split_name=FLAGS.train_split,
          dataset_dir=FLAGS.dataset_dir,
          batch_size=clone_batch_size,
          crop_size=[int(sz) for sz in FLAGS.train_crop_size],
          min_resize_value=FLAGS.min_resize_value,
          max_resize_value=FLAGS.max_resize_value,
          resize_factor=FLAGS.resize_factor,
          min_scale_factor=FLAGS.min_scale_factor,
          max_scale_factor=FLAGS.max_scale_factor,
          scale_factor_step_size=FLAGS.scale_factor_step_size,
          model_variant=FLAGS.model_variant,
          num_readers=4,
          is_training=True,
          should_shuffle=True,
          should_repeat=True)

    # Create the global step on the device storing the variables.
    # 配置模型并创建global_step
    with tf.device(config.variables_device()):#config.variables_device()默认为/device:CPU:0,即使用CPU作为输入设备
      # 创建global_step,参数为计算图graph,如果没有参数则采用默认计算图
      global_step = tf.compat.v1.train.get_or_create_global_step()
      
      # Define the model and create clones.
      # 定义并创建模型,模型的定义在上面具体的函数中
      model_fn = _build_deeplab
      #输入数据
      model_args = (dataset.get_one_shot_iterator(), { common.OUTPUT_TYPE: dataset.num_of_classes}, dataset.ignore_label)
      
      # 创建模型并克隆到多个GPU
      clones = model_deploy.create_clones(config, model_fn, args=model_args)
      '''
      def create_clones(config, model_fn, args=None, kwargs=None):
      	  clones = []
		  args = args or []
		  kwargs = kwargs or {}
		  with slim.arg_scope([slim.model_variable, slim.variable],
		                      device=config.variables_device()):
		    # Create clones.创建克隆
		    for i in range(0, config.num_clones):
		      with tf.name_scope(config.clone_scope(i)) as clone_scope:
		        clone_device = config.clone_device(i)
		        with tf.device(clone_device):
		          with tf.variable_scope(tf.get_variable_scope(),
		                                 reuse=True if i > 0 else None):
		            outputs = model_fn(*args, **kwargs)#这里会运行函数
		          clones.append(Clone(outputs, clone_scope, clone_device))
		  return clones
      def clone_device(self, clone_index):
		    if clone_index >= self._num_clones:
		      raise ValueError('clone_index must be less than num_clones')
		    device = ''
		    if self._num_ps_tasks > 0:
		      device += self._worker_device
		    if self._clone_on_cpu:
		      device += '/device:CPU:0'
		    else:
		      device += '/device:GPU:%d' % clone_index #这里会添加GPU
		    return device'''

      # Gather update_ops from the first clone. These contain, for example,
      # the updates for the batch_norm variables created by model_fn.
      # 从第一个克隆中收集update_ops。 例如,它们包含由model_fn创建的batch_norm变量的更新。
      first_clone_scope = config.clone_scope(0)
      # 返回计算图中的first_clone_scope空间中的名字为UPDATE_OPS的张量集合(即返回需要迭代更新的变量)
      update_ops = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS, first_clone_scope)

    # Gather initial summaries.
    # 获取初始化summaries
    summaries = set(tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.SUMMARIES))

    # Add summaries for model variables.
    # 添加模型变量的summaries
    for model_var in tf.compat.v1.model_variables():
      summaries.add(tf.compat.v1.summary.histogram(model_var.op.name, model_var))

    # Add summaries for images, labels, semantic predictions
    # 添加图像,标签,语义预测的摘要
    if FLAGS.save_summaries_images:
      summary_image = graph.get_tensor_by_name(('%s/%s:0' % (first_clone_scope, common.IMAGE)).strip('/'))
      summaries.add(tf.summary.image('samples/%s' % common.IMAGE, summary_image))

      first_clone_label = graph.get_tensor_by_name(('%s/%s:0' % (first_clone_scope, common.LABEL)).strip('/'))
      # Scale up summary image pixel values for better visualization.
      pixel_scaling = max(1, 255 // dataset.num_of_classes)
      summary_label = tf.cast(first_clone_label * pixel_scaling, tf.uint8)
      summaries.add(tf.summary.image('samples/%s' % common.LABEL, summary_label))

      first_clone_output = graph.get_tensor_by_name(('%s/%s:0' % (first_clone_scope, common.OUTPUT_TYPE)).strip('/'))
      predictions = tf.expand_dims(tf.argmax(first_clone_output, 3), -1)

      summary_predictions = tf.cast(predictions * pixel_scaling, tf.uint8)
      summaries.add(tf.summary.image('samples/%s' % common.OUTPUT_TYPE, summary_predictions))

    # Add summaries for losses.
    for loss in tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.LOSSES, first_clone_scope):
      summaries.add(tf.compat.v1.summary.scalar('losses/%s' % loss.op.name, loss))

    # Build the optimizer based on the device specification.
    # 配置优化过程
    with tf.device(config.optimizer_device()):
      #根据学习策略获取学习率,会在之后选择momentum优化方法时使用,若是使用Adm优化则不使用
      learning_rate = train_utils.get_model_learning_rate(
          FLAGS.learning_policy,
          FLAGS.base_learning_rate,
          FLAGS.learning_rate_decay_step,
          FLAGS.learning_rate_decay_factor,
          FLAGS.training_number_of_steps,
          FLAGS.learning_power,
          FLAGS.slow_start_step,
          FLAGS.slow_start_learning_rate,
          decay_steps=FLAGS.decay_steps,
          end_learning_rate=FLAGS.end_learning_rate)
	  # 将学习率加入到summaries中
      summaries.add(tf.compat.v1.summary.scalar('learning_rate', learning_rate))
	  #选择优化方法
      if FLAGS.optimizer == 'momentum':
        optimizer = tf.compat.v1.train.MomentumOptimizer(learning_rate, FLAGS.momentum)
      elif FLAGS.optimizer == 'adam':
        optimizer = tf.train.AdamOptimizer(
            learning_rate=FLAGS.adam_learning_rate, epsilon=FLAGS.adam_epsilon)
      else:
        raise ValueError('Unknown optimizer')
	#开始量化训练的步骤。 如果大于0,则会量化模型
	# 量化模型应该就是在保证精度的情况下对模型进行轻量化处理,提高模型运行速度,减小模型
    if FLAGS.quantize_delay_step >= 0:
      if FLAGS.num_clones > 1:
        raise ValueError('Quantization doesn\'t support multi-clone yet.')
      contrib_quantize.create_training_graph(quant_delay=FLAGS.quantize_delay_step)
	# 每个副本启动间隔的训练步数
    startup_delay_steps = FLAGS.task * FLAGS.startup_delay_steps

    # 配置模型变量
    with tf.device(config.variables_device()):
      # 根据clones网络模型和optimize优化方法计算所有GPU上的损失和梯度
      total_loss, grads_and_vars = model_deploy.optimize_clones(clones, optimizer)
      '''
      # 计算给定的“克隆”列表的每一个克隆的损失和梯度。
      def optimize_clones(clones, optimizer,
                    regularization_losses=None,
                    **kwargs):
  
		  grads_and_vars = []
		  clones_losses = []
		  num_clones = len(clones)
		  if regularization_losses is None:
		    regularization_losses = tf.compat.v1.get_collection(
		        tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES)
		  for clone in clones:
		    with tf.name_scope(clone.scope):
		      clone_loss, clone_grad = _optimize_clone(
		          optimizer, clone, num_clones, regularization_losses, **kwargs)
		      if clone_loss is not None:
		        clones_losses.append(clone_loss)
		        grads_and_vars.append(clone_grad)
		      # Only use regularization_losses for the first clone
		      regularization_losses = None
		  # Compute the total_loss summing all the clones_losses.
		  total_loss = tf.add_n(clones_losses, name='total_loss')
		  # Sum the gradients across clones.
		  grads_and_vars = _sum_clones_gradients(grads_and_vars)
		  return total_loss, grads_and_vars
		  '''
      # 检查检查 NaN 和 Inf 值的张量,如果存在这些张量则返回异常信息
      total_loss = tf.debugging.check_numerics(total_loss, 'Loss is inf or nan.')
      summaries.add(tf.compat.v1.summary.scalar('total_loss', total_loss))

      # Modify the gradients for biases and last layer variables.
      # 修正bias和最后一层权值变量的梯度。对于语义分割任务,模型通常是从分类任务中微调而来。
	  # 为了微调模型,我们通常把模型的最后一层变量设置更大的学习率

	  # 这里会根据last_layers_contain_logits_only来选择是否仅将logit视为最后一层,即获取last_layers
	  #当FLAGS.last_layers_contain_logits_only=True时,last_layers=logits;
	  #当为false时则为['logits', 'image_pooling', 'aspp', 'concat_projection', 'decoder', 'meta_architecture']
      last_layers = model.get_extra_layer_scopes(FLAGS.last_layers_contain_logits_only)
      '''
      def get_extra_layer_scopes(last_layers_contain_logits_only=False):
		  if last_layers_contain_logits_only:
		    return [LOGITS_SCOPE_NAME]
		  else:
		    return [
		        LOGITS_SCOPE_NAME,
		        IMAGE_POOLING_SCOPE,
		        ASPP_SCOPE,
		        CONCAT_PROJECTION_SCOPE,
		        DECODER_SCOPE,
		        META_ARCHITECTURE_SCOPE,
		    ]
	  '''
	  # 获取梯度乘数
      grad_mult = train_utils.get_model_gradient_multipliers(last_layers, FLAGS.last_layer_gradient_multiplier)
      '''
      梯度乘数将调整模型变量的学习率 
      对于语义分割任务,通常会从针对图像分类任务训练的模型中微调模型。 
	  为了微调模型,我们通常为最后一层的参数设置较大的学习率(例如,大10倍)。
      def get_model_gradient_multipliers(last_layers, last_layer_gradient_multiplier):
		  gradient_multipliers = {}
		  #遍历所有的变量,其实就是每一层的权重
		  for var in tf.compat.v1.model_variables():
		    # 使偏置值的学习率提高一倍,只有最后一层logits/semantic/biases存在biases
		    if 'biases' in var.op.name:
		      gradient_multipliers[var.op.name] = 2.
		
		    # 对最后一层变量使用较大的学习率
		    for layer in last_layers:
		      if layer in var.op.name and 'biases' in var.op.name:
		        gradient_multipliers[var.op.name] = 2 * last_layer_gradient_multiplier
		        break
		      elif layer in var.op.name:
		        gradient_multipliers[var.op.name] = last_layer_gradient_multiplier
		        break
		  # 梯度乘数表,变量为键,乘数为值
		  return gradient_multipliers
		  '''
      if grad_mult:
        #  乘以指定的渐变
        grads_and_vars = slim.learning.multiply_gradients(grads_and_vars, grad_mult)
	  	'''
	  	def multiply_gradients(grads_and_vars, gradient_multipliers):
		 
		  if not isinstance(grads_and_vars, list):
		    raise ValueError('`grads_and_vars` must be a list.')
		  if not gradient_multipliers:
		    raise ValueError('`gradient_multipliers` is empty.')
		  if not isinstance(gradient_multipliers, dict):
		    raise ValueError('`gradient_multipliers` must be a dict.')
		
		  multiplied_grads_and_vars = []
		  for grad, var in grads_and_vars:
		    if var in gradient_multipliers or var.op.name in gradient_multipliers:
		      key = var if var in gradient_multipliers else var.op.name
		      if grad is None:
		        raise ValueError('Requested multiple of `None` gradient.')
		
		      multiplier = gradient_multipliers[key]
		      if not isinstance(multiplier, ops.Tensor):
		        multiplier = constant_op.constant(multiplier, dtype=grad.dtype)
		
		      if isinstance(grad, ops.IndexedSlices):
		        tmp = grad.values * multiplier
		        grad = ops.IndexedSlices(tmp, grad.indices, grad.dense_shape)
		      else:
		        grad *= multiplier
		    multiplied_grads_and_vars.append((grad, var))
		  return multiplied_grads_and_vars
		  '''
      # 创建梯度更新对象
      grad_updates = optimizer.apply_gradients(grads_and_vars, global_step=global_step)
      update_ops.append(grad_updates)
      update_op = tf.group(*update_ops)
      # 用来控制计算流图,给图中的某些计算指定顺序,即先运行update_op再运行total_loss,即更新完变量后再计算损失
	  # 与tf.control_dependencies配套使用的tf.identity用于创建一个与原来一样的张量节点到graph中,这样control_dependencies才会生效
      with tf.control_dependencies([update_op]):
        train_tensor = tf.identity(total_loss, name='train_op')

    # 添加第一个克隆的summary。 这些包含由model_fn和optimize_clones()或_gather_clone_loss()创建的summary
    summaries |= set(tf.get_collection(tf.compat.v1.GraphKeys.SUMMARIES, first_clone_scope))
    # 将所有summaries合并在一起。
    summary_op = tf.compat.v1.summary.merge(list(summaries))
    # 创建session,并对session进行参数配置,指定在GPU设备上的运行情况
    #布局允许在不使用GPU的情况下放置CPU操作。
    session_config = tf.compat.v1.ConfigProto(allow_soft_placement=True, log_device_placement=False)

    # Start the training.
    # 配置文件的存储位置
    profile_dir = FLAGS.profile_logdir
    # 如果其有文件名不存在文件夹,则创建文件夹
    if profile_dir is not None:
      tf.gfile.MakeDirs(profile_dir)
	# 加载预训练模型
    with contrib_tfprof.ProfileContext(enabled=profile_dir is not None, profile_dir=profile_dir):
      init_fn = None
      if FLAGS.tf_initial_checkpoint:
        init_fn = train_utils.get_model_init_fn(
            FLAGS.train_logdir,#训练模型的保存目录
            FLAGS.tf_initial_checkpoint,#预训练模型目录
            FLAGS.initialize_last_layer,#是否初始化最后一层网络参数
            last_layers,#不要初始化的层
            ignore_missing_vars=True)


      # 运行训练过程
      slim.learning.train(
          train_tensor,
          logdir=FLAGS.train_logdir,
          log_every_n_steps=FLAGS.log_steps,
          master=FLAGS.master,
          number_of_steps=FLAGS.training_number_of_steps,
          is_chief=(FLAGS.task == 0),
          session_config=session_config,
          startup_delay_steps=startup_delay_steps,
          init_fn=init_fn,
          summary_op=summary_op,
          save_summaries_secs=FLAGS.save_summaries_secs,
          save_interval_secs=FLAGS.save_interval_secs)

if __name__ == '__main__':
  flags.mark_flag_as_required('train_logdir')
  flags.mark_flag_as_required('dataset_dir')
  tf.compat.v1.app.run()

data_generator.py,用来处理训练数据

import collections
import os
import tensorflow as tf
from deeplab import common
from deeplab import input_preprocess

# Named tuple to describe the dataset properties.
# 命名元组以描述数据集属性。
DatasetDescriptor = collections.namedtuple(
    'DatasetDescriptor',
    [
        'splits_to_sizes',  # Splits of the dataset into training, val and test.
        'num_classes',  # Number of semantic classes, including the
                        # background class (if exists). For example, there
                        # are 20 foreground classes + 1 background class in
                        # the PASCAL VOC 2012 dataset. Thus, we set
                        # num_classes=21.
        'ignore_label',  # Ignore label value.
    ])

#各种数据集数据
_CITYSCAPES_INFORMATION = DatasetDescriptor(
    splits_to_sizes={'train_fine': 2975,
                     'train_coarse': 22973,
                     'trainval_fine': 3475,
                     'trainval_coarse': 23473,
                     'val_fine': 500,
                     'test_fine': 1525},
    num_classes=19,
    ignore_label=255,
)

_PASCAL_VOC_SEG_INFORMATION = DatasetDescriptor(
    splits_to_sizes={
        'train': 1464,
        'train_aug': 10582,
        'trainval': 2913,
        'val': 1449,
    },
    num_classes=21,
    ignore_label=255,
)

_ADE20K_INFORMATION = DatasetDescriptor(
    splits_to_sizes={
        'train': 20210,  # num of samples in images/training
        'val': 2000,  # num of samples in images/validation
    },
    num_classes=151,
    ignore_label=0,
)

_MYFJ = DatasetDescriptor(
   splits_to_sizes={
        #'train': 649,  # num of samples in images/training
        #'val': 73,  # num of samples in images/validation
        'train': 1070,  # num of samples in images/training
        'val': 119,  # num of samples in images/validation
    },
    num_classes=4,
    ignore_label=255,
)
#数据集集合
_DATASETS_INFORMATION = {
    'cityscapes': _CITYSCAPES_INFORMATION,
    'pascal_voc_seg': _PASCAL_VOC_SEG_INFORMATION,
    'ade20k': _ADE20K_INFORMATION,
    'myfj':_MYFJ,
}

# Default file pattern of TFRecord of TensorFlow Example.
_FILE_PATTERN = '%s-*'


def get_cityscapes_dataset_name():
  return 'cityscapes'

#表示deeplab模型的输入数据集。
class Dataset(object):
  """Represents input dataset for deeplab model."""
  #初始化输入数据
  def __init__(self,
               dataset_name,#数据集名称
               split_name,#数据集中的训练集或验证集
               dataset_dir,#数据集路径
               batch_size,#训练批次
               crop_size,#用于裁剪图像和标签的大小
               min_resize_value=None,#较小的图像的尺寸。
               max_resize_value=None,#较大图像允许的最大尺寸
               resize_factor=None,#调整大小的尺寸是系数加1的倍数。
               min_scale_factor=1.,#最小比例因子值。
               max_scale_factor=1.,#最大比例因子值。
               scale_factor_step_size=0,#从最小比例因子到最大比例的步长因子。根据的值随机缩放输入(min_scale_factor,max_scale_factor,scale_factor_step_size)
               model_variant=None,#模型变体
               num_readers=1,#阅读器数量
               is_training=False,#布尔值,数据集是否用于训练。
               should_shuffle=False,#布尔值,如果是则应随机输入数据。
               should_repeat=False):#布尔值,如果是则应该重复输入数据
    #如果数据集名称不在设置的数据集集合中则报错
    if dataset_name not in _DATASETS_INFORMATION:
      raise ValueError('The specified dataset is not supported yet.')
    #初始化数据集名称
    self.dataset_name = dataset_name
    #初始化数据集数据信息
    splits_to_sizes = _DATASETS_INFORMATION[dataset_name].splits_to_sizes
    #如果数据集中没有此数据则报错
    if split_name not in splits_to_sizes:
      raise ValueError('data split name %s not recognized' % split_name)
    #若是模型为空则发出警告
    if model_variant is None:
      tf.compat.v1.logging.warning('Please specify a model_variant. See '
                         'feature_extractor.network_map for supported model '
                         'variants.')
	#初始化输入参数
    self.split_name = split_name
    self.dataset_dir = dataset_dir
    self.batch_size = batch_size
    self.crop_size = crop_size
    self.min_resize_value = min_resize_value
    self.max_resize_value = max_resize_value
    self.resize_factor = resize_factor
    self.min_scale_factor = min_scale_factor
    self.max_scale_factor = max_scale_factor
    self.scale_factor_step_size = scale_factor_step_size
    self.model_variant = model_variant
    self.num_readers = num_readers
    self.is_training = is_training
    self.should_shuffle = should_shuffle
    self.should_repeat = should_repeat
	#初始化类别和忽略标签
    self.num_of_classes = _DATASETS_INFORMATION[self.dataset_name].num_classes
    self.ignore_label = _DATASETS_INFORMATION[self.dataset_name].ignore_label
    
  #解析tfrecord中的图像信息
  def _parse_function(self, example_proto):
    
     #当前仅支持jpeg和png。
     #需要使用此逻辑,因为tf.image.decode_image的形状未知,并且我们在必要时依赖此信息来扩展标签。
    def _decode_image(content, channels):
      return tf.cond(
          tf.image.is_jpeg(content),
          lambda: tf.image.decode_jpeg(content, channels),
          lambda: tf.image.decode_png(content, channels))
          
	#就是tfrecord的反向操作
    features = {
        'image/encoded':
            tf.io.FixedLenFeature((), tf.string, default_value=''),
        'image/filename':
            tf.io.FixedLenFeature((), tf.string, default_value=''),
        'image/format':
            tf.io.FixedLenFeature((), tf.string, default_value='jpeg'),
        'image/height':
            tf.io.FixedLenFeature((), tf.int64, default_value=0),
        'image/width':
            tf.io.FixedLenFeature((), tf.int64, default_value=0),
        'image/segmentation/class/encoded':
            tf.io.FixedLenFeature((), tf.string, default_value=''),
        'image/segmentation/class/format':
            tf.io.FixedLenFeature((), tf.string, default_value='png'),
    }
	
    parsed_features = tf.io.parse_single_example(example_proto, features)

    image = _decode_image(parsed_features['image/encoded'], channels=3)

    label = None
    if self.split_name != common.TEST_SET:
      label = _decode_image(
          parsed_features['image/segmentation/class/encoded'], channels=1)

    image_name = parsed_features['image/filename']
    if image_name is None:
      image_name = tf.constant('')

    sample = {
        common.IMAGE: image,
        common.IMAGE_NAME: image_name,
        common.HEIGHT: parsed_features['image/height'],
        common.WIDTH: parsed_features['image/width'],
    }

    if label is not None:
      if label.get_shape().ndims == 2:
        label = tf.expand_dims(label, 2)
      elif label.get_shape().ndims == 3 and label.shape.dims[2] == 1:
        pass
      else:
        raise ValueError('Input label shape must be [height, width], or '
                         '[height, width, 1].')

      label.set_shape([None, None, 1])

      sample[common.LABELS_CLASS] = label
    return sample
  #预处理图像和标签。
  def _preprocess_image(self, sample):
    image = sample[common.IMAGE]
    label = sample[common.LABELS_CLASS]
	#获取源图像、语义分割图像和标签信息,需要跳转到input_preprocess.py文件
    original_image, image, label = input_preprocess.preprocess_image_and_label(
        image=image,
        label=label,
        crop_height=self.crop_size[0],
        crop_width=self.crop_size[1],
        min_resize_value=self.min_resize_value,
        max_resize_value=self.max_resize_value,
        resize_factor=self.resize_factor,
        min_scale_factor=self.min_scale_factor,
        max_scale_factor=self.max_scale_factor,
        scale_factor_step_size=self.scale_factor_step_size,
        ignore_label=self.ignore_label,
        is_training=self.is_training,
        model_variant=self.model_variant)
    sample[common.IMAGE] = image

    if not self.is_training:
      # Original image is only used during visualization.
      sample[common.ORIGINAL_IMAGE] = original_image

    if label is not None:
      sample[common.LABEL] = label

    # Remove common.LABEL_CLASS key in the sample since it is only used to
    # derive label and not used in training and evaluation.
    ##删除示例中的common.LABEL_CLASS键,因为它仅用于派生标签,而不用于训练和评估。
    sample.pop(common.LABELS_CLASS, None)

    return sample
    
  #获取一个遍历数据集一次的迭代器。
  def get_one_shot_iterator(self):
	#获取文件名和路径
    files = self._get_all_files()
    #解析并获取tfrecord文件的所有图像和标注信息
    #
    dataset = (tf.data.TFRecordDataset(files, num_parallel_reads=self.num_readers)
              .map(self._parse_function, num_parallel_calls=self.num_readers)
              .map(self._preprocess_image, num_parallel_calls=self.num_readers))
    #是否随机输入数据
    if self.should_shuffle:
      dataset = dataset.shuffle(buffer_size=100)
	#训练时重复输入数据
    if self.should_repeat:
      dataset = dataset.repeat()  # Repeat forever for training.
    else:
      dataset = dataset.repeat(1)

    dataset = dataset.batch(self.batch_size).prefetch(self.batch_size)
    return tf.compat.v1.data.make_one_shot_iterator(dataset)
    
  #获取所有tfrecord文件的路径和目录
  def _get_all_files(self):
    file_pattern = _FILE_PATTERN
    file_pattern = os.path.join(self.dataset_dir,file_pattern % self.split_name)
    return tf.io.gfile.glob(file_pattern)

model.py构建网络模型

import tensorflow as tf
from tensorflow.contrib import slim as contrib_slim
from deeplab.core import dense_prediction_cell
from deeplab.core import feature_extractor

from deeplab.core import utils

slim = contrib_slim

LOGITS_SCOPE_NAME = 'logits'
MERGED_LOGITS_SCOPE = 'merged_logits'
IMAGE_POOLING_SCOPE = 'image_pooling'
ASPP_SCOPE = 'aspp'
CONCAT_PROJECTION_SCOPE = 'concat_projection'
DECODER_SCOPE = 'decoder'
META_ARCHITECTURE_SCOPE = 'meta_architecture'

PROB_SUFFIX = '_prob'

_resize_bilinear = utils.resize_bilinear
scale_dimension = utils.scale_dimension
split_separable_conv2d = utils.split_separable_conv2d

def get_extra_layer_scopes(last_layers_contain_logits_only=False):
  """Gets the scopes for extra layers.

  Args:mobilenetV2
    last_layers_contain_logits_only: Boolean, True if only consider logits as
    the last layer (i.e., exclude ASPP module, decoder module and so on)

  Returns:
    A list of scopes for extra layers.
  """
  if last_layers_contain_logits_only:
    return [LOGITS_SCOPE_NAME]
  else:
    return [
        LOGITS_SCOPE_NAME,
        IMAGE_POOLING_SCOPE,
        ASPP_SCOPE,
        CONCAT_PROJECTION_SCOPE,
        DECODER_SCOPE,
        META_ARCHITECTURE_SCOPE,
    ]


def predict_labels_multi_scale(images,
                               model_options,
                               eval_scales=(1.0,),
                               add_flipped_images=False):
  """Predicts segmentation labels.

  Args:
    images: A tensor of size [batch, height, width, channels].
    model_options: A ModelOptions instance to configure models.
    eval_scales: The scales to resize images for evaluation.
    add_flipped_images: Add flipped images for evaluation or not.

  Returns:
    A dictionary with keys specifying the output_type (e.g., semantic
      prediction) and values storing Tensors representing predictions (argmax
      over channels). Each prediction has size [batch, height, width].
  """
  outputs_to_predictions = {
      output: []
      for output in model_options.outputs_to_num_classes
  }

  for i, image_scale in enumerate(eval_scales):
    with tf.compat.v1.variable_scope(tf.compat.v1.get_variable_scope(), reuse=True if i else None):
      outputs_to_scales_to_logits = multi_scale_logits(
          images,
          model_options=model_options,
          image_pyramid=[image_scale],
          is_training=False,
          fine_tune_batch_norm=False)

    if add_flipped_images:
      with tf.compat.v1.variable_scope(tf.compat.v1.get_variable_scope(), reuse=True):
        outputs_to_scales_to_logits_reversed = multi_scale_logits(
            tf.reverse_v2(images, [2]),
            model_options=model_options,
            image_pyramid=[image_scale],
            is_training=False,
            fine_tune_batch_norm=False)

    for output in sorted(outputs_to_scales_to_logits):
      scales_to_logits = outputs_to_scales_to_logits[output]
      logits = _resize_bilinear(
          scales_to_logits[MERGED_LOGITS_SCOPE],
          tf.shape(images)[1:3],
          scales_to_logits[MERGED_LOGITS_SCOPE].dtype)
      outputs_to_predictions[output].append(
          tf.expand_dims(tf.nn.softmax(logits), 4))

      if add_flipped_images:
        scales_to_logits_reversed = (
            outputs_to_scales_to_logits_reversed[output])
        logits_reversed = _resize_bilinear(
            tf.reverse_v2(scales_to_logits_reversed[MERGED_LOGITS_SCOPE], [2]),
            tf.shape(images)[1:3],
            scales_to_logits_reversed[MERGED_LOGITS_SCOPE].dtype)
        outputs_to_predictions[output].append(
            tf.expand_dims(tf.nn.softmax(logits_reversed), 4))

  for output in sorted(outputs_to_predictions):
    predictions = outputs_to_predictions[output]
    # Compute average prediction across different scales and flipped images.
    predictions = tf.reduce_mean(tf.concat(predictions, 4), axis=4)
    outputs_to_predictions[output] = tf.argmax(predictions, 3)
    predictions[output + PROB_SUFFIX] = tf.nn.softmax(predictions)

  return outputs_to_predictions

def predict_labels(images, model_options, image_pyramid=None):
  """Predicts segmentation labels.

  Args:
    images: A tensor of size [batch, height, width, channels].
    model_options: A ModelOptions instance to configure models.
    image_pyramid: Input image scales for multi-scale feature extraction.

  Returns:
    A dictionary with keys specifying the output_type (e.g., semantic
      prediction) and values storing Tensors representing predictions (argmax
      over channels). Each prediction has size [batch, height, width].
  """
  outputs_to_scales_to_logits = multi_scale_logits(
      images,
      model_options=model_options,
      image_pyramid=image_pyramid,
      is_training=False,
      fine_tune_batch_norm=False)

  predictions = {}
  for output in sorted(outputs_to_scales_to_logits):
    scales_to_logits = outputs_to_scales_to_logits[output]
    logits = scales_to_logits[MERGED_LOGITS_SCOPE]
    # There are two ways to obtain the final prediction results: (1) bilinear
    # upsampling the logits followed by argmax, or (2) argmax followed by
    # nearest neighbor upsampling. The second option may introduce the "blocking
    # effect" but is computationally efficient.
    if model_options.prediction_with_upsampled_logits:
      logits = _resize_bilinear(logits,
                                tf.shape(images)[1:3],
                                scales_to_logits[MERGED_LOGITS_SCOPE].dtype)
      predictions[output] = tf.argmax(logits, 3)
      predictions[output + PROB_SUFFIX] = tf.nn.softmax(logits)
    else:
      argmax_results = tf.argmax(logits, 3)
      argmax_results = tf.image.resize_nearest_neighbor(
          tf.expand_dims(argmax_results, 3),
          tf.shape(images)[1:3],
          align_corners=True,
          name='resize_prediction')
      predictions[output] = tf.squeeze(argmax_results, 3)
      predictions[output + PROB_SUFFIX] = tf.image.resize_bilinear(
          tf.nn.softmax(logits),
          tf.shape(images)[1:3],
          align_corners=True,
          name='resize_prob')
  return predictions

def multi_scale_logits(images,
                       model_options,
                       image_pyramid,
                       weight_decay=0.0001,
                       is_training=False,
                       fine_tune_batch_norm=False,
                       nas_training_hyper_parameters=None):

  # 如果image_pyramid没有设置,则将默认值赋给它.
  if not image_pyramid:
    image_pyramid = [1.0]
    
  # 设置裁剪后图像的高宽
  crop_height = ( model_options.crop_size[0] if model_options.crop_size else tf.shape(images)[1])
  crop_width = (model_options.crop_size[1] if model_options.crop_size else tf.shape(images)[2])
  
  # 如果设置了池化层特征图裁剪尺寸则将其赋值给图像池化层特征图高宽
  if model_options.image_pooling_crop_size:
    image_pooling_crop_height = model_options.image_pooling_crop_size[0]
    image_pooling_crop_width = model_options.image_pooling_crop_size[1]

  # Compute the height, width for the output logits.
  #判断是否存在decoder_output_stride,若是存在,则设置为最小步长
  #若是没有,则设为输出步长;这里decoder_output_stride=[4],即其与输出图像相差4倍
  if model_options.decoder_output_stride:
    logits_output_stride = min(model_options.decoder_output_stride)
  else:
    logits_output_stride = model_options.output_stride
  # 计算输出logit的高宽,根据image_pyramid 和logits_output_stride计算,以最大比例为准,之后所有比例统一为该尺寸
  logits_height = scale_dimension(
      crop_height,
      #选择image_pyramid中最大的比例,若是小于1,则默认为1,然后除以logits_output_stride
      max(1.0, max(image_pyramid)) / logits_output_stride)
      
  logits_width = scale_dimension(
      crop_width,
      max(1.0, max(image_pyramid)) / logits_output_stride)
  '''
  def scale_dimension(dim, scale):
	  if isinstance(dim, tf.Tensor):
	    return tf.cast((tf.to_float(dim) - 1.0) * scale + 1.0, dtype=tf.int32)
	  else:
	    return int((float(dim) - 1.0) * scale + 1.0)
    '''
  # Compute the logits for each scale in the image pyramid.
  #计算图像金字塔中每个比例的logits
  outputs_to_scales_to_logits = {
      k: {}
      for k in model_options.outputs_to_num_classes
  }
  #获取图像通道数
  num_channels = images.get_shape().as_list()[-1]
  #遍历所有比例的image_pyramid
  for image_scale in image_pyramid:
  	#判断比例是否为1,若不为1则进行其他比例特征提取
    if image_scale != 1.0:
      #获取其他比例的特征图的高宽和尺寸
      scaled_height = scale_dimension(crop_height, image_scale)
      scaled_width = scale_dimension(crop_width, image_scale)
      scaled_crop_size = [scaled_height, scaled_width]
      #对图像进行线性插值处理,将原始图像按照该比例尺寸进行缩放
      scaled_images = _resize_bilinear(images, scaled_crop_size, images.dtype)
      #如果存在crop_size则重新设置维度,感觉这步有点多余
      if model_options.crop_size:
        scaled_images.set_shape([None, scaled_height, scaled_width, num_channels])
      # Adjust image_pooling_crop_size accordingly.
      #定义scaled_image_pooling_crop_size并初始化为None
      scaled_image_pooling_crop_size = None
      
      #如果存在image_pooling_crop_size则根据比例尺寸调整池化层特征图高宽
      if model_options.image_pooling_crop_size:
        scaled_image_pooling_crop_size = [
            scale_dimension(image_pooling_crop_height, image_scale),
            scale_dimension(image_pooling_crop_width, image_scale)]
            
    #若是比例为1.0,则将相应的参数赋给scaled的参数,其图像不做处理直接赋值给scaled_images       
    else:
      scaled_crop_size = model_options.crop_size
      scaled_images = images
      scaled_image_pooling_crop_size = model_options.image_pooling_crop_size
      
    #根据之前调整不同比例的尺寸参数来更新模型参数,这里只是更新了crop_size和image_pooling_crop_size的值
    updated_options = model_options._replace(
        crop_size=scaled_crop_size,
        image_pooling_crop_size=scaled_image_pooling_crop_size)

    #输出接口,通过模型网络然后获取输出
    outputs_to_logits = _get_logits(
        scaled_images,
        updated_options,
        weight_decay=weight_decay,
        reuse=tf.compat.v1.AUTO_REUSE,
        is_training=is_training,
        fine_tune_batch_norm=fine_tune_batch_norm,
        nas_training_hyper_parameters=nas_training_hyper_parameters)


    #合并前将Logit调整为具有相同尺寸的输出,之前设置的最大比例的尺寸,这里还是使用了线性插值的方式来调整
    for output in sorted(outputs_to_logits):
      outputs_to_logits[output] = _resize_bilinear(
          outputs_to_logits[output], [logits_height, logits_width],
          outputs_to_logits[output].dtype)

    # Return when only one input scale.
    #当只有一个输入比例时返回,因为只有一个输出不需要合并分支
    if len(image_pyramid) == 1:
      for output in sorted(model_options.outputs_to_num_classes):
        outputs_to_scales_to_logits[output][MERGED_LOGITS_SCOPE] = outputs_to_logits[output]
      return outputs_to_scales_to_logits
    
    # Save logits to the output map.
    #将logits都保存到outputs_to_scales_to_logits中
    for output in sorted(model_options.outputs_to_num_classes):
      outputs_to_scales_to_logits[output]['logits_%.2f' % image_scale] = outputs_to_logits[output]

  # 合并来自所有多尺度输入的logit
  # 将outputs_to_scales_to_logits中保存的logits合并在一起
  for output in sorted(model_options.outputs_to_num_classes):
    # 连接每种输出类型的多尺度logit
    #[< 'ExpandDims:0' shape=(?, 193, 193, 4, 1) >, <'ExpandDims_1:0' shape=(?, 193, 193, 4,1)]
    all_logits = [tf.expand_dims(logits, axis=4) for logits in outputs_to_scales_to_logits[output].values()]
    # Tensor("concat_2:0", shape=(?, 193, 193, 4, 2), dtype=float32, device=/device:GPU:0)
    all_logits = tf.concat(all_logits, 4)
    #选择合并方式,如果设置了max则选择tf.reduce_max,即选择所有对应输出数值最大的那个,否则选择平均值
    merge_fn = (tf.reduce_max if model_options.merge_method == 'max' else tf.reduce_mean)
    #根据选择的合并方式合并输出数值
    outputs_to_scales_to_logits[output][MERGED_LOGITS_SCOPE] = merge_fn(all_logits, axis=4)

  return outputs_to_scales_to_logits


def extract_features(images,
                     model_options,
                     weight_decay=0.0001,
                     reuse=None,
                     is_training=False,
                     fine_tune_batch_norm=False,
                     nas_training_hyper_parameters=None):
  """Extracts features by the particular model_variant.

  Args:
    images: A tensor of size [batch, height, width, channels].
    model_options: A ModelOptions instance to configure models.
    weight_decay: The weight decay for model variables.
    reuse: Reuse the model variables or not.
    is_training: Is training or not.
    fine_tune_batch_norm: Fine-tune the batch norm parameters or not.
    nas_training_hyper_parameters: A dictionary storing hyper-parameters for
      training nas models. Its keys are:
      - `drop_path_keep_prob`: Probability to keep each path in the cell when
        training.
      - `total_training_steps`: Total training steps to help drop path
        probability calculation.

  Returns:
    concat_logits: A tensor of size [batch, feature_height, feature_width,
      feature_channels], where feature_height/feature_width are determined by
      the images height/width and output_stride.
    end_points: A dictionary from components of the network to the corresponding
      activation.
  """
  features, end_points = feature_extractor.extract_features(
      images,
      output_stride=model_options.output_stride,
      multi_grid=model_options.multi_grid,
      model_variant=model_options.model_variant,
      depth_multiplier=model_options.depth_multiplier,
      divisible_by=model_options.divisible_by,
      weight_decay=weight_decay,
      reuse=reuse,
      is_training=is_training,
      preprocessed_images_dtype=model_options.preprocessed_images_dtype,
      fine_tune_batch_norm=fine_tune_batch_norm,
      nas_architecture_options=model_options.nas_architecture_options,
      nas_training_hyper_parameters=nas_training_hyper_parameters,
      use_bounded_activation=model_options.use_bounded_activation)
  #print(features)
  if not model_options.aspp_with_batch_norm:
    return features, end_points
  else:
    if model_options.dense_prediction_cell_config is not None:
      tf.logging.info('Using dense prediction cell config.')
      dense_prediction_layer = dense_prediction_cell.DensePredictionCell(
          config=model_options.dense_prediction_cell_config,
          hparams={
              'conv_rate_multiplier': 16 // model_options.output_stride,
          })
      concat_logits = dense_prediction_layer.build_cell(
          features,
          output_stride=model_options.output_stride,
          crop_size=model_options.crop_size,
          image_pooling_crop_size=model_options.image_pooling_crop_size,
          weight_decay=weight_decay,
          reuse=reuse,
          is_training=is_training,
          fine_tune_batch_norm=fine_tune_batch_norm)
      return concat_logits, end_points
    else:
      # The following codes employ the DeepLabv3 ASPP module. Note that we
      # could express the ASPP module as one particular dense prediction
      # cell architecture. We do not do so but leave the following codes
      # for backward compatibility.
      batch_norm_params = utils.get_batch_norm_params(
          decay=0.9997,
          epsilon=1e-5,
          scale=True,
          is_training=(is_training and fine_tune_batch_norm),
          sync_batch_norm_method=model_options.sync_batch_norm_method)
      #print(batch_norm_params)#{'is_training': False, 'decay': 0.9997, 'epsilon': 1e-05, 'scale': True, 'center': True}
      batch_norm = utils.get_batch_norm_fn(model_options.sync_batch_norm_method)
      activation_fn = (tf.nn.relu6 if model_options.use_bounded_activation else tf.nn.relu)
      with slim.arg_scope(
          [slim.conv2d, slim.separable_conv2d],
          weights_regularizer=slim.l2_regularizer(weight_decay),
          activation_fn=activation_fn,
          normalizer_fn=batch_norm,
          padding='SAME',
          stride=1,
          reuse=reuse):
        with slim.arg_scope([batch_norm], **batch_norm_params):
          depth = model_options.aspp_convs_filters
          branch_logits = []

          if model_options.add_image_level_feature:
            #print(model_options.crop_size)#[513,513]
            if model_options.crop_size is not None:
              image_pooling_crop_size = model_options.image_pooling_crop_size
              #print(image_pooling_crop_size)#None
              # If image_pooling_crop_size is not specified, use crop_size.
              if image_pooling_crop_size is None:
                image_pooling_crop_size = model_options.crop_size
                #print(image_pooling_crop_size)#[513,513]
              pool_height = scale_dimension(
                  image_pooling_crop_size[0],
                  1. / model_options.output_stride)
              pool_width = scale_dimension(
                  image_pooling_crop_size[1],
                  1. / model_options.output_stride)
              #print(features)#Tensor("xception_65/exit_flow/block2/unit_1/xception_module/separable_conv3_pointwise/Relu:0", shape=(?, 33, 33, 2048), dtype=float32)
              #print(pool_height,pool_width)#33,33
              image_feature = slim.avg_pool2d(
                  features, [pool_height, pool_width],
                  model_options.image_pooling_stride, padding='VALID')
              #print(image_feature)#Tensor("AvgPool2D/AvgPool:0", shape=(?, 1, 1, 2048), dtype=float32, device=/device:GPU:0)
              resize_height = scale_dimension(
                  model_options.crop_size[0],
                  1. / model_options.output_stride)
              resize_width = scale_dimension(
                  model_options.crop_size[1],
                  1. / model_options.output_stride)
              #print(resize_height,resize_width)#33,33
            else:
              # If crop_size is None, we simply do global pooling.
              pool_height = tf.shape(features)[1]
              pool_width = tf.shape(features)[2]
              image_feature = tf.reduce_mean(
                  features, axis=[1, 2], keepdims=True)
              resize_height = pool_height
              resize_width = pool_width

            image_feature_activation_fn = tf.nn.relu
            image_feature_normalizer_fn = batch_norm
            #print(model_options.aspp_with_squeeze_and_excitation)#False
            if model_options.aspp_with_squeeze_and_excitation:
              image_feature_activation_fn = tf.nn.sigmoid
              if model_options.image_se_uses_qsigmoid:
                image_feature_activation_fn = utils.q_sigmoid
              image_feature_normalizer_fn = None
            #print(image_feature)
            image_feature = slim.conv2d(
                image_feature, depth, 1,
                activation_fn=image_feature_activation_fn,
                normalizer_fn=image_feature_normalizer_fn,
                scope=IMAGE_POOLING_SCOPE)
  
            #print(image_feature)#Tensor("image_pooling/Relu:0", shape=(?, 1, 1, 256), dtype=float32, device=/device:GPU:0)
            #通过插值使特征图恢复输出stride的尺寸
            image_feature = _resize_bilinear(
                image_feature,
                [resize_height, resize_width],
                image_feature.dtype)
            #print(image_feature)#Tensor("ResizeBilinear:0", shape=(?, 33, 33, 256), dtype=float32, device=/device:GPU:0)
            # Set shape for resize_height/resize_width if they are not Tensor.
            if isinstance(resize_height, tf.Tensor):
              resize_height = None
            if isinstance(resize_width, tf.Tensor):
              resize_width = None
            image_feature.set_shape([None, resize_height, resize_width, depth])
            #print(image_feature)
            if not model_options.aspp_with_squeeze_and_excitation:
              branch_logits.append(image_feature)
              #print(branch_logits)
          # Employ a 1x1 convolution.
          branch_logits.append(slim.conv2d(features, depth, 1,
                                           scope=ASPP_SCOPE + str(0)))
          #print(depth)#256
          #print(branch_logits)[, ]
          #print("features = ",features)
          if model_options.atrous_rates:
            # Employ 3x3 convolutions with different atrous rates.
            for i, rate in enumerate(model_options.atrous_rates, 1):
              scope = ASPP_SCOPE + str(i)
              if model_options.aspp_with_separable_conv:
                aspp_features = split_separable_conv2d(
                    features,
                    filters=depth,
                    rate=rate,
                    weight_decay=weight_decay,
                    scope=scope)
                #print(aspp_features)
              else:
                aspp_features = slim.conv2d(
                    features, depth, 3, rate=rate, scope=scope)
              branch_logits.append(aspp_features)
              #print(branch_logits)
          # Merge branch logits.
          concat_logits = tf.concat(branch_logits, 3)
          #print(concat_logits)
          if model_options.aspp_with_concat_projection:
            concat_logits = slim.conv2d(
                concat_logits, depth, 1, scope=CONCAT_PROJECTION_SCOPE)
            #print(concat_logits)
            concat_logits = slim.dropout(
                concat_logits,
                keep_prob=0.9,
                is_training=is_training,
                scope=CONCAT_PROJECTION_SCOPE + '_dropout')
            #print(concat_logits)
          if (model_options.add_image_level_feature and model_options.aspp_with_squeeze_and_excitation):
            concat_logits *= image_feature
          #print(concat_logits)
          return concat_logits, end_points


def _get_logits(images,
                model_options,
                weight_decay=0.0001,
                reuse=None,
                is_training=False,
                fine_tune_batch_norm=False,
                nas_training_hyper_parameters=None):
  """Gets the logits by atrous/image spatial pyramid pooling.

  Args:
    images: A tensor of size [batch, height, width, channels].
    model_options: A ModelOptions instance to configure models.
    weight_decay: The weight decay for model variables.
    reuse: Reuse the model variables or not.
    is_training: Is training or not.
    fine_tune_batch_norm: Fine-tune the batch norm parameters or not.
    nas_training_hyper_parameters: A dictionary storing hyper-parameters for
      training nas models. Its keys are:
      - `drop_path_keep_prob`: Probability to keep each path in the cell when
        training.
      - `total_training_steps`: Total training steps to help drop path
        probability calculation.

  Returns:
    outputs_to_logits: A map from output_type to logits.
  """
  features, end_points = extract_features(
      images,
      model_options,
      weight_decay=weight_decay,
      reuse=reuse,
      is_training=is_training,
      fine_tune_batch_norm=fine_tune_batch_norm,
      nas_training_hyper_parameters=nas_training_hyper_parameters)
  #print("features = ",features)
  if model_options.decoder_output_stride:
    crop_size = model_options.crop_size
    if crop_size is None:
      crop_size = [tf.shape(images)[1], tf.shape(images)[2]]
    
    #解码器
    features = refine_by_decoder(
        features,
        end_points,
        crop_size=crop_size,
        decoder_output_stride=model_options.decoder_output_stride,
        decoder_use_separable_conv=model_options.decoder_use_separable_conv,
        decoder_use_sum_merge=model_options.decoder_use_sum_merge,
        decoder_filters=model_options.decoder_filters,
        decoder_output_is_logits=model_options.decoder_output_is_logits,
        model_variant=model_options.model_variant,
        weight_decay=weight_decay,
        reuse=reuse,
        is_training=is_training,
        fine_tune_batch_norm=fine_tune_batch_norm,
        use_bounded_activation=model_options.use_bounded_activation)
    #print(features)
  outputs_to_logits = {}
  #print(model_options.outputs_to_num_classes){'semantic': 4}
  for output in sorted(model_options.outputs_to_num_classes):
    #print(output)
    if model_options.decoder_output_is_logits:
      outputs_to_logits[output] = tf.identity(features,
                                              name=output)
    else:
      outputs_to_logits[output] = get_branch_logits(
          features,
          model_options.outputs_to_num_classes[output],
          model_options.atrous_rates,
          aspp_with_batch_norm=model_options.aspp_with_batch_norm,
          kernel_size=model_options.logits_kernel_size,
          weight_decay=weight_decay,
          reuse=reuse,
          scope_suffix=output)
      #print(outputs_to_logits[output])
  return outputs_to_logits


def refine_by_decoder(features,
                      end_points,
                      crop_size=None,
                      decoder_output_stride=None,
                      decoder_use_separable_conv=False,
                      decoder_use_sum_merge=False,
                      decoder_filters=256,
                      decoder_output_is_logits=False,
                      model_variant=None,
                      weight_decay=0.0001,
                      reuse=None,
                      is_training=False,
                      fine_tune_batch_norm=False,
                      use_bounded_activation=False,
                      sync_batch_norm_method='None'):
  """Adds the decoder to obtain sharper segmentation results.

  Args:
    features: A tensor of size [batch, features_height, features_width,
      features_channels].
    end_points: A dictionary from components of the network to the corresponding
      activation.
    crop_size: A tuple [crop_height, crop_width] specifying whole patch crop
      size.
    decoder_output_stride: A list of integers specifying the output stride of
      low-level features used in the decoder module.
    decoder_use_separable_conv: Employ separable convolution for decoder or not.
    decoder_use_sum_merge: Boolean, decoder uses simple sum merge or not.
    decoder_filters: Integer, decoder filter size.
    decoder_output_is_logits: Boolean, using decoder output as logits or not.
    model_variant: Model variant for feature extraction.
    weight_decay: The weight decay for model variables.
    reuse: Reuse the model variables or not.
    is_training: Is training or not.
    fine_tune_batch_norm: Fine-tune the batch norm parameters or not.
    use_bounded_activation: Whether or not to use bounded activations. Bounded
      activations better lend themselves to quantized inference.
    sync_batch_norm_method: String, method used to sync batch norm. Currently
     only support `None` (no sync batch norm) and `tpu` (use tpu code to
     sync batch norm).

  Returns:
    Decoder output with size [batch, decoder_height, decoder_width,
      decoder_channels].

  Raises:
    ValueError: If crop_size is None.
  """
  if crop_size is None:
    raise ValueError('crop_size must be provided when using decoder.')
  batch_norm_params = utils.get_batch_norm_params(
      decay=0.9997,
      epsilon=1e-5,
      scale=True,
      is_training=(is_training and fine_tune_batch_norm),
      sync_batch_norm_method=sync_batch_norm_method)
  batch_norm = utils.get_batch_norm_fn(sync_batch_norm_method)
  decoder_depth = decoder_filters
  projected_filters = 48
  if decoder_use_sum_merge:
    # When using sum merge, the projected filters must be equal to decoder
    # filters.
    projected_filters = decoder_filters
  if decoder_output_is_logits:
    # Overwrite the setting when decoder output is logits.
    activation_fn = None
    normalizer_fn = None
    conv2d_kernel = 1
    # Use original conv instead of separable conv.
    decoder_use_separable_conv = False
  else:
    # Default setting when decoder output is not logits.
    activation_fn = tf.nn.relu6 if use_bounded_activation else tf.nn.relu
    normalizer_fn = batch_norm
    conv2d_kernel = 3
  with slim.arg_scope(
      [slim.conv2d, slim.separable_conv2d],
      weights_regularizer=slim.l2_regularizer(weight_decay),
      activation_fn=activation_fn,
      normalizer_fn=normalizer_fn,
      padding='SAME',
      stride=1,
      reuse=reuse):
    with slim.arg_scope([batch_norm], **batch_norm_params):
      with tf.compat.v1.variable_scope(DECODER_SCOPE, DECODER_SCOPE, [features]):
        decoder_features = features
        decoder_stage = 0
        scope_suffix = ''
        #print(decoder_output_stride)#[4]
        for output_stride in decoder_output_stride:
          feature_list = feature_extractor.networks_to_feature_maps[
              model_variant][
                  feature_extractor.DECODER_END_POINTS][output_stride]
          #print(feature_list)['entry_flow/block2/unit_1/xception_module/separable_conv2_pointwise']
          # If only one decoder stage, we do not change the scope name in
          # order for backward compactibility.
          if decoder_stage:
            scope_suffix = '_{}'.format(decoder_stage)
          for i, name in enumerate(feature_list):
            decoder_features_list = [decoder_features]
            #print(decoder_features_list)
            # MobileNet and NAS variants use different naming convention.
            if ('mobilenet' in model_variant or
                model_variant.startswith('mnas') or
                model_variant.startswith('nas')):
              feature_name = name
            else:
              feature_name = '{}/{}'.format(feature_extractor.name_scope[model_variant], name)
              #print(feature_name,name)
            decoder_features_list.append(
                slim.conv2d(
                    end_points[feature_name],
                    projected_filters,
                    1,
                    scope='feature_projection' + str(i) + scope_suffix))
            #print(decoder_features_list)
            # Determine the output size.
            decoder_height = scale_dimension(crop_size[0], 1.0 / output_stride)
            decoder_width = scale_dimension(crop_size[1], 1.0 / output_stride)

            # Resize to decoder_height/decoder_width.
            for j, feature in enumerate(decoder_features_list):
              decoder_features_list[j] = _resize_bilinear(feature, [decoder_height, decoder_width], feature.dtype)
              #print(decoder_features_list[j])
              h = (None if isinstance(decoder_height, tf.Tensor)
                   else decoder_height)
              w = (None if isinstance(decoder_width, tf.Tensor)
                   else decoder_width)
              decoder_features_list[j].set_shape([None, h, w, None])
            if decoder_use_sum_merge:
              decoder_features = _decoder_with_sum_merge(
                  decoder_features_list,
                  decoder_depth,
                  conv2d_kernel=conv2d_kernel,
                  decoder_use_separable_conv=decoder_use_separable_conv,
                  weight_decay=weight_decay,
                  scope_suffix=scope_suffix)
            else:
              if not decoder_use_separable_conv:
                scope_suffix = str(i) + scope_suffix
              decoder_features = _decoder_with_concat_merge(
                  decoder_features_list,
                  decoder_depth,
                  decoder_use_separable_conv=decoder_use_separable_conv,
                  weight_decay=weight_decay,
                  scope_suffix=scope_suffix)
          decoder_stage += 1
        return decoder_features


def _decoder_with_sum_merge(decoder_features_list,
                            decoder_depth,
                            conv2d_kernel=3,
                            decoder_use_separable_conv=True,
                            weight_decay=0.0001,
                            scope_suffix=''):
  """Decoder with sum to merge features.

  Args:
    decoder_features_list: A list of decoder features.
    decoder_depth: Integer, the filters used in the convolution.
    conv2d_kernel: Integer, the convolution kernel size.
    decoder_use_separable_conv: Boolean, use separable conv or not.
    weight_decay: Weight decay for the model variables.
    scope_suffix: String, used in the scope suffix.

  Returns:
    decoder features merged with sum.

  Raises:
    RuntimeError: If decoder_features_list have length not equal to 2.
  """
  if len(decoder_features_list) != 2:
    raise RuntimeError('Expect decoder_features has length 2.')
  # Only apply one convolution when decoder use sum merge.
  if decoder_use_separable_conv:
    decoder_features = split_separable_conv2d(
        decoder_features_list[0],
        filters=decoder_depth,
        rate=1,
        weight_decay=weight_decay,
        scope='decoder_split_sep_conv0'+scope_suffix) + decoder_features_list[1]
  else:
    decoder_features = slim.conv2d(
        decoder_features_list[0],
        decoder_depth,
        conv2d_kernel,
        scope='decoder_conv0'+scope_suffix) + decoder_features_list[1]
  return decoder_features


def _decoder_with_concat_merge(decoder_features_list,
                               decoder_depth,
                               decoder_use_separable_conv=True,
                               weight_decay=0.0001,
                               scope_suffix=''):
  """Decoder with concatenation to merge features.

  This decoder method applies two convolutions to smooth the features obtained
  by concatenating the input decoder_features_list.

  This decoder module is proposed in the DeepLabv3+ paper.

  Args:
    decoder_features_list: A list of decoder features.
    decoder_depth: Integer, the filters used in the convolution.
    decoder_use_separable_conv: Boolean, use separable conv or not.
    weight_decay: Weight decay for the model variables.
    scope_suffix: String, used in the scope suffix.

  Returns:
    decoder features merged with concatenation.
  """
  if decoder_use_separable_conv:
    decoder_features = split_separable_conv2d(
        tf.concat(decoder_features_list, 3),
        filters=decoder_depth,
        rate=1,
        weight_decay=weight_decay,
        scope='decoder_conv0'+scope_suffix)
    decoder_features = split_separable_conv2d(
        decoder_features,
        filters=decoder_depth,
        rate=1,
        weight_decay=weight_decay,
        scope='decoder_conv1'+scope_suffix)
  else:
    num_convs = 2
    decoder_features = slim.repeat(
        tf.concat(decoder_features_list, 3),
        num_convs,
        slim.conv2d,
        decoder_depth,
        3,
        scope='decoder_conv'+scope_suffix)
  return decoder_features


def get_branch_logits(features,
                      num_classes,
                      atrous_rates=None,
                      aspp_with_batch_norm=False,
                      kernel_size=1,
                      weight_decay=0.0001,
                      reuse=None,
                      scope_suffix=''):
  """Gets the logits from each model's branch.

  The underlying model is branched out in the last layer when atrous
  spatial pyramid pooling is employed, and all branches are sum-merged
  to form the final logits.

  Args:
    features: A float tensor of shape [batch, height, width, channels].
    num_classes: Number of classes to predict.
    atrous_rates: A list of atrous convolution rates for last layer.
    aspp_with_batch_norm: Use batch normalization layers for ASPP.
    kernel_size: Kernel size for convolution.
    weight_decay: Weight decay for the model variables.
    reuse: Reuse model variables or not.
    scope_suffix: Scope suffix for the model variables.

  Returns:
    Merged logits with shape [batch, height, width, num_classes].

  Raises:
    ValueError: Upon invalid input kernel_size value.
  """
  # When using batch normalization with ASPP, ASPP has been applied before
  # in extract_features, and thus we simply apply 1x1 convolution here.
  #print(aspp_with_batch_norm , atrous_rates)#True [6, 12, 18]
  if aspp_with_batch_norm or atrous_rates is None:
    if kernel_size != 1:
      raise ValueError('Kernel size must be 1 when atrous_rates is None or '
                       'using aspp_with_batch_norm. Gets %d.' % kernel_size)
    atrous_rates = [1]

  with slim.arg_scope(
      [slim.conv2d],
      weights_regularizer=slim.l2_regularizer(weight_decay),
      weights_initializer=tf.truncated_normal_initializer(stddev=0.01),
      reuse=reuse):
    with tf.compat.v1.variable_scope(LOGITS_SCOPE_NAME, LOGITS_SCOPE_NAME, [features]):
      branch_logits = []
      #print(atrous_rates)
      for i, rate in enumerate(atrous_rates):
        scope = scope_suffix
        if i:
          scope += '_%d' % i

        branch_logits.append(
            slim.conv2d(
                features,
                num_classes,
                kernel_size=kernel_size,
                rate=rate,
                activation_fn=None,
                normalizer_fn=None,
                scope=scope))
      #print(branch_logits)
      return tf.add_n(branch_logits)

train_utils.py

import six
import tensorflow as tf
from tensorflow.contrib import framework as contrib_framework

from deeplab.core import preprocess_utils
from deeplab.core import utils


def _div_maybe_zero(total_loss, num_present):
  """Normalizes the total loss with the number of present pixels."""
  return tf.cast(num_present > 0, float) * tf.math.divide(
      total_loss,
      tf.maximum(1e-5, num_present))

#为每个尺度的logit添加softmax交叉熵损失
def add_softmax_cross_entropy_loss_for_each_scale(scales_to_logits,
                                                  labels,
                                                  num_classes,
                                                  ignore_label,
                                                  loss_weight=1.0,
                                                  upsample_logits=True,
                                                  hard_example_mining_step=0,
                                                  top_k_percent_pixels=1.0,
                                                  gt_is_matting_map=False,
                                                  scope=None):
  # 判断标签是否存在
  if labels is None:
    raise ValueError('No label for softmax cross entropy loss.')

  # 如果输入groundtruth是标签特征图,则检查输入标签是否为浮点值。
  if gt_is_matting_map and not labels.dtype.is_floating:
    raise ValueError('Labels must be floats if groundtruth is a matting map.')
  
  # 遍历每一种比例的输出
  for scale, logits in six.iteritems(scales_to_logits):
    loss_scope = None
    #设置每个比例输出的loss_scope名称,便于之后的保存
    if scope:
      loss_scope = '%s_%s' % (scope, scale)
      
	#判断是否设置了上采样logits处理
    if upsample_logits:
      #这里不会对标签进行下采样,而是对每一种logits进行上采样,出来成与标签高宽相同的尺寸
      logits = tf.compat.v1.image.resize_bilinear(
          logits,
          preprocess_utils.resolve_shape(labels, 4)[1:3],
          align_corners=True)
      #将标签labels赋值给 scaled_labels
      scaled_labels = labels

	#若是没有设置上采样logits处理,则对标签进行下采样处理
    else:
      # 标签被下采样到与logit相同的大小。
	  # 当gt_is_matting_map = true时,使用最邻近方法进行标签下采样可能会引入伪像。 
	  #但是,为了避免将ignore_label与其他标签进行插值,我们仍然会执行最邻近插值。
      #TODO(huizhongc):通过分别处理填充标签和未填充标签来更改为双线性插值。
      if gt_is_matting_map:
        tf.compat.v1.logging.warning('Label downsampling with nearest neighbor may introduce artifacts.')
	  # 使用最邻近方法进行标签下采样
      scaled_labels = tf.image.resize_nearest_neighbor(
          labels,
          preprocess_utils.resolve_shape(logits, 4)[1:3],
          align_corners=True)
	# 对标签维度进行归一处理,将其转化为一维尺度
    scaled_labels = tf.reshape(scaled_labels, shape=[-1])
    #获取标签不被忽略的权重模板
    weights = utils.get_label_weight_mask(scaled_labels, ignore_label, num_classes, label_weights=loss_weight)
    '''
     label_weights:浮点数或权重列表。 如果是浮点数,则表示所有标签的权重系数相同。 
     如果是权重列表,则列表中的每个元素代表其索引标签的权重,
     例如,label_weights = [0.1,0.5]表示标签0的权重为0.1,标签1的权重为0.5 。
    def get_label_weight_mask(labels, ignore_label, num_classes, label_weights=1.0):
      #判断label_weights格式是否为float或list
	  if not isinstance(label_weights, (float, list)):
	    raise ValueError('The type of label_weights is invalid, it must be a float or a list.')
	  #判断label_weights个格式是否为list且其长度是否等于num_classes
	  if isinstance(label_weights, list) and len(label_weights) != num_classes:
	    raise ValueError('Length of label_weights must be equal to num_classes if it is a list, '
	        'label_weights: %s, num_classes: %d.' % (label_weights, num_classes))
	        
	  # 获取不被忽略的图像像素位置模板
	  #bool返回 (x! = y) 元素的真值.
	  not_ignore_mask = tf.not_equal(labels, ignore_label)
	  #将bool值转化为float格式
	  not_ignore_mask = tf.cast(not_ignore_mask, tf.float32)
	  #如果是浮点数,则表示所有标签的重量相同,对所有权重系数进行初始化并返回模板
	  if isinstance(label_weights, float):
	    return not_ignore_mask * label_weights
	
	  label_weights = tf.constant(label_weights, tf.float32)
	  weight_mask = tf.einsum('...y,y->...',tf.one_hot(labels, num_classes, dtype=tf.float32),label_weights)
	  return tf.multiply(not_ignore_mask, weight_mask)
	  '''
    # Dimension of keep_mask is equal to the total number of pixels.
    # keep_mask的尺寸等于像素总数
    keep_mask = tf.cast(tf.not_equal(scaled_labels, ignore_label), dtype=tf.float32)
    train_labels = None
    # 将logits转化为二维shape=(?, 4)格式
    logits = tf.reshape(logits, shape=[-1, num_classes])

    if gt_is_matting_map:
      #当groundtruth是整数标签掩码时,我们可以将与类相关的标签权重分配给损失。 
      #当groundtruth是图像置信度时,我们不会应用与类相关的标签权重(即,label_weight = 1.0)。
      if loss_weight != 1.0:
        raise ValueError('loss_weight must equal to 1 if groundtruth is matting map.')

      # 将标签值0指定为忽略像素。 忽略像素的确切标签值无关紧要,因为那些ignore_value像素损失将乘以0权重。
      train_labels = scaled_labels * keep_mask
      train_labels = tf.expand_dims(train_labels, 1)
      train_labels = tf.concat([1 - train_labels, train_labels], axis=1)
      
    else:
      # 对标签进行热键处理,将其整理成与logits相同的格式
      train_labels = tf.one_hot(scaled_labels, num_classes, on_value=1.0, off_value=0.0)
	#设置default_loss_scope名称
    default_loss_scope = ('softmax_all_pixel_loss'if top_k_percent_pixels == 1.0 else'softmax_hard_example_mining')

    with tf.name_scope(loss_scope, default_loss_scope,
                       [logits, train_labels, weights]):
      # Compute the loss for all pixels.
      # 计算 softmax(logits) 和 labels 之间的交叉熵
      pixel_losses = tf.nn.softmax_cross_entropy_with_logits_v2(
          labels=tf.stop_gradient(train_labels, name='train_labels_stop_gradient'),
          logits=logits,
          name='pixel_losses')
	  # 将损失交叉熵与之前计算的权重相乘
      weighted_pixel_losses = tf.multiply(pixel_losses, weights)
      # 计算最终损失值
      if top_k_percent_pixels == 1.0:
        total_loss = tf.reduce_sum(weighted_pixel_losses)
        num_present = tf.reduce_sum(keep_mask)
        loss = _div_maybe_zero(total_loss, num_present)
        # 将外部定义的损失添加到损失集合中。
        tf.compat.v1.losses.add_loss(loss)
      else:#使用困难挖掘样本算法
        num_pixels = tf.to_float(tf.shape(logits)[0])
        # Compute the top_k_percent pixels based on current training step.
        if hard_example_mining_step == 0:
          # Directly focus on the top_k pixels.
          top_k_pixels = tf.to_int32(top_k_percent_pixels * num_pixels)
        else:
          # Gradually reduce the mining percent to top_k_percent_pixels.
          global_step = tf.to_float(tf.train.get_or_create_global_step())
          ratio = tf.minimum(1.0, global_step / hard_example_mining_step)
          top_k_pixels = tf.to_int32((ratio * top_k_percent_pixels + (1.0 - ratio)) * num_pixels)
          
        top_k_losses, _ = tf.nn.top_k(weighted_pixel_losses,
                                      k=top_k_pixels,
                                      sorted=True,
                                      name='top_k_percent_pixels')
        total_loss = tf.reduce_sum(top_k_losses)
        num_present = tf.reduce_sum(tf.to_float(tf.not_equal(top_k_losses, 0.0)))
        loss = _div_maybe_zero(total_loss, num_present)
        tf.losses.add_loss(loss)

# 从检查点获取初始化模型变量的函数
def get_model_init_fn(train_logdir,#训练模型保存目录
                      tf_initial_checkpoint,#用于初始化的TensorFlow模型
                      initialize_last_layer,#是否初始化最后一层
                      last_layers,#模型的最后一层
                      ignore_missing_vars=False):#忽略预训练模型中缺少的变量,如果你的模型因为缺少参数报错可以将其改为True
  

  # 判断是否存在预训练模型,若是不存在则直接返回退出
  if tf_initial_checkpoint is None:
    tf.compat.v1.logging.info('Not initializing the model from a checkpoint.')
    return None
  # 自动找到最近保存的变量文件,即判断保存模型的目录中是否有训练模型,若是有直接退出,后面会加载它
  if tf.train.latest_checkpoint(train_logdir):
    tf.compat.v1.logging.info('Ignoring initialization; other checkpoint exists')
    return None
  # 输出预训练模型地址
  tf.compat.v1.logging.info('Initializing model from path: %s', tf_initial_checkpoint)

  # Variables that will not be restored.
  # 不会被预训练模型初始化的层列表
  exclude_list = ['global_step']
  # 将之前设置的不被预训练模型初始化的层排除
  if not initialize_last_layer:
    exclude_list.extend(last_layers)
  # 设置需要预训练模型初始化的层,即将之前exclude列表中的层去除
  variables_to_restore = contrib_framework.get_variables_to_restore(exclude=exclude_list)
  # 如果存在需要初始化的层,则开始初始化这些网络层参数
  if variables_to_restore:
    init_op, init_feed_dict = contrib_framework.assign_from_checkpoint(
        tf_initial_checkpoint,
        variables_to_restore,
        ignore_missing_vars=ignore_missing_vars)
    # 获取全局步数
    global_step = tf.compat.v1.train.get_or_create_global_step()
	# 这个函数就是运行之前的操作,因为之前的操作都是在建立结点,此函数就是运行这些结点
    def restore_fn(sess):
      sess.run(init_op, init_feed_dict)
      sess.run([global_step])

    return restore_fn

  return None


def get_model_gradient_multipliers(last_layers, last_layer_gradient_multiplier):
  """Gets the gradient multipliers.

  The gradient multipliers will adjust the learning rates for model
  variables. For the task of semantic segmentation, the models are
  usually fine-tuned from the models trained on the task of image
  classification. To fine-tune the models, we usually set larger (e.g.,
  10 times larger) learning rate for the parameters of last layer.

  Args:
    last_layers: Scopes of last layers.
    last_layer_gradient_multiplier: The gradient multiplier for last layers.

  Returns:
    The gradient multiplier map with variables as key, and multipliers as value.
  """
  gradient_multipliers = {}

  for var in tf.compat.v1.model_variables():
    # Double the learning rate for biases.
    if 'biases' in var.op.name:
      gradient_multipliers[var.op.name] = 2.

    # Use larger learning rate for last layer variables.
    for layer in last_layers:
      if layer in var.op.name and 'biases' in var.op.name:
        gradient_multipliers[var.op.name] = 2 * last_layer_gradient_multiplier
        break
      elif layer in var.op.name:
        gradient_multipliers[var.op.name] = last_layer_gradient_multiplier
        break

  return gradient_multipliers

# 获取模型的学习率
def get_model_learning_rate(learning_policy,#训练学习率的方法
                            base_learning_rate,#模型训练的基本学习率
                            learning_rate_decay_step,#以固定步数衰减基本学习率
                            learning_rate_decay_factor,#衰减基本学习率的速率
                            training_number_of_steps,#训练的总步数
                            learning_power,#用于“ploy”学习方法的权重
                            slow_start_step,#前几步的学习率较低的训练模型
                            slow_start_learning_rate,#缓慢启动期间采用的学习率
                            slow_start_burnin_type='none',#慢启动阶段的衰减类型。可以是“无”(不衰减)或“线性”(表示学习率)这意味着学习率从slow_start_learning_rate线性增加,并在slow_start_steps之后达到base_learning_rate。
                            decay_steps=0.0,# 衰减步长
                            end_learning_rate=0.0,#最终学习率
                            boundaries=None,#带有严格增加条目的Tensor或Int或Float的列表
                            boundary_learning_rates=None):#
 '''返回指定学习策略的学习率
 (1)“ step ”的学习策略计算如下:
    current_learning_rate = base_learning_rate *learning_rate_decay_factor ^(global_step / learning_rate_decay_step)
  有关详细信息,请参见tf.train.exponential_decay。
  (2)“ poly ” 的学习策略计算如下:
    current_learning_rate = base_learning_rate *(1-global_step / training_number_of_steps)^ learning_power'''
	
  #获取全局步数
  global_step = tf.compat.v1.train.get_or_create_global_step()
  adjusted_global_step = tf.maximum(global_step - slow_start_step, 0)
  if decay_steps == 0.0:
    tf.compat.v1.logging.info('Setting decay_steps to total training steps.')
    decay_steps = training_number_of_steps - slow_start_step
  #根据学习策略选择函数
  if learning_policy == 'step':
    learning_rate = tf.train.exponential_decay(
        base_learning_rate,
        adjusted_global_step,
        learning_rate_decay_step,
        learning_rate_decay_factor,
        staircase=True)
  elif learning_policy == 'poly':
    learning_rate = tf.compat.v1.train.polynomial_decay(
        base_learning_rate,
        adjusted_global_step,
        decay_steps=decay_steps,
        end_learning_rate=end_learning_rate,
        power=learning_power)
  elif learning_policy == 'cosine':
    learning_rate = tf.train.cosine_decay(
        base_learning_rate,
        adjusted_global_step,
        training_number_of_steps - slow_start_step)
  elif learning_policy == 'multi_steps':
    if boundaries is None or boundary_learning_rates is None:
      raise ValueError('Must set `boundaries` and `boundary_learning_rates` '
                       'for multi_steps learning rate decay.')
    learning_rate = tf.train.piecewise_constant_decay(
        adjusted_global_step,
        boundaries,
        boundary_learning_rates)
  else:
    raise ValueError('Unknown learning policy.')

  adjusted_slow_start_learning_rate = slow_start_learning_rate
  if slow_start_burnin_type == 'linear':
    # Do linear burnin. Increase linearly from slow_start_learning_rate and
    # reach base_learning_rate after (global_step >= slow_start_steps).
    adjusted_slow_start_learning_rate = (
        slow_start_learning_rate +
        (base_learning_rate - slow_start_learning_rate) *
        tf.to_float(global_step) / slow_start_step)
  elif slow_start_burnin_type != 'none':
    raise ValueError('Unknown burnin type.')

  # Employ small learning rate at the first few steps for warm start.
  # 在开始的几个步骤中,应采用较小的学习率,以实现平滑启动。
  return tf.where(global_step < slow_start_step,adjusted_slow_start_learning_rate, learning_rate)

你可能感兴趣的:(Tensorflow,DeepLabv3+,python)