[代码解读]Superpoint代码解读

代码地址(非官方公开,他人复现)

文件目录

[代码解读]Superpoint代码解读_第1张图片
根据作者的介绍,安装好相关的库,生成数据文件目录和实验记录的目录以后,直接进入**/superpoint**路径看,这个路径下有model,datasets,config等文件目录,分别存放了相关的网络,数据加载与预处理,配置文件等.

训练

训练网络的程序接口是experiment.py文件.输入的时候需要包含模型名和数据集名(按照作者介绍的运行就行).
运行前,大多数时候需要确认下数据模式,即data_format:,默认是’channels_first’

    default_config = {
            'data_format': 'channels_first',   # 这里显示的是magic_point网络在函数中写的默认参数
            'kernel_reg': 0.,
            'grid_size': 8,
            'detection_threshold': 0.4,
            'homography_adaptation': {'num': 0},
            'nms': 0,
            'top_k': 0

如果是’channels_last’,需要在配置文件中加入:

data_format: 'channels_last'

每次运行前都需要先:

export TMPDIR=/tmp/

所以可以写成一个脚本xx.sh,示例内容如下:

#!/bin/sh
export TMPDIR=/tmp/
python3 experiment.py train configs/magic-point_shapes.yaml magic-point_synth  # 根据具体的训练任务更改

然后要确保运行是在superpoint路径下的termianl的,所以,将该xx.sh文件扔到该目录下.

代码分析

进入正题,开始分析下该项目的代码...

程序接口

进入experiment.py函数:

  1. 加载并解析命令行参数
	subparsers = parser.add_subparsers(dest='command') 
    # Training command
    p_train = subparsers.add_parser('train')
    p_train.add_argument('config', type=str)
    p_train.add_argument('exper_name', type=str)
    p_train.add_argument('--eval', action='store_true')
    p_train.set_defaults(func=_cli_train)
    
    # Evaluation command
    p_train = subparsers.add_parser('evaluate')
    p_train.add_argument('config', type=str)
    p_train.add_argument('exper_name', type=str)
    p_train.set_defaults(func=_cli_eval)

    # Inference command
    p_train = subparsers.add_parser('predict')
    p_train.add_argument('config', type=str)
    p_train.add_argument('exper_name', type=str)
    p_train.set_defaults(func=_cli_pred)
    # 解析参数
    args = parser.parse_args()
  1. 加载配置文件
  2. 设定输出位置,就是EXPER_DIR的路径
  3. 调用函数,执行程序
    args.func(config, output_dir, args)
  1. 训练
def _cli_train(config, output_dir, args):
	# 保存参数
 # 训练train()
 # 评估:_cli_eval(config, output_dir, args)

--->

def train(config, n_iter, output_dir, checkpoint_name='model.ckpt'):
	# 提取当前的模型节点
	# 在_init_graph中加载模型和数据类对象,将模型类对象返回
	# 调用模型的train()成员函数(在/model/base_model.py中)进行真正的训练
    # 保存当前的模型参数

模型部分

转到model/目录下,
base_model.py是基础的模型类.其中初始化的时候定义了相关参数,构建了所需的计算图(tensorflow第一步):

__init()__--->_build_grph()--->_train_graph()/_eval_graph()/_pred_graph

而具体的执行操作(tensorflow第二步),是在 train(),predict(),evaluate() 中去实现的,这三个函数是由程序的接口(experiment.py)可以直接调用的.

回到计算图的构建中,在 _train_graph()/_eval_graph()/_pred_graph 首先就是调用 _gpu_tower(),在其中进行了相关计算的定义,主要包括网络输出,损失函数的计算,

 with tf.device(device_setter):
                    net_outputs = self._model(shards[i], mode, **self.config)  # 网络输出
                    if mode == Mode.TRAIN:
                        loss = self._loss(net_outputs, shards[i], **self.config)  # 损失函数计算
                        loss += tf.reduce_sum(
                                tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES,
                                                  scope))
                        model_params = tf.trainable_variables()
                        grad = tf.gradients(loss, model_params)
                        tower_losses.append(loss)
                        tower_gradvars.append(zip(grad, model_params))
                        if i == 0:
                            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                                           scope)
                    elif mode == Mode.EVAL:
                        tower_metrics.append(self._metrics(  # 模型的评估
                            net_outputs, shards[i], **self.config))
                    else:
                        tower_preds.append(net_outputs)

其中, _model(), _loss(), _metrics() 是每个网络单独实现的部分.例如:

 #_model()定义了网络结构
    def _model(self, inputs, mode, **config):
        config['training'] = (mode == Mode.TRAIN)
        image = inputs['image']

        def net(image):
            if config['data_format'] == 'channels_first':
                image = tf.transpose(image, [0, 3, 1, 2])
            features = vgg_backbone(image, **config)
            outputs = detector_head(features, **config)
            return outputs

        if (mode == Mode.PRED) and config['homography_adaptation']['num']:
            outputs = homography_adaptation(image, net, config['homography_adaptation'])
        else:
            outputs = net(image)

        prob = outputs['prob']
        if config['nms']:
            prob = tf.map_fn(lambda p: box_nms(p, config['nms'],
                                               min_prob=config['detection_threshold'],
                                               keep_top_k=config['top_k']), prob)
            outputs['prob_nms'] = prob
        pred = tf.to_int32(tf.greater_equal(prob, config['detection_threshold']))
        outputs['pred'] = pred

        return outputs

网络的输出是一个score_map,需要进行nms以及以一定的阈值进行滤除.最终生成预测的关键点的map.

 # _loss()定义了损失函数的计算
    def _loss(self, outputs, inputs, **config):
        if config['data_format'] == 'channels_first':
            outputs['logits'] = tf.transpose(outputs['logits'], [0, 2, 3, 1])
        return detector_loss(inputs['keypoint_map'], outputs['logits'],
                             valid_mask=inputs['valid_mask'], **config)

在损失函数计算中,调用了utils.py中的detector_loss,(这个文件中就只有网络的两个head,两个损失函数,还有box_nms),在计算损失函数的时候先将label调整到网络的直接输出的维度,用直接输出的结果(logits)计算loss,这样不易引起梯度的异常和不可传导.

    # _metrics()定义了评测的计算
    def _metrics(self, outputs, inputs, **config):
        pred = outputs['pred']
        labels = inputs['keypoint_map']

        precision = tf.reduce_sum(pred*labels) / tf.reduce_sum(pred)
        recall = tf.reduce_sum(pred*labels) / tf.reduce_sum(labels)

        return {'precision': precision, 'recall': recall}

也就是说这些相关操作的相同部分在base_model.py里,不同部分在各自的网络模型中.

在模型计算过程中,所需要的相应的几何操作在homographies.py中.

你可能感兴趣的:(特征点检测)