目标检测中,DETR方法为何class设置为91+1,DINO中为91

基于DEtection TRansformer的DETR框架https://github.com/facebookresearch/detr因为end-to-end,无需后处理等优点,逐渐得到青睐。DINO方法https://github.com/IDEA-Research/DINO更是取得了在COCO2017的SOTA结果。

其中,在DETR方法中,class被设置为了91+1:

    # the `num_classes` naming here is somewhat misleading.
    # it indeed corresponds to `max_obj_id + 1`, where max_obj_id
    # is the maximum id for a class in your dataset. For example,
    # COCO has a max_obj_id of 90, so we pass `num_classes` to be 91.
    # As another example, for a dataset that has a single class with id 1,
    # you should pass `num_classes` to be 2 (max_obj_id + 1).
    # For more details on this, check the following discussion
    # https://github.com/facebookresearch/detr/issues/108#issuecomment-650269223
    num_classes = 20 if args.dataset_file != 'coco' else 91
    if args.dataset_file == "coco_panoptic":
        # for panoptic, we just add a num_classes that is large enough to hold
        # max_obj_id + 1, but the exact value doesn't really matter
        num_classes = 250
    device = torch.device(args.device)

在类初始化时:

目标检测中,DETR方法为何class设置为91+1,DINO中为91_第1张图片

计算class loss时:

目标检测中,DETR方法为何class设置为91+1,DINO中为91_第2张图片

 而在DINO中,class number被设置为91.

coco2017共有80个子类别,由于该类别是原始数据的子集(原91类别,可参阅原文https://arxiv.org/pdf/1405.0312.pdf%090.949.pdf),其各类别物体的原始ID是从1~90,共90个,中间有断层:

{"1": "person", "2": "bicycle", "3": "car", "4": "motorcycle", "5": "airplane", "6": "bus", 
"7": "train", "8": "truck", "9": "boat", "10": "traffic light", "11": "fire hydrant", "13": 
"stop sign", "14": "parking meter", "15": "bench", "16": "bird", "17": "cat", "18": "dog", 
"19": "horse", "20": "sheep", "21": "cow", "22": "elephant", "23": "bear", "24": "zebra", 
"25": "giraffe", "27": "backpack", "28": "umbrella", "31": "handbag", "32": "tie", "33": 
"suitcase", "34": "frisbee", "35": "skis", "36": "snowboard", "37": "sports ball", "38": 
"kite", "39": "baseball bat", "40": "baseball glove", "41": "skateboard", "42": 
"surfboard", "43": "tennis racket", "44": "bottle", "46": "wine glass", "47": "cup", "48": 
"fork", "49": "knife", "50": "spoon", "51": "bowl", "52": "banana", "53": "apple", "54": 
"sandwich", "55": "orange", "56": "broccoli", "57": "carrot", "58": "hot dog", "59": 
"pizza", "60": "donut", "61": "cake", "62": "chair", "63": "couch", "64": "potted plant", 
"65": "bed", "67": "dining table", "70": "toilet", "72": "tv", "73": "laptop", "74": 
"mouse", "75": "remote", "76": "keyboard", "77": "cell phone", "78": "microwave", "79": 
"oven", "80": "toaster", "81": "sink", "82": "refrigerator", "84": "book", "85": "clock", 
"86": "vase", "87": "scissors", "88": "teddy bear", "89": "hair drier", "90": "toothbrush"}

那我们就知道,coco在DETR和DINO中,有效的类别任然是从1~90,共80个具体类别(忽略断层ID)。而在原始的DETR中,计算class loss时,采用的是交叉熵:

        src_logits = outputs['pred_logits']

        idx = self._get_src_permutation_idx(indices)
        target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
        target_classes = torch.full(src_logits.shape[:2], self.num_classes,
                                    dtype=torch.int64, device=src_logits.device)
        target_classes[idx] = target_classes_o

        if torch.min(target_classes_o) < 1 or torch.max(target_classes_o) > 90:
            print(target_classes_o)

        loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight)

cross_entropy不要求输入为one-hot形式,只要求为整数ID即可。所以,在DETR中,理论上只需要1~90 + 1(no object)=91个维度就够了,那为何是92呢?

这是因为作者初始化的no object为ID “91”(不使用“0”的原因应该只是为了矩阵的操作方便),index 0不使用的情况下,91维进能够表示到ID 90, 故需要设置为91+1=92。

那为何在DINO中是91呢?

DINO中实际也需要用到92。同样的,初始化no object为ID "91", 但由于作者用sigmoid函数来计算loss,即coco的类别之间不产生竞争,每个框预测属于该类还是no object,无需显式的定义一个维度来预测no object。所以,我们初始化时候,看到类别是91.

然而,我们看到,在计算class loss时,多了一个维度,即92:

        target_classes = torch.full(src_logits.shape[:2], self.num_classes,
                                    dtype=torch.int64, device=src_logits.device)
        target_classes[idx] = target_classes_o

        target_classes_onehot = torch.zeros([src_logits.shape[0], src_logits.shape[1], src_logits.shape[2]+1],
                                            dtype=src_logits.dtype, layout=src_logits.layout, device=src_logits.device)
        target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1)

        target_classes_onehot = target_classes_onehot[:,:,:-1]
        loss_ce = sigmoid_focal_loss(src_logits, target_classes_onehot, num_boxes, alpha=self.focal_alpha, gamma=2) * src_logits.shape[1]
        losses = {'loss_ce': loss_ce}

 target_classes_onehot = torch.zeros([src_logits.shape[0], src_logits.shape[1], src_logits.shape[2]+1], dtype=src_logits.dtype, layout=src_logits.layout, device=src_logits.device)

仔细看就会发现,这里shape[2]+1是为了下一步scatter_时使用的,便于能够处理ID“91”,否则会出错。

之后,使用的时候,也是直接忽略了这个临时增加的维度:

target_classes_onehot = target_classes_onehot[:,:,:-1]

你可能感兴趣的:(目标检测,深度学习,人工智能)