yolo_v4模型解析(head部分)

1.backbone部分
2.neck部分
3.head部分

最关键的部分来了~~~
yolo_v4模型解析(head部分)_第1张图片
yolo_v4模型解析(head部分)_第2张图片
yolo_v4模型解析(head部分)_第3张图片
ok,上代码,基本每一句都已经注释完毕:

import torch.nn as nn
import torch.nn.functional as F
import torch
import numpy as np

#  output:(B,A*n_ch,H,W) ---->  (B, A, H, W, n_ch)
def yolo_decode(output, num_classes, anchors, num_anchors, scale_x_y):
    device = None
    cuda_check = output.is_cuda  #判断是否为gpu数据
    if cuda_check:
        device = output.get_device()

    n_ch = 4+1+num_classes
    A = num_anchors
    B = output.size(0)  # 取banch

    H = output.size(2)  # 取网格大小
    W = output.size(3)

    output = output.view(B, A, n_ch, H, W).permute(0,1,3,4,2).contiguous()
    bx, by = output[..., 0], output[..., 1]
    bw, bh = output[..., 2], output[..., 3]

    det_confs = output[..., 4]         #有无物体置信度
    cls_confs = output[..., 5:]       #取类别通道,注意此处为切片,最后的维度不会丢失,上面为索引,损失最后一个维度。

    bx = torch.sigmoid(bx)  #物体中心点相对网格变化量
    by = torch.sigmoid(by)
    bw = torch.exp(bw)*scale_x_y - 0.5*(scale_x_y-1) #scale_x_y为缩放因子,物体大小不均匀时使用,否则默认为1
    bh = torch.exp(bh)*scale_x_y - 0.5*(scale_x_y-1)

    det_confs = torch.sigmoid(det_confs)  #将的得分值转为概率值 用以判断有无物体
    cls_confs = torch.sigmoid(cls_confs)  #转到概率值 用于判断物体类别
   '''
   torch.arange 创建[0.1.2....18]的一维张量,repeat作用为在对应的维度位置、repeat规定次数。比如(1,3,W,1)维度为n.c.w,h.
   对于[0.1.2...18],自右向左 即在列方向重复1次,在行方向上重复W次,通道上重复3次, batch上重复一次,最后得到(1.3.19.19)
   '''
    grid_x = torch.arange(W, dtype=torch.float).repeat(1, 3, W, 1).to(device)
    #.permute 行列互换,形成grid_y
    grid_y = torch.arange(H, dtype=torch.float).repeat(1, 3, H, 1).permute(0, 1, 3, 2).to(device)
    bx += grid_x
    by += grid_y
    # anchor传入的为anchor的数值,并且已经装化为了网格单位,例如[142,110,192,243,,459,401]/32,
    # anchors[i*2] 用作索引anchor列表的的数值
    for i in range(num_anchors):
        bw[:, i, :, :] *= anchors[i*2]  #anchors[0],anchors[2],anchors[4],
        bh[:, i, :, :] *= anchors[i*2+1]# anchors[1],anchors[3],anchors[5],
# 此时bx表示物体中心点在网格上的绝对位置,bx / W表示中心点相对整张图的位置
    bx = (bx / W).unsqueeze(-1) #最后增加一个维度,变为(1,3,19,19,1)
    by = (by / H).unsqueeze(-1)
    bw = (bw / W).unsqueeze(-1)
    bh = (bh / H).unsqueeze(-1)
    #cat完成,维度(1,3,19,19,4)
    boxes = torch.cat((bx, by, bw, bh), dim=-1).reshape(B, A * H * W, 4)
    det_confs = det_confs.unsqueeze(-1).reshape(B, A*H*W, 1)
    cls_confs =cls_confs.reshape(B, A*H*W, num_classes)
    outputs = torch.cat([boxes, det_confs, cls_confs], dim=-1)# 全部拼接完成,维度变为(1,3x19x19,4+1+num_class)
    #return boxes,
    return outputs


class YoloLayer(nn.Module):
    ''' Yolo layer
    model_out: while inference,is post-processing inside or outside the model
        true:outside
    anchor_masks=[],分别传入[6,7,8] [3,4,5] [0,1,2], 预测时 yololayer会被调用三次,分别传入不同anchor_mask

    '''

    def __init__(self, img_size, anchor_masks=[], num_classes=80, anchors=[], num_anchors=9, stride=32,scale_x_y=1):
        super(YoloLayer, self).__init__()
        #假设本次传入anchor_masks=[6,7,8]
        self.anchor_masks = anchor_masks
        #类别
        self.num_classes = num_classes
        
        self.anchors = anchors #例如:[12.16. 19.36. 40.28. 36.75. 76.55. 72.146. 142.110. 192.243. 459.401] 9组9个anchor
        self.num_anchors = num_anchors
        #18/9 = 2
        self.anchor_step = len(self.anchors) // num_anchors
        #32
        self.stride = stride
        self.scale_x_y = scale_x_y

        self.feature_length = [img_size[0]//8,img_size[0]//16,img_size[0]//32]
        self.img_size = img_size

    def forward(self, output):
        if self.training:
            return output

        in_w = output.size(3)
        anchor_index = self.anchor_masks[self.feature_length.index(in_w)]
        stride_w = self.img_size[0] / in_w

        masked_anchors = []
        for m in anchor_index:
            masked_anchors += self.anchors[m * self.anchor_step:(m + 1) * self.anchor_step]
        #[142,110,192,243,,459,401]/32
        masked_anchors = [anchor / stride_w for anchor in masked_anchors] # 转为网格单位

        data = yolo_decode(output, self.num_classes, masked_anchors, len(anchor_index),scale_x_y=self.scale_x_y)
        return data
       

你可能感兴趣的:(python,自动驾驶,深度学习)