代码地址(非官方公开,他人复现)
根据作者的介绍,安装好相关的库,生成数据文件目录和实验记录的目录以后,直接进入**/superpoint**路径看,这个路径下有model,datasets,config等文件目录,分别存放了相关的网络,数据加载与预处理,配置文件等.
训练网络的程序接口是experiment.py文件.输入的时候需要包含模型名和数据集名(按照作者介绍的运行就行).
运行前,大多数时候需要确认下数据模式,即data_format:,默认是’channels_first’
default_config = {
'data_format': 'channels_first', # 这里显示的是magic_point网络在函数中写的默认参数
'kernel_reg': 0.,
'grid_size': 8,
'detection_threshold': 0.4,
'homography_adaptation': {'num': 0},
'nms': 0,
'top_k': 0
如果是’channels_last’,需要在配置文件中加入:
data_format: 'channels_last'
每次运行前都需要先:
export TMPDIR=/tmp/
所以可以写成一个脚本xx.sh,示例内容如下:
#!/bin/sh
export TMPDIR=/tmp/
python3 experiment.py train configs/magic-point_shapes.yaml magic-point_synth # 根据具体的训练任务更改
然后要确保运行是在superpoint路径下的termianl的,所以,将该xx.sh文件扔到该目录下.
进入正题,开始分析下该项目的代码...
进入experiment.py函数:
subparsers = parser.add_subparsers(dest='command')
# Training command
p_train = subparsers.add_parser('train')
p_train.add_argument('config', type=str)
p_train.add_argument('exper_name', type=str)
p_train.add_argument('--eval', action='store_true')
p_train.set_defaults(func=_cli_train)
# Evaluation command
p_train = subparsers.add_parser('evaluate')
p_train.add_argument('config', type=str)
p_train.add_argument('exper_name', type=str)
p_train.set_defaults(func=_cli_eval)
# Inference command
p_train = subparsers.add_parser('predict')
p_train.add_argument('config', type=str)
p_train.add_argument('exper_name', type=str)
p_train.set_defaults(func=_cli_pred)
# 解析参数
args = parser.parse_args()
args.func(config, output_dir, args)
def _cli_train(config, output_dir, args):
# 保存参数
# 训练train()
# 评估:_cli_eval(config, output_dir, args)
--->
def train(config, n_iter, output_dir, checkpoint_name='model.ckpt'):
# 提取当前的模型节点
# 在_init_graph中加载模型和数据类对象,将模型类对象返回
# 调用模型的train()成员函数(在/model/base_model.py中)进行真正的训练
# 保存当前的模型参数
转到model/目录下,
base_model.py是基础的模型类.其中初始化的时候定义了相关参数,构建了所需的计算图(tensorflow第一步):
__init()__--->_build_grph()--->_train_graph()/_eval_graph()/_pred_graph
而具体的执行操作(tensorflow第二步),是在 train(),predict(),evaluate() 中去实现的,这三个函数是由程序的接口(experiment.py)可以直接调用的.
回到计算图的构建中,在 _train_graph()/_eval_graph()/_pred_graph 首先就是调用 _gpu_tower(),在其中进行了相关计算的定义,主要包括网络输出,损失函数的计算,
with tf.device(device_setter):
net_outputs = self._model(shards[i], mode, **self.config) # 网络输出
if mode == Mode.TRAIN:
loss = self._loss(net_outputs, shards[i], **self.config) # 损失函数计算
loss += tf.reduce_sum(
tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES,
scope))
model_params = tf.trainable_variables()
grad = tf.gradients(loss, model_params)
tower_losses.append(loss)
tower_gradvars.append(zip(grad, model_params))
if i == 0:
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
scope)
elif mode == Mode.EVAL:
tower_metrics.append(self._metrics( # 模型的评估
net_outputs, shards[i], **self.config))
else:
tower_preds.append(net_outputs)
其中, _model(), _loss(), _metrics() 是每个网络单独实现的部分.例如:
#_model()定义了网络结构
def _model(self, inputs, mode, **config):
config['training'] = (mode == Mode.TRAIN)
image = inputs['image']
def net(image):
if config['data_format'] == 'channels_first':
image = tf.transpose(image, [0, 3, 1, 2])
features = vgg_backbone(image, **config)
outputs = detector_head(features, **config)
return outputs
if (mode == Mode.PRED) and config['homography_adaptation']['num']:
outputs = homography_adaptation(image, net, config['homography_adaptation'])
else:
outputs = net(image)
prob = outputs['prob']
if config['nms']:
prob = tf.map_fn(lambda p: box_nms(p, config['nms'],
min_prob=config['detection_threshold'],
keep_top_k=config['top_k']), prob)
outputs['prob_nms'] = prob
pred = tf.to_int32(tf.greater_equal(prob, config['detection_threshold']))
outputs['pred'] = pred
return outputs
网络的输出是一个score_map,需要进行nms以及以一定的阈值进行滤除.最终生成预测的关键点的map.
# _loss()定义了损失函数的计算
def _loss(self, outputs, inputs, **config):
if config['data_format'] == 'channels_first':
outputs['logits'] = tf.transpose(outputs['logits'], [0, 2, 3, 1])
return detector_loss(inputs['keypoint_map'], outputs['logits'],
valid_mask=inputs['valid_mask'], **config)
在损失函数计算中,调用了utils.py中的detector_loss,(这个文件中就只有网络的两个head,两个损失函数,还有box_nms),在计算损失函数的时候先将label调整到网络的直接输出的维度,用直接输出的结果(logits)计算loss,这样不易引起梯度的异常和不可传导.
# _metrics()定义了评测的计算
def _metrics(self, outputs, inputs, **config):
pred = outputs['pred']
labels = inputs['keypoint_map']
precision = tf.reduce_sum(pred*labels) / tf.reduce_sum(pred)
recall = tf.reduce_sum(pred*labels) / tf.reduce_sum(labels)
return {'precision': precision, 'recall': recall}
也就是说这些相关操作的相同部分在base_model.py里,不同部分在各自的网络模型中.
在模型计算过程中,所需要的相应的几何操作在homographies.py中.