pytorch 实现SSD详细理解 (二)ssd网络

摘要

前面讲解了ssd网络的特征图提取,在ssd中分为二个阶段,测试阶段和训练阶段,测试是通过已经训练好的参数直接来产生回归好的框和分类类别,训练阶段就是通过sdd产生的数据用来训练参数能够进行分类和回归。

multibox函数

这个函数是是用来创建6个特征图框的坐标点和分类类别

def multibox(vgg, extra_layers, cfg, num_classes):
    loc_layers = []
    conf_layers = []
    vgg_source = [21, -2]
    for k, v in enumerate(vgg_source):
        loc_layers += [nn.Conv2d(vgg[v].out_channels,
                                 cfg[k] * 4, kernel_size=3, padding=1)]
        conf_layers += [nn.Conv2d(vgg[v].out_channels,
                        cfg[k] * num_classes, kernel_size=3, padding=1)]
    for k, v in enumerate(extra_layers[1::2], 2):
        loc_layers += [nn.Conv2d(v.out_channels, cfg[k]
                                 * 4, kernel_size=3, padding=1)]
        conf_layers += [nn.Conv2d(v.out_channels, cfg[k]
                                  * num_classes, kernel_size=3, padding=1)]
    return vgg, extra_layers, (loc_layers, conf_layers)

extra_layers[1::2]这一步其实也好理解就是刚好对应之前要输出的四个特征图,[1::2]就是每隔一个输出,来看下具体的输出

print(nn.Sequential(*loc_layers))  这种方式的输出看起来直观
print(nn.Sequential(*conf_layers))

Sequential(
  (0): Conv2d(512, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): Conv2d(1024, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (2): Conv2d(512, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (3): Conv2d(256, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (4): Conv2d(256, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (5): Conv2d(256, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
Sequential(
  (0): Conv2d(512, 84, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): Conv2d(1024, 126, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (2): Conv2d(512, 126, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (3): Conv2d(256, 126, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (4): Conv2d(256, 84, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (5): Conv2d(256, 84, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)

原代码中的multibox函数使用

base_, extras_, head_ = multibox(vgg(base[str(300)], 3),
                                     add_extras(extras[str(300)], 1024),
                                     mbox[str(300)], 21)
                        对应输出
[Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), ReLU(inplace), Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), ReLU(inplace), MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False), Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), ReLU(inplace), Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), ReLU(inplace), MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False), Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), ReLU(inplace), Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), ReLU(inplace), Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), ReLU(inplace), MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=True), Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), ReLU(inplace), Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), ReLU(inplace), Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), ReLU(inplace), MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False), Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), ReLU(inplace), Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), ReLU(inplace), Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), ReLU(inplace), MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False), Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(6, 6), dilation=(6, 6)), ReLU(inplace), Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1)), ReLU(inplace)]
[Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1)), Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)), Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1)), Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)), Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1)), Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1)), Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1)), Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1))]
([Conv2d(512, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), Conv2d(1024, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), Conv2d(512, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), Conv2d(256, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), Conv2d(256, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), Conv2d(256, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))], [Conv2d(512, 84, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), Conv2d(1024, 126, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), Conv2d(512, 126, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), Conv2d(256, 126, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), Conv2d(256, 84, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), Conv2d(256, 84, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))])


SSD网络

class SSD(nn.Module):
    """Single Shot Multibox Architecture
    The network is composed of a base VGG network followed by the
    added multibox conv layers.  Each multibox layer branches into
        1) conv2d for class conf scores
        2) conv2d for localization predictions
        3) associated priorbox layer to produce default bounding
           boxes specific to the layer's feature map size.
    See: https://arxiv.org/pdf/1512.02325.pdf for more details.

    Args:
        phase: (string) Can be "test" or "train"
        size: input image size
        base: VGG16 layers for input, size of either 300 or 500
        extras: extra layers that feed to multibox loc and conf layers
        head: "multibox head" consists of loc and conf conv layers
    """

    def __init__(self, phase, size, base, extras, head, num_classes):
        super(SSD, self).__init__()
        self.phase = phase
        self.num_classes = num_classes
        self.cfg = (coco, voc)[num_classes == 21]
        self.priorbox = PriorBox(self.cfg)
        self.priors = Variable(self.priorbox.forward(), volatile=True)
        self.size = size

        # SSD network
        self.vgg = nn.ModuleList(base)
        # Layer learns to scale the l2 normalized features from conv4_3
        self.L2Norm = L2Norm(512, 20)
        self.extras = nn.ModuleList(extras)

        self.loc = nn.ModuleList(head[0])
        self.conf = nn.ModuleList(head[1])

        if phase == 'test':
            self.softmax = nn.Softmax(dim=-1)
            self.detect = Detect(num_classes, 0, 200, 0.01, 0.45)

     def forward(self, x):
        # 定义forward函数, 将设计好的layers和ops应用到输入图片 x 上

        # 参数: x, 输入的batch 图片, Shape: [batch, 3, 300, 300]

        # 返回值: 取决于不同阶段
        # test: 预测的类别标签, confidence score, 以及相关的location.
        #       Shape: [batch, topk, 7]
        # train: 关于以下输出的元素组成的列表
        #       1: confidence layers, Shape: [batch*num_priors, num_classes]
        #       2: localization layers, Shape: [batch, num_priors*4]
        #       3: priorbox layers, Shape: [2, num_priors*4]
        sources = list() # 这个列表存储的是参与预测的卷积层的输出, 也就是原文中那6个指定的卷积层
        loc = list() # 用于存储预测的边框信息
        conf = list() # 用于存储预测的类别信息

        # 计算vgg直到conv4_3的relu
        for k in range(23):
            x = self.vgg[k](x)

        s = self.L2Norm(x)
        sources.append(s) # 将 conv4_3 的特征层输出添加到 sources 中, 后面会根据 sources 中的元素进行预测

        # 将vgg应用到fc7
        for k in range(23, len(self.vgg)):
            x = self.vgg[k](x)
        sources.append(x) # 同理, 添加到 sources 列表中

        # 计算extras layers, 并且将结果存储到sources列表中
        for k, v in enumerate(self.extras):
            x = F.relu(v(x), inplace=True) # import torch.nn.functional as F
            if k % 2 = 1: # 在extras_layers中, 第1,3,5,7,9(从第0开始)的卷积层的输出会用于预测box位置和类别, 因此, 将其添加到 sources列表中
                sources.append(x)

        # 应用multibox到source layers上, source layers中的元素均为各个用于预测的特征图谱
        # apply multibox to source layers

        # 注意pytorch中卷积层的输入输出维度是:[N×C×H×W]
        for (x, l, c) in zip(sources, self.loc, self.conf):
            # permute重新排列维度顺序, PyTorch维度的默认排列顺序为 (N, C, H, W),
            # 因此, 这里的排列是将其改为 $(N, H, W, C)$.
            # contiguous返回内存连续的tensor, 由于在执行permute或者transpose等操作之后, tensor的内存地址可能不是连续的,
            # 然后 view 操作是基于连续地址的, 因此, 需要调用contiguous语句.
            loc.append(l(x).permute(0,2,3,1).contiguous())
            conf.append(c(x).permute(0,2,3,1).contiguous())
            # loc: [b×w1×h1×4*4, b×w2×h2×6*4, b×w3×h3×6*4, b×w4×h4×6*4, b×w5×h5×4*4, b×w6×h6×4*4]
            # conf: [b×w1×h1×4*C, b×w2×h2×6*C, b×w3×h3×6*C, b×w4×h4×6*C, b×w5×h5×4*C, b×w6×h6×4*C] C为num_classes
        # cat 是 concatenate 的缩写, view返回一个新的tensor, 具有相同的数据但是不同的size, 类似于numpy的reshape
        # 在调用view之前, 需要先调用contiguous
        loc = torch.cat([o.view(o.size(0), -1) for o in loc], 1)
        # 将除batch以外的其他维度合并, 因此, 对于边框坐标来说, 最终的shape为(两维):[batch, num_boxes*4]
        conf = torch.cat([o.view(o.size(0), -1) for o in conf], 1)
        # 同理, 最终的shape为(两维):[batch, num_boxes*num_classes]

        if self.phase == "test":
            # 这里用到了 detect 对象, 该对象主要由于接预测出来的结果进行解析, 以获得方便可视化的边框坐标和类别编号, 具体实现会在后文讨论.
            output = self.detect(
                loc.view(loc.size(0), -1, 4), #  又将shape转换成: [batch, num_boxes, 4], 即[1, 8732, 4]
                self.softmax(conf.view(conf.size(0), -1, self.num_classes)), # 同理,  shape 为[batch, num_boxes, num_classes], 即 [1, 8732, 21]
                self.priors.type(type(x.data))
                # 利用 PriorBox对象获取特征图谱上的 default box, 该参数的shape为: [8732,4]. 关于生成 default box 的方法实际上很简单, 类似于 anchor box, 详细的代码实现会在后文解析.
                # 这里的 self.priors.type(type(x.data)) 与 self.priors 就结果而言完全等价(自己试验过了), 但是为什么?
            )
        if self.phase == "train": # 如果是训练阶段, 则无需解析预测结果, 直接返回然后求损失.
            output = (
                loc.view(loc.size(0), -1, 4), conf.view(conf.size(0), -1, self.num_classes), self.priors
            )
        return output
    def load_weights(self, base_file): # 加载权重文件
        other, ext = os.path.splitext(base_file)
        if ext == ".pkl" or ".pth":
            print("Loading weights into state dict...")
            self.load_state_dict(torch.load(base_file, map_location=lambda storage, loc: storage))
            print("Finished!")
        else:
            print("Sorry only .pth and .pkl files supported")

重点需要看Detect部分,函数整体作用将得到的预测框进行回归使框的位置更准确和分类,在进行nms去除大多数重复度高的框,然后返回分类和框的信息

detect代码

class Detect(Function):
    """At test time, Detect is the final layer of SSD.  Decode location preds,
    apply non-maximum suppression to location predictions based on conf
    scores and threshold to a top_k number of output predictions for both
    confidence score and locations.
    """
    def __init__(self, num_classes, bkg_label, top_k, conf_thresh, nms_thresh):
        self.num_classes = num_classes
        self.background_label = bkg_label
        self.top_k = top_k
        # Parameters used in nms.
        self.nms_thresh = nms_thresh
        if nms_thresh <= 0:
            raise ValueError('nms_threshold must be non negative.')
        self.conf_thresh = conf_thresh
        self.variance = cfg['variance']

    def forward(self, loc_data, conf_data, prior_data):
        """
        Args:
            loc_data: (tensor) Loc preds from loc layers
                Shape: [batch,num_priors*4]
            conf_data: (tensor) Shape: Conf preds from conf layers
                Shape: [batch*num_priors,num_classes]
            prior_data: (tensor) Prior boxes and variances from priorbox layers
                Shape: [1,num_priors,4]
        """
        num = loc_data.size(0)  # batch size
        num_priors = prior_data.size(0)
        output = torch.zeros(num, self.num_classes, self.top_k, 5)  
        conf_preds = conf_data.view(num, num_priors,
                                    self.num_classes).transpose(2, 1)  

        # Decode predictions into bboxes.
        #每一个照片有多个框,现在对每一张照片进行处理
        for i in range(num):
            decoded_boxes = decode(loc_data[i], prior_data, self.variance)   #这一步通过使用decode这个函数使用俩进行框位置回归,得到精准的位置
            # For each class, perform nms
            conf_scores = conf_preds[i].clone() #复制第i个image置信度预测结果
#这里是对1到20的类别进行判定,
            for cl in range(1, self.num_classes):
                c_mask = conf_scores[cl].gt(self.conf_thresh) #返回由0,1组成的数组, 0代表小于thresh, 1代表大于thresh
                scores = conf_scores[cl][c_mask]  #返回值为1的对应下标的元素值(即返回conf_scores中大于thresh的元素集合)
                if scores.size(0) == 0:   #每个照片不可能包含全部的类,所以没有的类别就没有框
                    continue
                l_mask = c_mask.unsqueeze(1).expand_as(decoded_boxes)
                boxes = decoded_boxes[l_mask].view(-1, 4)# 获取置信度大于thresh的box的左上角和右下角坐标
                # idx of highest scoring and non-overlapping boxes per class
                 # 返回每个类别的最高的score 的下标, 并且除去那些与该box有较大交并比的box
                ids, count = nms(boxes, scores, self.nms_thresh, self.top_k)

#output[i, cl, :count]假如count是200吧,此时的形状还是[1,2,200,5]然后通个
#scores[ids[:count]].unsqueeze(1)得到前200的最大得分值.unsqueeze(1)在维度1增加一
#维度和前200的boxes在一维度上合并,假设scores=torch.rand(1,2,300)就是三维,里面的
#数可以变,在这里我认为第一张照片,在第二类别上有300个框大于阙值。boxex是
#torch.rand(300,4),总之通过:count都会减少到一样的数量
                output[i, cl, :count] = \
                    torch.cat((scores[ids[:count]].unsqueeze(1),
                               boxes[ids[:count]]), 1)
        flt = output.contiguous().view(num, -1, 5) #shape是[1,200,5]
        _, idx = flt[:, :, 0].sort(1, descending=True)  #在得分排序从大到小,idx是索引值
        _, rank = idx.sort(1)#索引的索引
        #举个简单例子   上面的值就假设为flt,
#fit=torch.rand(1,10,5)
 #_, idx = fit[:, :, 0].sort(1, descending=True)
#print(fit)
#print(idx)
#_, rank = idx.sort(1)
#print(_)
#print(rank)
   '''     tensor([[[0.4154, 0.9650, 0.4527, 0.1799, 0.7934],
         [0.9431, 0.8766, 0.9846, 0.1867, 0.2736],
         [0.8146, 0.6050, 0.0402, 0.6013, 0.0978],
         [0.0071, 0.0835, 0.8645, 0.3371, 0.0275],
         [0.8402, 0.1735, 0.6369, 0.8216, 0.4283],
         [0.9713, 0.2999, 0.2104, 0.4408, 0.7352],
         [0.7575, 0.5271, 0.8710, 0.7745, 0.9809],
         [0.1472, 0.5359, 0.2581, 0.1935, 0.5572],
         [0.9021, 0.5285, 0.3834, 0.8177, 0.2026],
         [0.1573, 0.0639, 0.7384, 0.4654, 0.1422]]])
         tensor([[5, 1, 8, 4, 2, 6, 0, 9, 7, 3]])  idx
         tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])  _
         tensor([[6, 1, 4, 9, 3, 0, 5, 8, 2, 7]])   rank  这里不好理解,是索引的索引,先看_ 0在idx的第6处,所以rank第一个位置是6,以此类推   '''
        flt[(rank < self.top_k).unsqueeze(-1).expand_as(flt)].fill_(0)#对flt的修改也会反应到output上面
        return output

中间用到的函数decode和nms

def decode(loc, priors, variances):
    """Decode locations from predictions using priors to undo
    the encoding we did for offset regression at train time.
    Args:
        loc (tensor): location predictions for loc layers,
            Shape: [num_priors,4]
        priors (tensor): Prior boxes in center-offset form.
            Shape: [num_priors,4].
        variances: (list[float]) Variances of priorboxes
    Return:
        decoded bounding box predictions
    """

    boxes = torch.cat((
        priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
        priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1)
    boxes[:, :2] -= boxes[:, 2:] / 2
    boxes[:, 2:] += boxes[:, :2]
    return boxes
   
def nms(boxes, scores, overlap=0.5, top_k=200):
    """Apply non-maximum suppression at test time to avoid detecting too many
    overlapping bounding boxes for a given object.
    Args:
        boxes: (tensor) The location preds for the img, Shape: [num_priors,4].
        scores: (tensor) The class predscores for the img, Shape:[num_priors].
        overlap: (float) The overlap thresh for suppressing unnecessary boxes.
        top_k: (int) The Maximum number of box preds to consider.
    Return:
        The indices of the kept boxes with respect to num_priors.
    """

    keep = scores.new(scores.size(0)).zero_().long()
    if boxes.numel() == 0:
        return keep
    x1 = boxes[:, 0]
    y1 = boxes[:, 1]
    x2 = boxes[:, 2]
    y2 = boxes[:, 3]
    area = torch.mul(x2 - x1, y2 - y1)
    v, idx = scores.sort(0)  # sort in ascending order
    # I = I[v >= 0.01]
    idx = idx[-top_k:]  # indices of the top-k largest vals
    xx1 = boxes.new()
    yy1 = boxes.new()
    xx2 = boxes.new()
    yy2 = boxes.new()
    w = boxes.new()
    h = boxes.new()

    # keep = torch.Tensor()
    count = 0
    while idx.numel() > 0:
        i = idx[-1]  # index of current largest val
        # keep.append(i)
        keep[count] = i
        count += 1
        if idx.size(0) == 1:
            break
        idx = idx[:-1]  # remove kept element from view
        # load bboxes of next highest vals
        torch.index_select(x1, 0, idx, out=xx1)
        torch.index_select(y1, 0, idx, out=yy1)
        torch.index_select(x2, 0, idx, out=xx2)
        torch.index_select(y2, 0, idx, out=yy2)
        # store element-wise max with next highest score
        xx1 = torch.clamp(xx1, min=x1[i])
        yy1 = torch.clamp(yy1, min=y1[i])
        xx2 = torch.clamp(xx2, max=x2[i])
        yy2 = torch.clamp(yy2, max=y2[i])
        w.resize_as_(xx2)
        h.resize_as_(yy2)
        w = xx2 - xx1
        h = yy2 - yy1
        # check sizes of xx1 and xx2.. after each iteration
        w = torch.clamp(w, min=0.0)
        h = torch.clamp(h, min=0.0)
        inter = w*h
        # IoU = i / (area(a) + area(b) - i)
        rem_areas = torch.index_select(area, 0, idx)  # load remaining areas)
        union = (rem_areas - inter) + area[i]
        IoU = inter/union  # store result in iou
        # keep only elements with an IoU <= overlap
        idx = idx[IoU.le(overlap)]
    return keep, count


到现在的话可以看到整个ssd网络,在源代码中的通过一个函数来完成特征值的提取和ssd的构造

def build_ssd(phase, size=300, num_classes=21):
    if phase != "test" and phase != "train":
        print("ERROR: Phase: " + phase + " not recognized")
        return
    if size != 300:
        print("ERROR: You specified size " + repr(size) + ". However, " +
              "currently only SSD300 (size=300) is supported!")
        return
    base_, extras_, head_ = multibox(vgg(base[str(size)], 3),
                                     add_extras(extras[str(size)], 1024),
                                     mbox[str(size)], num_classes)
    return SSD(phase, size, base_, extras_, head_, num_classes)

基本所有目标检测的detect部分都是差不多的,不同的是特征图的提取部分和产生框的高和宽。大部分研究是都是针对一种数据集,其实对于每一种数据集都可以设计出一类高和宽的产生规则来更好的适应数据集,

你可能感兴趣的:(pytorch 实现SSD详细理解 (二)ssd网络)