【YOLOv5 Head解耦】

目录

  • 前言
  • 一、Head代码
  • 二、Detect调用

前言

YOLOv5原理 【YOLOV5-5.x 源码讲解】整体项目文件导航

一、Head代码

这部分定义的比较复杂,可以根据自己任务自定义啊,建议看下别人解耦头怎么定义的,然后可以在这里自定义一下。

class DecoupledHead(nn.Module):
    def __init__(self, ch=256, nc=80, width=1.0, anchors=()):
        super().__init__()
        self.nc = nc  # number of classes
        self.nl = len(anchors)  # number of detection layers 3
        self.na = len(anchors[0]) // 2  # number of anchors 3
        self.merge = Conv(ch, 256 * width, 1, 1)
        self.cls_convs1 = Conv(256 * width, 256 * width, 3, 1, 1)
        self.cls_convs2 = Conv(256 * width, 256 * width, 3, 1, 1)
        self.reg_convs1 = Conv(256 * width, 256 * width, 3, 1, 1)
        self.reg_convs2 = Conv(256 * width, 256 * width, 3, 1, 1)
        self.cls_preds = nn.Conv2d(256 * width, self.nc * self.na, 1)
        self.reg_preds = nn.Conv2d(256 * width, 4 * self.na, 1)
        self.obj_preds = nn.Conv2d(256 * width, 1 * self.na, 1)

    def forward(self, x):
        x = self.merge(x)
        # 分类=3x3conv + 3x3conv + 1x1convpred
        x1 = self.cls_convs1(x)
        x1 = self.cls_convs2(x1)
        x1 = self.cls_preds(x1)
        # 回归=3x3conv(共享) + 3x3conv(共享) + 1x1pred
        x2 = self.reg_convs1(x)
        x2 = self.reg_convs2(x2)
        x21 = self.reg_preds(x2)
        # 置信度=3x3conv(共享)+ 3x3conv(共享) + 1x1pred
        x22 = self.obj_preds(x2)
        out = torch.cat([x21, x22, x1], 1)
        return out

二、Detect调用

Detect的__init__函数中更改self.m,如下所示

        # output conv 对每个输出的feature map都要调用一次conv1x1
        # self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch)
        self.m = nn.ModuleList(DecoupledHead(x, nc, 1, anchors) for x in ch)
        # use in-place ops (e.g. slice assignment) 一般都是True 默认不使用AWS Inferentia加速
        self.inplace = inplace

Model中的__init__的self._initialize_biases()注释掉

就可以运行了。

没来得及做实验,最近在准备秋招,感兴趣的可以自己做下实验。

你可能感兴趣的:(#,YOLOV5-5.x,源码讲解,YOLOv5,head解耦)