Faster Rcnn 代码解读之 test.py

# --------------------------------------------------------
# Tensorflow Faster R-CNN
# Licensed under The MIT License [see LICENSE for details]
# Written by Xinlei Chen
# --------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import cv2
import numpy as np

try:
    import cPickle as pickle
except ImportError:
    import pickle
import os
import math

from utils.timer import Timer
from utils.blob import im_list_to_blob

from model.config import cfg, get_output_dir
from model.bbox_transform import clip_boxes, bbox_transform_inv
from model.nms_wrapper import nms

'''用于测试的时候调用模型进行测试,在tools/test_net.py调用'''


def _get_image_blob(im):
    """Converts an image into a network input.
    将图片转换成网络的输入格式
    Arguments:
      im (ndarray): a color image in BGR order
    Returns:
      blob (ndarray): a data blob holding an image pyramid
      im_scale_factors (list): list of image scales (relative to im) used
        in the image pyramid
    """

    im_orig = im.astype(np.float32, copy=True)
    # 去均值
    im_orig -= cfg.PIXEL_MEANS
    # 提取最大、最小边
    im_shape = im_orig.shape
    im_size_min = np.min(im_shape[0:2])
    im_size_max = np.max(im_shape[0:2])

    processed_ims = []
    im_scale_factors = []

    for target_size in cfg.TEST.SCALES:
        # 缩放比例
        im_scale = float(target_size) / float(im_size_min)
        # Prevent the biggest axis from being more than MAX_SIZE
        if np.round(im_scale * im_size_max) > cfg.TEST.MAX_SIZE:
            im_scale = float(cfg.TEST.MAX_SIZE) / float(im_size_max)
        im = cv2.resize(im_orig, None, None, fx=im_scale, fy=im_scale,
                        interpolation=cv2.INTER_LINEAR)
        im_scale_factors.append(im_scale)
        processed_ims.append(im)

    # Create a blob to hold the input images
    # 将压缩的图像转化成blob,格式为numpy矩阵
    blob = im_list_to_blob(processed_ims)

    return blob, np.array(im_scale_factors)


def _get_blobs(im):  # 将blob封装成一个简单的方法
    """Convert an image and RoIs within that image into network inputs."""
    blobs = {}
    blobs['data'], im_scale_factors = _get_image_blob(im)

    return blobs, im_scale_factors


def _clip_boxes(boxes, im_shape):
    """Clip boxes to image boundaries."""
    # 对超出图像范围的boxes进行裁剪
    # x1 >= 0
    boxes[:, 0::4] = np.maximum(boxes[:, 0::4], 0)
    # y1 >= 0
    boxes[:, 1::4] = np.maximum(boxes[:, 1::4], 0)
    # x2 < im_shape[1]
    boxes[:, 2::4] = np.minimum(boxes[:, 2::4], im_shape[1] - 1)
    # y2 < im_shape[0]
    boxes[:, 3::4] = np.minimum(boxes[:, 3::4], im_shape[0] - 1)
    return boxes


def _rescale_boxes(boxes, inds, scales):
    """Rescale boxes according to image rescaling."""
    # 将图像缩放为原来的大小
    for i in range(boxes.shape[0]):
        boxes[i, :] = boxes[i, :] / scales[int(inds[i])]

    return boxes


# 调用训练好的参数以及网络结构进行训练,返回目标检测的类别和坐标信息
def im_detect(sess, net, im):
    blobs, im_scales = _get_blobs(im)
    assert len(im_scales) == 1, "Only single-image batch implemented"

    im_blob = blobs['data']
    blobs['im_info'] = np.array([im_blob.shape[1], im_blob.shape[2], im_scales[0]], dtype=np.float32)
    # 测试网络
    _, scores, bbox_pred, rois = net.test_image(sess, blobs['data'], blobs['im_info'])
    # boxes:rpn层预测的区域
    boxes = rois[:, 1:5] / im_scales[0]
    # scores:回归层的softmax值
    scores = np.reshape(scores, [scores.shape[0], -1])
    # bbox_pred:回归层的boxes预测坐标
    bbox_pred = np.reshape(bbox_pred, [bbox_pred.shape[0], -1])
    # 默认true
    if cfg.TEST.BBOX_REG:
        # Apply bounding-box regression deltas
        box_deltas = bbox_pred
        # 从偏移量映射回真实坐标 [dx,dy,dw,dh]->[xmin,ymin,xmax,ymax]
        pred_boxes = bbox_transform_inv(boxes, box_deltas)
        pred_boxes = _clip_boxes(pred_boxes, im.shape)
    else:
        # Simply repeat the boxes, once for each class
        pred_boxes = np.tile(boxes, (1, scores.shape[1]))

    return scores, pred_boxes


def apply_nms(all_boxes, thresh):
    """Apply non-maximum suppression to all predicted boxes output by the
    test_net method.
    """
    num_classes = len(all_boxes)
    num_images = len(all_boxes[0])
    # shape:num_classes*num_images
    nms_boxes = [[[] for _ in range(num_images)] for _ in range(num_classes)]
    for cls_ind in range(num_classes):  # 对于每一类
        for im_ind in range(num_images):  # 对于每一类的每一个检测结果应用nms
            dets = all_boxes[cls_ind][im_ind]
            if dets == []:
                continue

            # 保证xmax > xmin , ymax > ymin
            x1 = dets[:, 0]
            y1 = dets[:, 1]
            x2 = dets[:, 2]
            y2 = dets[:, 3]
            scores = dets[:, 4]
            inds = np.where((x2 > x1) & (y2 > y1))[0]
            dets = dets[inds, :]
            if dets == []:
                continue

            keep = nms(dets, thresh)
            if len(keep) == 0:
                continue
            nms_boxes[cls_ind][im_ind] = dets[keep, :].copy()
            # nms_boxes的shape : num_classes*num_images*len(keep)*5
    return nms_boxes


# 主函数
def test_net(sess, net, imdb, weights_filename, max_per_image=100, thresh=0.):
    # 随机数种子
    np.random.seed(cfg.RNG_SEED)
    """Test a Fast R-CNN network on an image database."""
    # 图片数量:test数据集
    num_images = len(imdb.image_index)
    # all detections are collected into:
    #  all_boxes[cls][image] = N x 5 array of detections in
    #  (x1, y1, x2, y2, score)
    all_boxes = [[[] for _ in range(num_images)]
                 for _ in range(imdb.num_classes)]
    # 保存模型的路径
    output_dir = get_output_dir(imdb, weights_filename)
    # timers 定义计时器
    _t = {'im_detect': Timer(), 'misc': Timer()}

    for i in range(num_images):
        im = cv2.imread(imdb.image_path_at(i))

        _t['im_detect'].tic()  # 计时开始
        # 调用im_detect对图像检测,返回类别和坐标信息
        scores, boxes = im_detect(sess, net, im)
        _t['im_detect'].toc()  # 计时结束

        _t['misc'].tic()

        # skip j = 0, because it's the background class
        for j in range(1, imdb.num_classes):
            # 返回在类别j上得分大于阈值的横坐标
            inds = np.where(scores[:, j] > thresh)[0]
            # 满足阈值的行的j类别的得分
            cls_scores = scores[inds, j]
            cls_boxes = boxes[inds, j * 4:(j + 1) * 4]
            # cls_dets=[xmin,ymin,xmax,ymax,scores]
            cls_dets = np.hstack((cls_boxes, cls_scores[:, np.newaxis])) \
                .astype(np.float32, copy=False)
            # nms
            keep = nms(cls_dets, cfg.TEST.NMS)
            cls_dets = cls_dets[keep, :]
            all_boxes[j][i] = cls_dets  # 第i个图像的第j个类的所有检测结果,都保存在cls_dets中了

        # Limit to max_per_image detections *over all classes*
        if max_per_image > 0:
            # num_classes*num_image*xum_classes 所有类别得分
            image_scores = np.hstack([all_boxes[j][i][:, -1]
                                      for j in range(1, imdb.num_classes)])
            # 保留前max_per_image个框
            if len(image_scores) > max_per_image:
                image_thresh = np.sort(image_scores)[-max_per_image]
                for j in range(1, imdb.num_classes):
                    keep = np.where(all_boxes[j][i][:, -1] >= image_thresh)[0]
                    all_boxes[j][i] = all_boxes[j][i][keep, :]
        _t['misc'].toc()

        print('im_detect: {:d}/{:d} {:.3f}s {:.3f}s' \
              .format(i + 1, num_images, _t['im_detect'].average_time,
                      _t['misc'].average_time))
    # 保存
    det_file = os.path.join(output_dir, 'detections.pkl')
    with open(det_file, 'wb') as f:
        pickle.dump(all_boxes, f, pickle.HIGHEST_PROTOCOL)

    print('Evaluating detections')
    # 保存结果文件,并对检测结果调用_do_python_eval()计算AP,MAP
    imdb.evaluate_detections(all_boxes, output_dir)

感谢WYX同志

你可能感兴趣的:(FasterRcnn)