在此程序中,我初次接触到了tf.estimator,除了官方教程,还有很多优秀的博客可供参考,这里对此模块不再详细介绍。
我们接下来所探讨的代码github链接,作者和上一篇文章DeeplabV3的作者相同。虽然DeeplabV3和DeeplabV3+的网络非常相似,但是这次DeeplabV3+使用了tf.data生成输入队列,tf.estimator生成和训练网络。可以说作者rishizek很与时俱进了。
改进:作者使用了TFRecord,这种格式的文件虽然有利于tensorflow的加速,但是弊端就是对于数据量很大的训练集,会占用极大的存储空间。我更喜欢直接从硬盘中读取数据,之后会附一份修改版的代码。
首先是工程目录:
其中最主要的是train(定义了训练的具体过程和参数)和deeplab_model(中定义了网络结构)
本文尝试顺序讲解代码,不再将整个py文件粘贴过来。
if __name__ == '__main__':
#设置log信息可以直接打印在console
tf.logging.set_verbosity(tf.logging.INFO)
#解析命令行传入的参数,使用parse_known_args()就算是命令输入不全也不会报错
FLAGS, unparsed = parser.parse_known_args()
#主函数入口
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
| 转到main()
def main(unused_argv):
# Using the Winograd non-fused algorithms provides a small performance boost.
os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'
#移除整个目录
if FLAGS.clean_model_dir:
shutil.rmtree(FLAGS.model_dir, ignore_errors=True)
# Set up a RunConfig to only save checkpoints once per training cycle.
#设置每1e9秒保存一次模型
run_config = tf.estimator.RunConfig().replace(save_checkpoints_secs=1e9)
model = tf.estimator.Estimator(
#模型函数:model function
model_fn=deeplab_model.deeplabv3_plus_model_fn,
#指定包括.ckpt文件在内的所有的保存位置,若为空贼默认为临时目录
model_dir=FLAGS.model_dir,
#参数为tf.estimator.RunConfig对象,包含了执行环境的信息,如果没有传递config,则他会被Estimator实例化,使用的是默认配置。
config=run_config,
params={
'output_stride': FLAGS.output_stride,
'batch_size': FLAGS.batch_size,
'base_architecture': FLAGS.base_architecture,
'pre_trained_model': FLAGS.pre_trained_model,
'batch_norm_decay': _BATCH_NORM_DECAY,
'num_classes': _NUM_CLASSES,
'tensorboard_images_max_outputs': FLAGS.tensorboard_images_max_outputs,
'weight_decay': FLAGS.weight_decay,
'learning_rate_policy': FLAGS.learning_rate_policy,
'num_train': _NUM_IMAGES['train'],
'initial_learning_rate': FLAGS.initial_learning_rate,
'max_iter': FLAGS.max_iter,
'end_learning_rate': FLAGS.end_learning_rate,
'power': _POWER,
'momentum': _MOMENTUM,
'freeze_batch_norm': FLAGS.freeze_batch_norm,
'initial_global_step': FLAGS.initial_global_step
})
#每个epoch保存的训练信息,learining_rate, cross_entropy, train_px_accuracy, train_mean_iou,这些都是需要在训练中不断打印出来的信息
for _ in range(FLAGS.train_epochs // FLAGS.epochs_per_eval):
tensors_to_log = {
'learning_rate': 'learning_rate',
'cross_entropy': 'cross_entropy',
'train_px_accuracy': 'train_px_accuracy',
'train_mean_iou': 'train_mean_iou',
}
#设置如何打印上述信息,每迭代十次打印一次
logging_hook = tf.train.LoggingTensorHook(
tensors=tensors_to_log, every_n_iter=10)
train_hooks = [logging_hook]
eval_hooks = None
#这里使用了tensorflow专属的调试器,可以看到python调试器看不到的信息,具体信息请参考:https://www.tensorflow.org/guide/debugger
if FLAGS.debug:
debug_hook = tf_debug.LocalCLIDebugHook()
train_hooks.append(debug_hook)
eval_hooks = [debug_hook]
#这条信息会打印出“start training...”
tf.logging.info("Start training.")
#开始训练,调用input向模型中输入输入,调用train_hooks显示训练过程中的信息
model.train(
#True代表是输入训练数据
input_fn=lambda: input_fn(True, FLAGS.data_dir, FLAGS.batch_size, FLAGS.epochs_per_eval),
hooks=train_hooks,
# steps=1 # For debug
)
#开始验证,验证的调用并不是在训练结束之后才开始的,根据前面的配置,我们已经配置为每个epoch遍历完一次就运行一次验证
tf.logging.info("Start evaluation.")
# Evaluate the model and print results
eval_results = model.evaluate(
# Batch size must be 1 for testing because the images' size differs
input_fn=lambda: input_fn(False, FLAGS.data_dir, 1),
hooks=eval_hooks,
# steps=1 # For debug
)
print(eval_results)
train.p y 完结,我们要介绍模型的建立及数据的导入
首先是模型建立
# 对与estimater的使用方法参考官方文档
# https://tensorflow.google.cn/api_docs/python/tf/estimator/Estimator
# 因为在train.py中没有看到往deeplabv3_plus_model_fn传入参数,所以不解其features,labels,mode,params是如何传入的。
def deeplabv3_plus_model_fn(features, labels, mode, params):
"""Model function for PASCAL VOC."""
if isinstance(features, dict):
features = features['feature']
#将图片每个通道减去均值
images = tf.cast(
tf.map_fn(preprocessing.mean_image_addition, features),
tf.uint8)
#生成网络,deeplab_v3_plus_generator实际上返回的是一个函数,函数体为定义在deeplab_v3_plus_generator中的model函数,network就是这个函数的实体。
network = deeplab_v3_plus_generator(params['num_classes'],
params['output_stride'],
params['base_architecture'],
params['pre_trained_model'],
params['batch_norm_decay'])
#设定网络模式为train,并得到输出logits
logits = network(features, mode == tf.estimator.ModeKeys.TRAIN)
#logits的channel维中,每个点都有num_classes个分类,选取其中最大的值所对应的维度,组成新的tensor,新的tensor是[batchsize,w,h,1]
pred_classes = tf.expand_dims(tf.argmax(logits, axis=3, output_type=tf.int32), axis=3)
#将logits做成gt label的形式,可供人眼辨识,即(num_images, h, w, 3),不同的分类有不同的颜色
pred_decoded_labels = tf.py_func(preprocessing.decode_labels,
[pred_classes, params['batch_size'], params['num_classes']],
tf.uint8)
#使用字典保存真正的结果,对logits的处理到此结束
predictions = {
'classes': pred_classes,
'probabilities': tf.nn.softmax(logits, name='softmax_tensor'),
'decoded_labels': pred_decoded_labels
}
if mode == tf.estimator.ModeKeys.PREDICT:
# Delete 'decoded_labels' from predictions because custom functions produce error when used with saved_model
predictions_without_decoded_labels = predictions.copy()
del predictions_without_decoded_labels['decoded_labels']
return tf.estimator.EstimatorSpec(
mode=mode,
predictions=predictions,
export_outputs={
'preds': tf.estimator.export.PredictOutput(
predictions_without_decoded_labels)
})
#开始对ground truth进行处理,首先是把他的形状转换成与logits相同的大小
gt_decoded_labels = tf.py_func(preprocessing.decode_labels,
[labels, params['batch_size'], params['num_classes']], tf.uint8)
#去掉channel这个维度
labels = tf.squeeze(labels, axis=3) # reduce the channel dimension.
#将logtis展开
logits_by_num_classes = tf.reshape(logits, [-1, params['num_classes']])
#将label展开
labels_flat = tf.reshape(labels, [-1, ])
#有效的labels的为1,无效为0
valid_indices = tf.to_int32(labels_flat <= params['num_classes'] - 1)
#将logits和labels按照0,1区分开,并且只取为1的部分
valid_logits = tf.dynamic_partition(logits_by_num_classes, valid_indices, num_partitions=2)[1]
valid_labels = tf.dynamic_partition(labels_flat, valid_indices, num_partitions=2)[1]
#生成混淆矩阵,混淆矩阵可以参考连接:https://blog.csdn.net/m0_38061927/article/details/77198990
preds_flat = tf.reshape(pred_classes, [-1, ])
valid_preds = tf.dynamic_partition(preds_flat, valid_indices, num_partitions=2)[1]
confusion_matrix = tf.confusion_matrix(valid_labels, valid_preds, num_classes=params['num_classes'])
predictions['valid_preds'] = valid_preds
predictions['valid_labels'] = valid_labels
predictions['confusion_matrix'] = confusion_matrix
#输出交叉熵
cross_entropy = tf.losses.sparse_softmax_cross_entropy(
logits=valid_logits, labels=valid_labels)
#为了保存cross_entropy到log
# Create a tensor named cross_entropy for logging purposes.
tf.identity(cross_entropy, name='cross_entropy')
tf.summary.scalar('cross_entropy', cross_entropy)
#是否在训练时禁止更新BN层参数
if not params['freeze_batch_norm']:
train_var_list = [v for v in tf.trainable_variables()]
else:
train_var_list = [v for v in tf.trainable_variables()
if 'beta' not in v.name and 'gamma' not in v.name]
# Add weight decay to the loss.
#输出total loss
with tf.variable_scope("total_loss"):
loss = cross_entropy + params.get('weight_decay', _WEIGHT_DECAY) * tf.add_n(
[tf.nn.l2_loss(v) for v in train_var_list])
# loss = tf.losses.get_total_loss() # obtain the regularization losses as well
if mode == tf.estimator.ModeKeys.TRAIN:
tf.summary.image('images',
tf.concat(axis=2, values=[images, gt_decoded_labels, pred_decoded_labels]),
max_outputs=params['tensorboard_images_max_outputs']) # Concatenate row-wise.
#建立global_step()
global_step = tf.train.get_or_create_global_step()
if params['learning_rate_policy'] == 'piecewise':
# Scale the learning rate linearly with the batch size. When the batch size
# is 128, the learning rate should be 0.1.
initial_learning_rate = 0.1 * params['batch_size'] / 128
batches_per_epoch = params['num_train'] / params['batch_size']
# Multiply the learning rate by 0.1 at 100, 150, and 200 epochs.
boundaries = [int(batches_per_epoch * epoch) for epoch in [100, 150, 200]]
values = [initial_learning_rate * decay for decay in [1, 0.1, 0.01, 0.001]]
learning_rate = tf.train.piecewise_constant(
tf.cast(global_step, tf.int32), boundaries, values)
elif params['learning_rate_policy'] == 'poly':
learning_rate = tf.train.polynomial_decay(
params['initial_learning_rate'],
tf.cast(global_step, tf.int32) - params['initial_global_step'],
params['max_iter'], params['end_learning_rate'], power=params['power'])
else:
raise ValueError('Learning rate policy must be "piecewise" or "poly"')
# Create a tensor named learning_rate for logging purposes
tf.identity(learning_rate, name='learning_rate')
tf.summary.scalar('learning_rate', learning_rate)
optimizer = tf.train.MomentumOptimizer(
learning_rate=learning_rate,
momentum=params['momentum'])
# Batch norm requires update ops to be added as a dependency to the train_op
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
train_op = optimizer.minimize(loss, global_step, var_list=train_var_list)
else:
train_op = None
#计算ACC,MEAN_IOU
accuracy = tf.metrics.accuracy(
valid_labels, valid_preds)
mean_iou = tf.metrics.mean_iou(valid_labels, valid_preds, params['num_classes'])
metrics = {'px_accuracy': accuracy, 'mean_iou': mean_iou}
# Create a tensor named train_accuracy for logging purposes
tf.identity(accuracy[1], name='train_px_accuracy')
tf.summary.scalar('train_px_accuracy', accuracy[1])
def compute_mean_iou(total_cm, name='mean_iou'):
"""Compute the mean intersection-over-union via the confusion matrix."""
sum_over_row = tf.to_float(tf.reduce_sum(total_cm, 0))
sum_over_col = tf.to_float(tf.reduce_sum(total_cm, 1))
cm_diag = tf.to_float(tf.diag_part(total_cm))
denominator = sum_over_row + sum_over_col - cm_diag
# The mean is only computed over classes that appear in the
# label or prediction tensor. If the denominator is 0, we need to
# ignore the class.
num_valid_entries = tf.reduce_sum(tf.cast(
tf.not_equal(denominator, 0), dtype=tf.float32))
# If the value of the denominator is 0, set it to 1 to avoid
# zero division.
denominator = tf.where(
tf.greater(denominator, 0),
denominator,
tf.ones_like(denominator))
iou = tf.div(cm_diag, denominator)
for i in range(params['num_classes']):
tf.identity(iou[i], name='train_iou_class{}'.format(i))
tf.summary.scalar('train_iou_class{}'.format(i), iou[i])
# If the number of valid entries is 0 (no classes) we return 0.
result = tf.where(
tf.greater(num_valid_entries, 0),
tf.reduce_sum(iou, name=name) / num_valid_entries,
0)
return result
train_mean_iou = compute_mean_iou(mean_iou[1])
tf.identity(train_mean_iou, name='train_mean_iou')
tf.summary.scalar('train_mean_iou', train_mean_iou)
return tf.estimator.EstimatorSpec(
mode=mode,
predictions=predictions,
loss=loss,
train_op=train_op,
eval_metric_ops=metrics)
之前都是在配置训练的参数,配置网络的输出和输入,接下来要真正开始构造DeeplabV3+了:
def deeplab_v3_plus_generator(num_classes,
output_stride,
base_architecture,
pre_trained_model,
batch_norm_decay,
data_format='channels_last'):
"""Generator for DeepLab v3 plus models.
Args:
num_classes: The number of possible classes for image classification.
output_stride: The ResNet unit's stride. Determines the rates for atrous convolution.
the rates are (6, 12, 18) when the stride is 16, and doubled when 8.
base_architecture: The architecture of base Resnet building block.
pre_trained_model: The path to the directory that contains pre-trained models.
batch_norm_decay: The moving average decay when estimating layer activation
statistics in batch normalization.
data_format: The input format ('channels_last', 'channels_first', or None).
If set to None, the format is dependent on whether a GPU is available.
Only 'channels_last' is supported currently.
Returns:
The model function that takes in `inputs` and `is_training` and
returns the output tensor of the DeepLab v3 model.
"""
if data_format is None:
# data_format = (
# 'channels_first' if tf.test.is_built_with_cuda() else 'channels_last')
pass
if batch_norm_decay is None:
batch_norm_decay = _BATCH_NORM_DECAY
if base_architecture not in ['resnet_v2_50', 'resnet_v2_101']:
raise ValueError("'base_architrecture' must be either 'resnet_v2_50' or 'resnet_v2_101'.")
#使用的基本还是resnet,只是在输出端进行了改变,增加了encoder-decoder网络
if base_architecture == 'resnet_v2_50':
base_model = resnet_v2.resnet_v2_50
else:
base_model = resnet_v2.resnet_v2_101
def model(inputs, is_training):
"""Constructs the ResNet model given the inputs."""
if data_format == 'channels_first':
# Convert the inputs from channels_last (NHWC) to channels_first (NCHW).
# This provides a large performance boost on GPU. See
# https://www.tensorflow.org/performance/performance_guide#data_formats
inputs = tf.transpose(inputs, [0, 3, 1, 2])
# tf.logging.info('net shape: {}'.format(inputs.shape))
# encoder
with tf.contrib.slim.arg_scope(resnet_v2.resnet_arg_scope(batch_norm_decay=batch_norm_decay)):
#resnet_v2的返回值是net, end_points
logits, end_points = base_model(inputs,
num_classes=None,
is_training=is_training,
global_pool=False,
output_stride=output_stride)
if is_training:
exclude = [base_architecture + '/logits', 'global_step']
variables_to_restore = tf.contrib.slim.get_variables_to_restore(exclude=exclude)
tf.train.init_from_checkpoint(pre_trained_model,
{v.name.split(':')[0]: v for v in variables_to_restore})
inputs_size = tf.shape(inputs)[1:3]
net = end_points[base_architecture + '/block4']
encoder_output = atrous_spatial_pyramid_pooling(net, output_stride, batch_norm_decay, is_training)
with tf.variable_scope("decoder"):
with tf.contrib.slim.arg_scope(resnet_v2.resnet_arg_scope(batch_norm_decay=batch_norm_decay)):
with arg_scope([layers.batch_norm], is_training=is_training):
with tf.variable_scope("low_level_features"):
low_level_features = end_points[base_architecture + '/block1/unit_3/bottleneck_v2/conv1']
low_level_features = layers_lib.conv2d(low_level_features, 48,
[1, 1], stride=1, scope='conv_1x1')
low_level_features_size = tf.shape(low_level_features)[1:3]
with tf.variable_scope("upsampling_logits"):
net = tf.image.resize_bilinear(encoder_output, low_level_features_size, name='upsample_1')
net = tf.concat([net, low_level_features], axis=3, name='concat')
net = layers_lib.conv2d(net, 256, [3, 3], stride=1, scope='conv_3x3_1')
net = layers_lib.conv2d(net, 256, [3, 3], stride=1, scope='conv_3x3_2')
net = layers_lib.conv2d(net, num_classes, [1, 1], activation_fn=None, normalizer_fn=None, scope='conv_1x1')
logits = tf.image.resize_bilinear(net, inputs_size, name='upsample_2')
return logits
#实际上返回了一个函数
return model