目标检测笔记No.6 一行代码背后,寻DETR中的测试过程中的边框

目标检测笔记No.6 一行代码背后,寻DETR中的测试过程中的边框

  • 第一部分
    • evaluate 函数
    • postprocessors
    • PostProcess()
    • 串一遍
  • 第二部分
    • coco api 调用 引子
    • CocoEvaluator
  • 总结

首先,来说说我的指代的代码在detr源码项目中 main.py中:

test_stats, coco_evaluator = evaluate(model, criterion, postprocessors,
                                              data_loader_val, base_ds, device, args.output_dir)

起因是出于想要看看训练好的模型,想要可视化一下图片的检测效果,结果以为程序封装之后所生成的bbox属于函数中间变量,并没有输出。

第一部分

evaluate 函数

位于主目录下 engine.py中,
源码github链接
摘抄一些代码

@torch.no_grad() # 测试不需要梯度反向传播
def evaluate(model, criterion, postprocessors, data_loader, base_ds, device, output_dir):
    model.eval() 
    criterion.eval()
    ...
    # 加载数据集
    for samples, targets in metric_logger.log_every(data_loader, 10, header):
	    samples = samples.to(device)
	    targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
	    # 模型输出,注意后续还有操作
	    outputs = model(samples)
	    loss_dict = criterion(outputs, targets)
	    weight_dict = criterion.weight_dict
	    ...
	    results = postprocessors['bbox'](outputs, orig_target_sizes) # 这个关键唉
		...
		res = {target['image_id'].item(): output for target, output in zip(targets, results)}
	 	

postprocessors

由于使用postprocessors,它的定义是在models/conditional_detr.pybuild(args)

def build(args):
	...
	postprocessors = {'bbox': PostProcess()} # 好像又有一个调用
	...
	return model, criterion, postprocessors    

全部代码如下 github地址

PostProcess()

源代码说明这个函数是将output转换位coco api的格式

class PostProcess(nn.Module):
    """ This module converts the model's output into the format expected by the coco api"""
    @torch.no_grad()
    def forward(self, outputs, target_sizes):
        """ Perform the computation
        Parameters:
            outputs: raw outputs of the model
            target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch
                          For evaluation, this must be the original image size (before any data augmentation)
                          For visualization, this should be the image size after data augment, but before padding
        """
        out_logits, out_bbox = outputs['pred_logits'], outputs['pred_boxes']

        assert len(out_logits) == len(target_sizes)
        assert target_sizes.shape[1] == 2

        prob = out_logits.sigmoid()
        topk_values, topk_indexes = torch.topk(prob.view(out_logits.shape[0], -1), 100, dim=1)
        scores = topk_values
        topk_boxes = topk_indexes // out_logits.shape[2]
        labels = topk_indexes % out_logits.shape[2]
        boxes = box_ops.box_cxcywh_to_xyxy(out_bbox)
        boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1,1,4))
        
        # and from relative [0, 1] to absolute [0, height] coordinates
        img_h, img_w = target_sizes.unbind(1)
        scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
        boxes = boxes * scale_fct[:, None, :]

        results = [{'scores': s, 'labels': l, 'boxes': b} for s, l, b in zip(scores, labels, boxes)]

        return results

感觉要差不多了,就要达到目的了。

串一遍

 results = postprocessors['bbox'](outputs, orig_target_sizes) # 调用PostProcess()
 # 等同于
 results = PostProcess(outputs, orig_target_sizes)

results包含的就是输出的分数,标签和边框,用一个字典进行存储。
串串香~

第二部分

coco api 调用 引子

engine.py 中 evaluate 函数

@torch.no_grad()
def evaluate(model, criterion, postprocessors, data_loader, base_ds, device, output_dir):
	coco_evaluator = CocoEvaluator(base_ds, iou_types)
	...
	if coco_evaluator is not None:
	    coco_evaluator.update(res)
	...
    if coco_evaluator is not None:
        coco_evaluator.synchronize_between_processes()
    ...
    # accumulate predictions from all images
    if coco_evaluator is not None:
        coco_evaluator.accumulate()
        coco_evaluator.summarize()
	...
    if coco_evaluator is not None:
        if 'bbox' in postprocessors.keys():
            stats['coco_eval_bbox'] = coco_evaluator.coco_eval['bbox'].stats.tolist()
    ...
    return stats, coco_evaluator

CocoEvaluator

datasets/coco_eval.py 中 CocoEvaluator
未完待续…

总结

这一过程可谓剥洋葱般层层寻觅,说多了都是泪。耗时一天,望大佬指点。因为是在看目标检测这部分,所以分割的代码就不讲了。各位客官若有意见或者建议可留言交流一下。
整体程序看起来复杂原因可能是既有分割也有检测的功能。
接下来,就是将其导入绘图程序,期待后续。COCO api 简单说明,最近也看看。

你可能感兴趣的:(目标检测或识别,目标检测,pytorch,计算机视觉)