多目标检测与识别 YOLOV3 解读四 反算

YOLOV3的反算

1.1 怎么造的样本

前面已经讲了怎么造的样本,我们这里回顾一下。
1.标注目标位置(xmin,ymin,xmax,ymax),并算出中心点坐标、宽高(cx,cy,w,h)
2.中心点坐标除以图片的宽高(cx/size(0),cy/size(2),w/size(0),h/size(1))然后resize(416,416)
3.把(cx/size(0),cy/size(2),w/size(0),h/size(1))所有值乘以416,得到在缩放图片下中心点宽高的位置。
4.把中心点除以缩放比例,整数部分得到中心点在特征图的位置,小数得到中心点相对于网格左上角的偏移。
5.相对位置用真实框和建议框的宽高比表示。并且宽高比前要加一个log,置信度用真实框和建议框的iou表示。

1.2 反算回原图

简单说怎么制造的样本就怎么反算回去。
先上图
多目标检测与识别 YOLOV3 解读四 反算_第1张图片
上图是造样本所用的图,实现框是真实框,虚线框是建议框。tx、ty是中心点相当于中心点所在网格左上角的偏移,然后用sigmiod激活。cx,cy是中心点所在特征的位置把他们加起来,再乘以缩放比例就是原图中心点的位置。同理宽高业是怎么造的数据就怎么反算回去。

import cfg
from darknet53 import *
from utils import *
import torch

device = torch.device(cfg.DEVICE)


class Detector(torch.nn.Module):

    def __init__(self):
        super(Detector, self).__init__()

        self.net = MainNet(cfg.CLASS_NUM).to(device)
        self.net.load_state_dict(torch.load('weights/darknet53.pt'))
        self.net.eval()

    def forward(self, input, thresh, anchors):

        output_13, output_26, output_52 = self.net(input.to(device))

        idxs_13, vecs_13 = self._filter(output_13, thresh)
        boxes_13 = self._parse(idxs_13, vecs_13, 32, anchors[13])

        idxs_26, vecs_26 = self._filter(output_26, thresh)
        boxes_26 = self._parse(idxs_26, vecs_26, 16, anchors[26])

        idxs_52, vecs_52 = self._filter(output_52, thresh)
        boxes_52 = self._parse(idxs_52, vecs_52, 8, anchors[52])

        boxes_all = torch.cat([boxes_13, boxes_26, boxes_52], dim=0)

        last_boxes = []
        for n in range(input.size(0)):
            n_boxes = []
            boxes_n = boxes_all[boxes_all[:, 6] == n]
            for cls in range(cfg.CLASS_NUM):
                boxes_c = boxes_n[boxes_n[:, 5] == cls]
                if boxes_c.size(0) > 0:
                    n_boxes.extend(nms(boxes_c, 0.3))
                else:
                    pass
            last_boxes.append(torch.stack(n_boxes))
        return last_boxes

    def _filter(self, output, thresh):
        output = output.permute(0, 2, 3, 1)
        output = output.reshape(output.size(0), output.size(1), output.size(2), 3, -1)

        output = output.cpu()

        torch.sigmoid_(output[..., 4])  # 置信度加sigmoid激活
        torch.sigmoid_(output[..., 0:2])  # 中心点加sigmoid激活

        # 在计算置信度损失的时候使用的sigmoid
        mask = output[..., 4] > thresh
        idxs = mask.nonzero()
        vecs = output[mask]
        # print(vecs[..., 4])
        return idxs, vecs

    def _parse(self, idxs, vecs, t, anchors):
        if idxs.size(0) == 0:
            return torch.Tensor([])

        anchors = torch.Tensor(anchors)

        n = idxs[:, 0]  # 所属的图片
        a = idxs[:, 3]  # 建议框
        c = vecs[:, 4]  # 置信度

        cls = torch.argmax(vecs[:, 5:], dim=1)

        cy = (idxs[:, 1].float() + vecs[:, 1]) * t  # 原图的中心点y
        cx = (idxs[:, 2].float() + vecs[:, 0]) * t  # 原图的中心点x

        w = anchors[a, 0] * torch.exp(vecs[:, 2])
        h = anchors[a, 1] * torch.exp(vecs[:, 3])

        w0_5, h0_5 = w / 2, h / 2
        x1, y1, x2, y2 = cx - w0_5, cy - h0_5, cx + w0_5, cy + h0_5

        return torch.stack([x1, y1, x2, y2, c, cls.float(), n.float()], dim=1)

你可能感兴趣的:(深度学习)