CV+Deep Learning——网络架构Pytorch复现系列——Detection(一:SSD:Single Shot MultiBox Detector 4.推理Detect)

上一话

CV+Deep Learning——网络架构Pytorch复现系列——Detection(一:SSD:Single Shot MultiBox Detector 3.loss)https://blog.csdn.net/XiaoyYidiaodiao/article/details/127535159?spm=1001.2014.3001.5501


复现Object Detection,会复现的网络架构有:

1.SSD: Single Shot MultiBox Detector(√)

2.RetinaNet

3.Faster RCNN

4.YOLO系列

....

代码:

https://github.com/HanXiaoyiGitHub/Simple-CV-Pytorch-masterhttps://github.com/HanXiaoyiGitHub/Simple-CV-Pytorch-master


1.复现SSD

1.4推理

推理阶段比较好理解,代码如下

代码

import torch
from torchvision.ops import nms
from torch.autograd import Function
from models.detection.SSD.utils.box_utils import decode


class Detect(Function):
    """At test time, Detect is the final layer of SSD.  Decode location preds,
    apply non-maximum suppression to location predictions based on conf
    scores and threshold to a top_k number of output predictions for both
    confidence score and locations.
    """

    def __init__(self, num_classes, top_k, conf_thresh, nms_thresh, bkg_label=0):
        self.num_classes = num_classes
        self.top_k = top_k
        self.conf_thresh = conf_thresh
        # Parameters used in nms.
        self.nms_thresh = nms_thresh
        if nms_thresh <= 0:
            raise ValueError('nms_threshold must be non negative.')
        self.background_label = bkg_label
        self.variance = [0.1, 0.2]

    def forward(self, loc_data, conf_data, prior_data):
        """
        Args:
            loc_data: (tensor) Loc preds from loc layers
                Shape: [batch,num_priors,4]
            conf_data: (tensor) Shape: Conf preds from conf layers
                Shape: [batch,num_priors,num_classes]
            prior_data: (tensor) Prior boxes and variances from priorbox layers
                Shape: [num_priors,4]
        """
        num = loc_data.size(0)  # batch size
        num_priors = prior_data.size(0)  # 8732
        output = torch.zeros(num, self.num_classes, self.top_k, 5)
        conf_preds = conf_data.view(num, num_priors,
                                    self.num_classes).transpose(2, 1)

        # Decode predictions into bboxes.
        for i in range(num):
            decoded_boxes = decode(loc_data[i], prior_data, self.variance)
            # For each class, perform nms
            conf_scores = conf_preds[i].clone()
            for cl in range(1, self.num_classes):
                c_mask = conf_scores[cl].gt(self.conf_thresh)
                scores = conf_scores[cl][c_mask]
                if scores.size(0) == 0:
                    continue
                l_mask = c_mask.unsqueeze(1).expand_as(decoded_boxes)
                boxes = decoded_boxes[l_mask].view(-1, 4)
                # idx of highest scoring and non-overlapping boxes per class
                ids = nms(boxes, scores, self.nms_thresh)
                if len(ids) < self.top_k:
                    output[i, cl, :len(ids)] = torch.cat((scores[ids].unsqueeze(1), boxes[ids]), 1)
                else:
                    output[i, cl, :self.top_k] = torch.cat((scores[ids[:self.top_k]].unsqueeze(1),
                                                            boxes[ids[:self.top_k]]), 1)
        flt = output.contiguous().view(num, -1, 5)
        _, idx = flt[:, :, 0].sort(1, descending=True)
        _, rank = idx.sort(1)
        flt[(rank < self.top_k).unsqueeze(-1).expand_as(flt)].fill_(0)
        return output


if __name__ == '__main__':
    from options.detection.SSD.train_options import cfg

    detect = Detect(num_classes=cfg['DATA']['NUM_CLASSES'],
                    top_k=cfg['TEST']['TOP_K'],
                    conf_thresh=cfg['TEST']['CONF_THRESH'],
                    nms_thresh=cfg['TEST']['NMS_THRESH'])
    loc_data = torch.randn(16, 8732, 4)
    conf_data = torch.randn(16, 8732, 21)
    prior_data = torch.randn(8732, 4)

    out = detect.forward(loc_data, conf_data, prior_data)
    print('Detect output shape:', out.shape)

结果

Detect output shape: torch.size([16, 81, 150, 5])

下一话

CV+Deep Learning——网络架构Pytorch复现系列——Detection(二:RtinaNet)更换backbonesicon-default.png?t=MBR7https://blog.csdn.net/XiaoyYidiaodiao/article/details/128689730?csdn_share_tail=%7B%22type%22%3A%22blog%22%2C%22rType%22%3A%22article%22%2C%22rId%22%3A%22128689730%22%2C%22source%22%3A%22XiaoyYidiaodiao%22%7D


参考文献

[1] Liu W, Anguelov D, Erhan D, et al. Ssd: Single shot multibox detector[C]//European conference on computer vision. Springer, Cham, 2016: 21-37.

你可能感兴趣的:(深度学习,pytorch,人工智能,计算机视觉,目标检测)