- 负责预测目标网格中与ground truth的IOU最大的anchor为正样本(记住这里没有阈值的事情,否则会绕晕)
- 剩下的anchor中,与全部ground truth的IOU都小于阈值的anchor为负样本
- 其他是忽略样本
- 代码未完待续
- 获取正样本代码,参考这里
def calculate_iou(_box_a, _box_b):
b1_x1, b1_x2 = _box_a[:, 0] - _box_a[:, 2] / 2, _box_a[:, 0] + _box_a[:, 2] / 2
b1_y1, b1_y2 = _box_a[:, 1] - _box_a[:, 3] / 2, _box_a[:, 1] + _box_a[:, 3] / 2
b2_x1, b2_x2 = _box_b[:, 0] - _box_b[:, 2] / 2, _box_b[:, 0] + _box_b[:, 2] / 2
b2_y1, b2_y2 = _box_b[:, 1] - _box_b[:, 3] / 2, _box_b[:, 1] + _box_b[:, 3] / 2
box_a = torch.zeros_like(_box_a)
box_b = torch.zeros_like(_box_b)
box_a[:, 0], box_a[:, 1], box_a[:, 2], box_a[:, 3] = b1_x1, b1_y1, b1_x2, b1_y2
box_b[:, 0], box_b[:, 1], box_b[:, 2], box_b[:, 3] = b2_x1, b2_y1, b2_x2, b2_y2
A = box_a.size(0)
B = box_b.size(0)
max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2), box_b[:, 2:].unsqueeze(0).expand(A, B, 2))
min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2), box_b[:, :2].unsqueeze(0).expand(A, B, 2))
inter = torch.clamp((max_xy - min_xy), min=0)
inter = inter[:, :, 0] * inter[:, :, 1]
area_a = ((box_a[:, 2]-box_a[:, 0]) * (box_a[:, 3]-box_a[:, 1])).unsqueeze(1).expand_as(inter)
area_b = ((box_b[:, 2]-box_b[:, 0]) * (box_b[:, 3]-box_b[:, 1])).unsqueeze(0).expand_as(inter)
union = area_a + area_b - inter
return inter / union
'''
targets是标签列表,长度是batch_size,元素的shape是(真实框个数*5)
anchors是[[116,90],[156,198],[373,326]]或[[30,61],[62,45],[59,119]]或[[10,13],[16,30],[33,23]]
in_h, in_w是13,13或26,26或52,52
num_classes是类别数,voc是20,COCO是80
'''
def get_target(targets, anchors, in_h, in_w, num_classes):
bs=len(targets)
positive=torch.zeros(bs,len(anchors),in_h, in_w, 5+num_classes,requires_grad = False)
negtive=torch.ones(bs,len(anchors),in_h, in_w, requires_grad = False)
for b in range(bs):
batch_target = torch.zeros_like(targets[b])
batch_target[:, [0,2]] = targets[b][:, [0,2]] * in_w
batch_target[:, [1,3]] = targets[b][:, [1,3]] * in_h
batch_target[:, 4] = targets[b][:, 4]
batch_target = batch_target.cpu()
gt_box= torch.FloatTensor(torch.cat((torch.zeros((batch_target.size(0), 2)), batch_target[:, 2:4]), 1))
anchor_shapes=torch.FloatTensor(torch.cat((torch.zeros((len(anchors), 2)), torch.FloatTensor(anchors)), 1))
iou=calculate_iou(gt_box, anchor_shapes)
best_ns = torch.argmax(iou, dim=-1)
for t, best_n in enumerate(best_ns):
i = torch.floor(batch_target[t, 0]).long()
j = torch.floor(batch_target[t, 1]).long()
c = batch_target[t, 4].long()
positive[b,best_n,j,i,0]=batch_target[t, 0] - i.float()
positive[b,best_n,j,i,1]=batch_target[t, 1] - j.float()
positive[b,best_n,j,i,2]=math.log(batch_target[t, 2] / anchors[best_n][0])
positive[b,best_n,j,i,3]=math.log(batch_target[t, 3] / anchors[best_n][1])
positive[b,best_n,j,i,4]=1
positive[b,best_n,j,i,c+5]=1
negtive[b,best_n,j,i]=0
return positive,negtive