从yolov5谈 Backbone neck和head

从yolov5谈 Backbone neck和head

  • yolov5 流程图
  • resnet+fpn

yolov5 流程图

从yolov5谈 Backbone neck和head_第1张图片
PANet实现

class YoloBody(nn.Module):
    def __init__(self, num_anchors, num_classes):
        super(YoloBody, self).__init__()
        #  backbone
        self.backbone = darknet53(None)

        self.conv1 = make_three_conv([512,1024],1024)
        self.SPP = SpatialPyramidPooling()
        self.conv2 = make_three_conv([512,1024],2048)

        self.upsample1 = Upsample(512,256)
        self.conv_for_P4 = conv2d(512,256,1)
        self.make_five_conv1 = make_five_conv([256, 512],512)

        self.upsample2 = Upsample(256,128)
        self.conv_for_P3 = conv2d(256,128,1)
        self.make_five_conv2 = make_five_conv([128, 256],256)
        # 3*(5+num_classes)=3*(5+20)=3*(4+1+20)=75
        # 4+1+num_classes
        final_out_filter2 = num_anchors * (5 + num_classes)
        self.yolo_head3 = yolo_head([256, final_out_filter2],128)

        self.down_sample1 = conv2d(128,256,3,stride=2)
        self.make_five_conv3 = make_five_conv([256, 512],512)
        # 3*(5+num_classes)=3*(5+20)=3*(4+1+20)=75
        final_out_filter1 =  num_anchors * (5 + num_classes)
        self.yolo_head2 = yolo_head([512, final_out_filter1],256)


        self.down_sample2 = conv2d(256,512,3,stride=2)
        self.make_five_conv4 = make_five_conv([512, 1024],1024)
        # 3*(5+num_classes)=3*(5+20)=3*(4+1+20)=75
        final_out_filter0 =  num_anchors * (5 + num_classes)
        self.yolo_head1 = yolo_head([1024, final_out_filter0],512)


    def forward(self, x):
        #  backbone
        x2, x1, x0 = self.backbone(x)

        P5 = self.conv1(x0)
        P5 = self.SPP(P5)
        P5 = self.conv2(P5)

        P5_upsample = self.upsample1(P5)
        P4 = self.conv_for_P4(x1)
        P4 = torch.cat([P4,P5_upsample],axis=1)
        P4 = self.make_five_conv1(P4)

        P4_upsample = self.upsample2(P4)
        P3 = self.conv_for_P3(x2)
        P3 = torch.cat([P3,P4_upsample],axis=1)
        P3 = self.make_five_conv2(P3)

        P3_downsample = self.down_sample1(P3)
        P4 = torch.cat([P3_downsample,P4],axis=1)
        P4 = self.make_five_conv3(P4)

        P4_downsample = self.down_sample2(P4)
        P5 = torch.cat([P4_downsample,P5],axis=1)
        P5 = self.make_five_conv4(P5)

        out2 = self.yolo_head3(P3)
        out1 = self.yolo_head2(P4)
        out0 = self.yolo_head1(P5)

        return out0, out1, out2

resnet+fpn

class YoloV6(nn.Module):

    def __init__(self, cfg, ch=3):
        super(YoloV6, self).__init__()
        with open(cfg) as f:
            self.md = yaml.load(f, Loader=yaml.FullLoader)
        self.nc = self.md['nc']
        self.anchors = self.md['anchors']
        self.na = len(self.anchors[0]) // 2  # number of anchors
        self.backbone = resnet18()
        self.backbone = IntermediateLayerGetter(
            self.backbone, {'layer2': 0, 'layer3': 1, 'layer4': 2})
        # FPN
        in_channels = [128, 256, 512]
        self.fpn = FeaturePyramidNetwork(in_channels_list=in_channels, out_channels=256)
        self.detect = Detect(self.nc, self.anchors, [256, 256, 256])
        # forward to get Detect lay params dynamically
        s = 512  # 2x min stride
        self.detect.stride = torch.tensor(
            [s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))])  # forward
        self.detect.anchors /= self.detect.stride.view(-1, 1, 1)
        check_anchor_order(self.detect)
        self.stride = self.detect.stride
        self._initialize_biases()
        initialize_weights(self)

        # for compatible, in check_anchors in train.py
        self.model = [self.detect]

    def _initialize_biases(self, cf=None):
        m = self.detect  # Detect() module
        for mi, s in zip(m.m, m.stride):  # from
            b = mi.bias.view(m.na, -1)  # conv.bias(255) to (3,85)
            # obj (8 objects per 640 image)
            b[:, 4] += math.log(8 / (640 / s) ** 2)
            b[:, 5:] += math.log(0.6 / (m.nc - 0.99)
                                 ) if cf is None else torch.log(cf / cf.sum())  # cls
            mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)

    def forward(self, x, augment=False):
        # we not using augment at all
        feas = self.backbone(x)
        
        a = self.fpn(feas)
        # for k, v in a.items():
            # print(k, v.shape)
        x_s, x_m, x_l = a[0], a[1], a[2]
        # print('p3 ', p3.shape)
        # print('p4 ', p4.shape)
        # print('p5 ', p5.shape)
        # print('feas ', feas.shape)
        # x_s, x_m, x_l = self._build_head(p3, p4, p5, feas)
        x = self.detect([x_s, x_m, x_l])
        return x

detect实现

class Detect(nn.Module):
    stride = None  # strides computed during build
    export = False  # onnx export

    def __init__(self, nc=80, anchors=(), ch=()):  # detection layer
        super(Detect, self).__init__()
        self.nc = nc  # number of classes
        self.no = nc + 5  # number of outputs per anchor
        self.nl = len(anchors)  # number of detection layers
        self.na = len(anchors[0]) // 2  # number of anchors
        self.grid = [torch.zeros(1)] * self.nl  # init grid
        a = torch.tensor(anchors).float().view(self.nl, -1, 2)
        self.register_buffer('anchors', a)  # shape(nl,na,2)
        self.register_buffer('anchor_grid', a.clone().view(self.nl, 1, -1, 1, 1, 2))  # shape(nl,1,na,1,1,2)
        self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch)  # output conv
        self.anchor_grid_awesome = torch.tensor(anchors).to(device).view(self.nl, 1, -1, 1, 1, 2)

    def forward(self, x):
        # x = x.copy()  # for profiling
        z = []  # inference output
        # self.training |= self.export
        for i in range(self.nl):
            x[i] = self.m[i](x[i])  # conv
            if self.export:
                print('exporting...')
                s = x[i].shape
                bs, _, ny, nx = s  # x(bs,255,20,20) to x(bs,3,20,20,85)
                x_i = x[i]
                x_i = x_i.view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
            else:
                bs, _, ny, nx = x[i].shape  # x(bs,255,20,20) to x(bs,3,20,20,85)
                x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()

            if not self.training or self.export:  # inference                
                if self.export:
                    self.grid = [i.to(x_i.device) for i in self.grid]
                    self.a = self.anchor_grid_awesome[i]

                    if self.grid[i].shape[2:4] != x_i.shape[2:4]:
                        self.grid[i] = self._make_grid(nx, ny).to(x_i.device)
                    y = x_i.sigmoid()
                    print('[WARN] you are calling export...')
                    x1y1, x2y2, conf, prob = torch.split(y, [2, 2, 1, self.nc], dim=4)
                    x1y1 = ((x1y1*2. - 0.5 + self.grid[i].to(x_i.device)) * self.stride[i]).type(x_i.dtype)
                    x2y2 = (x2y2 * 2) ** 2 * self.a
                    xyxy = torch.cat((x1y1, x2y2), dim=4)
                    # # add a idx (label ids before prob)
                    idxs = torch.argmax(prob, dim=-1).unsqueeze(axis=-1).type(x_i.dtype).to(x_i.device)
                    y = torch.cat((xyxy, conf, idxs, prob), dim=4).to(x_i.device)
                    # we added idxs so no+1
                    z.append(y.view(bs, -1, self.no+1))
                else:
                    if self.grid[i].shape[2:4] != x[i].shape[2:4]:
                        self.grid[i] = self._make_grid(nx, ny).to(x[i].device)

                    y = x[i].sigmoid()
                    y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i].to(x[i].device)) * self.stride[i]  # xy
                    y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i]  # wh
                    z.append(y.view(bs, -1, self.no))

        if self.training:
            return x if self.training else (torch.cat(z, 1), x)
        elif self.export:
            return torch.cat(z, 1)
        else:
            return (torch.cat(z, 1), x)

    @staticmethod
    def _make_grid(nx=20, ny=20):
        yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
        return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()

你可能感兴趣的:(pytorch,深度学习)