yolov3选取正负样本

  • 负责预测目标网格中与ground truth的IOU最大的anchor为正样本(记住这里没有阈值的事情,否则会绕晕)
  • 剩下的anchor中,与全部ground truth的IOU都小于阈值的anchor为负样本
  • 其他是忽略样本
  • 代码未完待续
  • 获取正样本代码,参考这里
def calculate_iou(_box_a, _box_b):
		b1_x1, b1_x2 = _box_a[:, 0] - _box_a[:, 2] / 2, _box_a[:, 0] + _box_a[:, 2] / 2
        b1_y1, b1_y2 = _box_a[:, 1] - _box_a[:, 3] / 2, _box_a[:, 1] + _box_a[:, 3] / 2
        b2_x1, b2_x2 = _box_b[:, 0] - _box_b[:, 2] / 2, _box_b[:, 0] + _box_b[:, 2] / 2
        b2_y1, b2_y2 = _box_b[:, 1] - _box_b[:, 3] / 2, _box_b[:, 1] + _box_b[:, 3] / 2
        box_a = torch.zeros_like(_box_a)
        box_b = torch.zeros_like(_box_b)
        box_a[:, 0], box_a[:, 1], box_a[:, 2], box_a[:, 3] = b1_x1, b1_y1, b1_x2, b1_y2
        box_b[:, 0], box_b[:, 1], box_b[:, 2], box_b[:, 3] = b2_x1, b2_y1, b2_x2, b2_y2
        A = box_a.size(0)
        B = box_b.size(0)
        # intersection
        # expand to A*B*2 and compare
        max_xy  = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2), box_b[:, 2:].unsqueeze(0).expand(A, B, 2))
        min_xy  = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2), box_b[:, :2].unsqueeze(0).expand(A, B, 2))
        # minus and set 0 if result less than 0
        inter   = torch.clamp((max_xy - min_xy), min=0)
        # size:A*B
        inter   = inter[:, :, 0] * inter[:, :, 1]
        area_a = ((box_a[:, 2]-box_a[:, 0]) * (box_a[:, 3]-box_a[:, 1])).unsqueeze(1).expand_as(inter) 
        area_b = ((box_b[:, 2]-box_b[:, 0]) * (box_b[:, 3]-box_b[:, 1])).unsqueeze(0).expand_as(inter)
        union = area_a + area_b - inter
        return inter / union
'''
targets是标签列表,长度是batch_size,元素的shape是(真实框个数*5)
anchors是[[116,90],[156,198],[373,326]]或[[30,61],[62,45],[59,119]]或[[10,13],[16,30],[33,23]]
in_h, in_w是13,13或26,26或52,52
num_classes是类别数,voc是20,COCO是80
'''
def get_target(targets, anchors, in_h, in_w, num_classes):
    bs=len(targets)
    positive=torch.zeros(bs,len(anchors),in_h, in_w, 5+num_classes,requires_grad = False)
    negtive=torch.ones(bs,len(anchors),in_h, in_w, requires_grad = False)
    for b in range(bs):
        batch_target = torch.zeros_like(targets[b])
        # 计算该特征图上标签的值
        batch_target[:, [0,2]] = targets[b][:, [0,2]] * in_w
        batch_target[:, [1,3]] = targets[b][:, [1,3]] * in_h
        batch_target[:, 4] = targets[b][:, 4]
        batch_target = batch_target.cpu()
        # 计算标签和anchor的IOU
        # 这里可以随便选一个共同中心(0,0),根据高宽计算IOU
        gt_box= torch.FloatTensor(torch.cat((torch.zeros((batch_target.size(0), 2)), batch_target[:, 2:4]), 1))
        anchor_shapes=torch.FloatTensor(torch.cat((torch.zeros((len(anchors), 2)), torch.FloatTensor(anchors)), 1))
        iou=calculate_iou(gt_box, anchor_shapes)
        # 获得与标签最匹配的anchor的索引
        best_ns = torch.argmax(iou, dim=-1)
        for t, best_n in enumerate(best_ns):
            # 第t个标签中心所在网格,种类
            i = torch.floor(batch_target[t, 0]).long()
            j = torch.floor(batch_target[t, 1]).long()
            c = batch_target[t, 4].long()
            positive[b,best_n,j,i,0]=batch_target[t, 0] - i.float()
            positive[b,best_n,j,i,1]=batch_target[t, 1] - j.float()
            positive[b,best_n,j,i,2]=math.log(batch_target[t, 2] / anchors[best_n][0])
            positive[b,best_n,j,i,3]=math.log(batch_target[t, 3] / anchors[best_n][1])
            positive[b,best_n,j,i,4]=1
            positive[b,best_n,j,i,c+5]=1
            negtive[b,best_n,j,i]=0
    return positive,negtive

你可能感兴趣的:(deeplearning,深度学习,目标检测,yolov3)