DeepLabv3+代码阅读之train_utils.py
一、get_model_learning_rate()
def get_model_learning_rate(learning_policy,
base_learning_rate,
learning_rate_decay_step,
learning_rate_decay_factor,
training_number_of_steps,
learning_power,
slow_start_step,
slow_start_learning_rate,
slow_start_burnin_type='none'):
"""Gets model's learning rate.
Computes the model's learning rate for different learning policy.
Right now, only "step" and "poly" are supported.
(1) The learning policy for "step" is computed as follows:
current_learning_rate = base_learning_rate *
learning_rate_decay_factor ^ (global_step / learning_rate_decay_step)
See tf.train.exponential_decay for details.
(2) The learning policy for "poly" is computed as follows:
current_learning_rate = base_learning_rate *
(1 - global_step / training_number_of_steps) ^ learning_power
"""
global_step = tf.train.get_or_create_global_step()
adjusted_global_step = global_step
if slow_start_burnin_type != 'none':
adjusted_global_step -= slow_start_step
if learning_policy == 'step':
learning_rate = tf.train.exponential_decay(
base_learning_rate,
adjusted_global_step,
learning_rate_decay_step,
learning_rate_decay_factor,
staircase=True)
elif learning_policy == 'poly':
learning_rate = tf.train.polynomial_decay(
base_learning_rate,
adjusted_global_step,
training_number_of_steps,
end_learning_rate=0,
power=learning_power)
else:
raise ValueError('Unknown learning policy.')
adjusted_slow_start_learning_rate = slow_start_learning_rate
if slow_start_burnin_type == 'linear':
adjusted_slow_start_learning_rate = (
slow_start_learning_rate +
(base_learning_rate - slow_start_learning_rate) *
tf.to_float(global_step) / slow_start_step)
elif slow_start_burnin_type != 'none':
raise ValueError('Unknown burnin type.')
return tf.where(global_step < slow_start_step,
adjusted_slow_start_learning_rate, learning_rate)
二、add_softmax_cross_entropy_loss_for_each_scale
对每一个尺度的输出结果计算cross entropy loss
参数:
scales_to_logits: logits名字到不同尺度的输出的对应,shape: [batch, logits_height, logits_width, num_classes].
labels: Groundtruth labels, shape: [batch, image_height, image_width, 1].
num_classes: 类别数
ignore_label: 忽略的标签编号
loss_weight: loss的权重(=1.0)
upsample_logits: 是否对logits上采样
hard_example_mining_step: default is 0
top_k_percent_pixels: default is 0
scope: the scope for the loss.
def add_softmax_cross_entropy_loss_for_each_scale(scales_to_logits,
labels,
num_classes,
ignore_label,
loss_weight=1.0,
upsample_logits=True,
hard_example_mining_step=0,
top_k_percent_pixels=1.0,
scope=None):
if labels is None:
raise ValueError('No label for softmax cross entropy loss.')
for scale, logits in six.iteritems(scales_to_logits):
loss_scope = None
if scope:
loss_scope = '%s_%s' % (scope, scale)
if upsample_logits:
logits = tf.image.resize_bilinear(
logits,
preprocess_utils.resolve_shape(labels, 4)[1:3],
align_corners=True)
scaled_labels = labels
else:
scaled_labels = tf.image.resize_nearest_neighbor(
labels,
preprocess_utils.resolve_shape(logits, 4)[1:3],
align_corners=True)
scaled_labels = tf.reshape(scaled_labels, shape=[-1])
not_ignore_mask = tf.to_float(tf.not_equal(scaled_labels,
ignore_label)) * loss_weight
one_hot_labels = tf.one_hot(
scaled_labels, num_classes, on_value=1.0, off_value=0.0)
if top_k_percent_pixels == 1.0:
tf.losses.softmax_cross_entropy(
one_hot_labels,
tf.reshape(logits, shape=[-1, num_classes]),
weights=not_ignore_mask,
scope=loss_scope)
else:
logits = tf.reshape(logits, shape=[-1, num_classes])
weights = not_ignore_mask
with tf.name_scope(loss_scope, 'softmax_hard_example_mining',
[logits, one_hot_labels, weights]):
one_hot_labels = tf.stop_gradient(
one_hot_labels, name='labels_stop_gradient')
pixel_losses = tf.nn.softmax_cross_entropy_with_logits_v2(
labels=one_hot_labels,
logits=logits,
name='pixel_losses')
weighted_pixel_losses = tf.multiply(pixel_losses, weights)
num_pixels = tf.to_float(tf.shape(logits)[0])
if hard_example_mining_step == 0:
top_k_pixels = tf.to_int32(top_k_percent_pixels * num_pixels)
else:
global_step = tf.to_float(tf.train.get_or_create_global_step())
ratio = tf.minimum(1.0, global_step / hard_example_mining_step)
top_k_pixels = tf.to_int32(
(ratio * top_k_percent_pixels + (1.0 - ratio)) * num_pixels)
top_k_losses, _ = tf.nn.top_k(weighted_pixel_losses,
k=top_k_pixels,
sorted=True,
name='top_k_percent_pixels')
total_loss = tf.reduce_sum(top_k_losses)
num_present = tf.reduce_sum(
tf.to_float(tf.not_equal(top_k_losses, 0.0)))
loss = _div_maybe_zero(total_loss, num_present)
tf.losses.add_loss(loss)
三、get_model_gradient_multipliers
梯度乘法器为模型的变量调整学习率。对于分割任务,模型通常会从由训练图像分类任务得到的模型中进行微调。
我们通常会对最后一层选取大一些(例如10倍)的学习率。
参数:
last_layers: 最后一层的域
last_layer_gradient_multiplier:最后一层的梯度乘法器
返回:
梯度乘法器的一个映射,{变量:乘法器的值}
def get_model_gradient_multipliers(last_layers, last_layer_gradient_multiplier):
gradient_multipliers = {}
for var in tf.model_variables():
if 'biases' in var.op.name:
gradient_multipliers[var.op.name] = 2.
for layer in last_layers:
if layer in var.op.name and 'biases' in var.op.name:
gradient_multipliers[var.op.name] = 2 * last_layer_gradient_multiplier
break
elif layer in var.op.name:
gradient_multipliers[var.op.name] = last_layer_gradient_multiplier
break
return gradient_multipliers
四、get_model_init_fn
从checkpoint中初始化模型。
参数:
train_logdir: 储存训练过程的log和checkpoint文件目录
tf_initial_checkpoint: 用来初始化的checkpoint
initialize_last_layer: 是否初始化最后一层
last_layers: 模型的最后一层
ignore_missing_vars: 忽略checkpoint中没有的变量
返回:
初始化后的模型
def get_model_init_fn(train_logdir,
tf_initial_checkpoint,
initialize_last_layer,
last_layers,
ignore_missing_vars=False):
if tf_initial_checkpoint is None:
tf.logging.info('Not initializing the model from a checkpoint.')
return None
if tf.train.latest_checkpoint(train_logdir):
tf.logging.info('Ignoring initialization; other checkpoint exists')
return None
tf.logging.info('Initializing model from path: %s', tf_initial_checkpoint)
exclude_list = ['global_step']
if not initialize_last_layer:
exclude_list.extend(last_layers)
variables_to_restore = tf.contrib.framework.get_variables_to_restore(
exclude=exclude_list)
if variables_to_restore:
init_op, init_feed_dict = tf.contrib.framework.assign_from_checkpoint(
tf_initial_checkpoint,
variables_to_restore,
ignore_missing_vars=ignore_missing_vars)
global_step = tf.train.get_or_create_global_step()
def restore_fn(unused_scaffold, sess):
sess.run(init_op, init_feed_dict)
sess.run([global_step])
return restore_fn
return None