yolo源码注释4——yolo-py

代码基于yolov5 v6.0

目录:

  • yolo源码注释1——文件结构
  • yolo源码注释2——数据集配置文件
  • yolo源码注释3——模型配置文件
  • yolo源码注释4——yolo-py

yolo.py 用于搭建 yolov5 的网络模型,主要包含 3 部分:

  • Detect:Detect 层
  • Model:搭建网络
  • parse_model:根据配置实例化模块

Model(仅注释了 init 函数):

class Model(nn.Module):
    # YOLOv5 model
    def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None, anchors=None):  # model, input channels, number of classes
        super().__init__()
        if isinstance(cfg, dict):
            self.yaml = cfg  # model dict
        else:  # is *.yaml
            import yaml
            self.yaml_file = Path(cfg).name
            with open(cfg, encoding='ascii', errors='ignore') as f:
                self.yaml = yaml.safe_load(f)

        # Define model
        ch = self.yaml['ch'] = self.yaml.get('ch', ch)  # input channels
        if nc and nc != self.yaml['nc']:
            LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
            self.yaml['nc'] = nc  # override yaml value
        if anchors:
            LOGGER.info(f'Overriding model.yaml anchors with anchors={anchors}')
            self.yaml['anchors'] = round(anchors)  # override yaml value

        # 根据配置搭建网络
        self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch])

        self.names = [str(i) for i in range(self.yaml['nc'])]  # default names
        self.inplace = self.yaml.get('inplace', True)

        # 计算生成 anchors 时的步长
        m = self.model[-1]  # Detect()
        if isinstance(m, Detect):
            s = 256  # 2x min stride

            m.inplace = self.inplace
            m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))])  # forward
            check_anchor_order(m)  # must be in pixel-space (not grid-space)
            m.anchors /= m.stride.view(-1, 1, 1)

            self.stride = m.stride
            self._initialize_biases()  # only run once

        # Init weights, biases
        initialize_weights(self)
        self.info()
        LOGGER.info('')

parse_model:

def parse_model(d, ch):  # model_dict, input_channels(3)
    LOGGER.info(f"\n{'':>3}{'from':>18}{'n':>3}{'params':>10}  {'module':<40}{'arguments':<30}")

    anchors, nc, gd, gw = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple']
    na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors  # number of anchors
    no = na * (nc + 5)  # number of outputs = anchors * (classes + 5)

    # layers: 保存每一层的结构
    # save: 记录 from 不是 -1 的层,即需要多个输入的层如 Concat 和 Detect 层
    # c2: 当前层输出的特征图数量
    layers, save, c2 = [], [], ch[-1]  # layers, savelist, ch out
    for i, (f, n, m, args) in enumerate(
            d['backbone'] + d['head']):  # from:-1, number:1, module:'Conv', args:[64, 6, 2, 2]
        m = eval(m) if isinstance(m, str) else m  # eval strings, m:

        # 数字、列表直接放入args[i],字符串通过 eval 函数变成模块
        for j, a in enumerate(args):
            try:
                args[j] = eval(a) if isinstance(a, str) else a  # eval strings, [64, 6, 2, 2]
            except NameError:
                pass

        # 对数量大于1的模块和 depth_multiple 相乘然后四舍五入
        n = n_ = max(round(n * gd), 1) if n > 1 else n  # depth gain

        # 实例化 ymal 文件中的每个模块
        if m in (Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv,
                 BottleneckCSP, C3, C3TR, C3SPP, C3Ghost,SE, FSM):
            c1, c2 = ch[f], args[0]  # 输入特征图数量(f指向的层的输出特征图数量),输出特征图数量

            # 如果输出层的特征图数量不等于 no (Detect输出层)
            # 则将输出图的特征图数量乘 width_multiple ,并调整为 8 的倍数
            if c2 != no:  # if not output
                c2 = make_divisible(c2 * gw, 8)

            args = [c1, c2, *args[1:]]  # 默认参数格式:[输入, 输出, 其他参数……]

            # 参数有特殊格式要求的模块
            if m in [BottleneckCSP, C3, C3TR, C3Ghost, CSPStage]:
                args.insert(2, n)  # number of repeats
                n = 1
        elif m is nn.BatchNorm2d:
            args = [ch[f]]
        elif m is Concat:
            c2 = sum(ch[x] for x in f)
        elif m is Detect:
            args.append([ch[x] for x in f])
            if isinstance(args[1], int):  # number of anchors
                args[1] = [list(range(args[1] * 2))] * len(f)
        elif m is Contract:
            c2 = ch[f] * args[0] ** 2
        elif m is Expand:
            c2 = ch[f] // args[0] ** 2
        else:
            c2 = ch[f]

        m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args)  # module
        t = str(m)[8:-2].replace('__main__.', '')  # module type
        np = sum(x.numel() for x in m_.parameters())  # number params
        m_.i, m_.f, m_.type, m_.np = i, f, t, np  # attach index, 'from' index, type, number params
        LOGGER.info(f'{i:>3}{str(f):>18}{n_:>3}{np:10.0f}  {t:<40}{str(args):<30}')  # print

        save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1)  # append to savelist
        layers.append(m_)

        if i == 0:
            ch = []
        ch.append(c2)
    return nn.Sequential(*layers), sorted(save)

你可能感兴趣的:(计算机视觉,YOLO,目标检测)