所有代码已上传到本人github repository:https://github.com/zgcr/pytorch-ImageNet-CIFAR-COCO-VOC-training
如果觉得有用,请点个star哟!
下列代码均在pytorch1.4版本中测试过,确认正确无误。
首先要明确的是,FCOS确实没有像RetinaNet那样使用了显式的Anchor(先验框)。FCOS把每一级FPN level的feature map上的每一个点作为一个样本,然后,根据样本在标注框内还是标注框外决定该样本是正样本还是负样本(注意FCOS中没有被忽略的样本)。从这一点上来说,FCOS确实是Anchor free的。但是,在FCOS进行ground trurh分配和测试计算时仍然要使用feature map上每个点倒推到输入图片上的(x,y)坐标,从这一点上来说,FCOS并不是完全free的,更准确地来说,FCOS是一个"point based"目标检测器。我们可以把FCOS看成是feature map上每个点只有一个隐式Anchor的目标检测器。
2020年新发布的DETR目标检测器(https://arxiv.org/pdf/2005.12872.pdf)把目标检测任务检测看成集合预测问题,使用了Transformer来预测box集合,完全不需要使用NMS和Anchor/Point的先验坐标,使得检测器真正做到了"free",感兴趣的同学可以自行了解。
对于一张输入图片上标注的多个框,首先把FPN上每一级FPN的feature map上的所有点都做判断,如果某个点在所有的标注框之外,那么这个点就作为负样本。此时,剩下的点中有些点可能同时在多个标注框内。然后取每个点对每个标注框的l,t,r,b(该点距离框左、上、右、下的距离)中的最大值,根据下面的值域范围,当最大值落在哪个范围内,就把该框分配给这个范围对应的FPN level的feature map上的对应点。
# 从左到右为分配给P3、P4、P5、P6、P7的值域范围
INF=100000000
mi=[[-1, 64], [64, 128], [128, 256], [256, 512], [512, INF]]
经过上面一步以后,绝大部分点都会只分配给一个框。但是仍然有些点会同时在两个框内(当有两个标注框的大小差不多的时候)。对于这些点,我们计算其与重叠框的面积,然后总是把这些点分配给面积最小的标注框。在下面的实现代码中,对于这部分样本我使用了矩阵计算的形式进行标签分配。虽然每张图上正样本一般只有几十到两三百左右,但是如果对这部分正样本使用for循环来分配标签,训练速度会变得非常慢,这一点需要注意。
对于分类标签,以0为负样本,1到80为80个正类;l,t,r,b和centerness标签完全按照FCOS论文中公式计算,没有修改。
ground truth分配代码实现如下:
def get_batch_position_annotations(self, cls_heads, reg_heads,
center_heads, batch_positions,
annotations):
"""
Assign a ground truth target for each position on feature map
"""
device = annotations.device
batch_mi = []
for reg_head, mi in zip(reg_heads, self.mi):
mi = torch.tensor(mi).to(device)
B, H, W, _ = reg_head.shape
per_level_mi = torch.zeros(B, H, W, 2).to(device)
per_level_mi = per_level_mi + mi
batch_mi.append(per_level_mi)
cls_preds,reg_preds,center_preds,all_points_position,all_points_mi=[],[],[],[],[]
for cls_pred, reg_pred, center_pred, per_level_position, per_level_mi in zip(
cls_heads, reg_heads, center_heads, batch_positions, batch_mi):
cls_pred = cls_pred.view(cls_pred.shape[0], -1, cls_pred.shape[-1])
reg_pred = reg_pred.view(reg_pred.shape[0], -1, reg_pred.shape[-1])
center_pred = center_pred.view(center_pred.shape[0], -1,
center_pred.shape[-1])
per_level_position = per_level_position.view(
per_level_position.shape[0], -1, per_level_position.shape[-1])
per_level_mi = per_level_mi.view(per_level_mi.shape[0], -1,
per_level_mi.shape[-1])
cls_preds.append(cls_pred)
reg_preds.append(reg_pred)
center_preds.append(center_pred)
all_points_position.append(per_level_position)
all_points_mi.append(per_level_mi)
cls_preds = torch.cat(cls_preds, axis=1)
reg_preds = torch.cat(reg_preds, axis=1)
center_preds = torch.cat(center_preds, axis=1)
all_points_position = torch.cat(all_points_position, axis=1)
all_points_mi = torch.cat(all_points_mi, axis=1)
batch_targets = []
for per_image_position, per_image_mi, per_image_annotations in zip(
all_points_position, all_points_mi, annotations):
per_image_annotations = per_image_annotations[
per_image_annotations[:, 4] >= 0]
points_num = per_image_position.shape[0]
if per_image_annotations.shape[0] == 0:
# 6:l,t,r,b,class_index,center-ness_gt
per_image_targets = torch.zeros([points_num, 6], device=device)
else:
annotaion_num = per_image_annotations.shape[0]
per_image_gt_bboxes = per_image_annotations[:, 0:4]
candidates = torch.zeros([points_num, annotaion_num, 4],
device=device)
candidates = candidates + per_image_gt_bboxes.unsqueeze(0)
per_image_position = per_image_position.unsqueeze(1).repeat(
1, annotaion_num, 2)
candidates[:, :,
0:2] = per_image_position[:, :,
0:2] - candidates[:, :,
0:2]
candidates[:, :,
2:4] = candidates[:, :,
2:4] - per_image_position[:, :,
2:4]
candidates_min_value, _ = candidates.min(axis=-1, keepdim=True)
sample_flag = (candidates_min_value[:, :, 0] >
0).int().unsqueeze(-1)
# get all negative reg targets which points ctr out of gt box
candidates = candidates * sample_flag
# get all negative reg targets which assign ground turth not in range of mi
candidates_max_value, _ = candidates.max(axis=-1, keepdim=True)
per_image_mi = per_image_mi.unsqueeze(1).repeat(
1, annotaion_num, 1)
m1_negative_flag = (candidates_max_value[:, :, 0] >
per_image_mi[:, :, 0]).int().unsqueeze(-1)
candidates = candidates * m1_negative_flag
m2_negative_flag = (candidates_max_value[:, :, 0] <
per_image_mi[:, :, 1]).int().unsqueeze(-1)
candidates = candidates * m2_negative_flag
final_sample_flag = candidates.sum(axis=-1).sum(axis=-1)
final_sample_flag = final_sample_flag > 0
positive_index = (final_sample_flag == True).nonzero().squeeze(
dim=-1)
# if no assign positive sample
if len(positive_index) == 0:
del candidates
# 6:l,t,r,b,class_index,center-ness_gt
per_image_targets = torch.zeros([points_num, 6],
device=device)
else:
positive_candidates = candidates[positive_index]
del candidates
sample_box_gts = per_image_annotations[:, 0:4].unsqueeze(0)
sample_box_gts = sample_box_gts.repeat(
positive_candidates.shape[0], 1, 1)
sample_class_gts = per_image_annotations[:, 4].unsqueeze(
-1).unsqueeze(0)
sample_class_gts = sample_class_gts.repeat(
positive_candidates.shape[0], 1, 1)
# 6:l,t,r,b,class_index,center-ness_gt
per_image_targets = torch.zeros([points_num, 6],
device=device)
if positive_candidates.shape[1] == 1:
# if only one candidate for each positive sample
# assign l,t,r,b,class_index,center_ness_gt ground truth
# class_index value from 1 to 80 represent 80 positive classes
# class_index value 0 represenet negative class
positive_candidates = positive_candidates.squeeze(1)
sample_class_gts = sample_class_gts.squeeze(1)
per_image_targets[positive_index,
0:4] = positive_candidates
per_image_targets[positive_index,
4:5] = sample_class_gts + 1
l, t, r, b = per_image_targets[
positive_index, 0:1], per_image_targets[
positive_index, 1:2], per_image_targets[
positive_index,
2:3], per_image_targets[positive_index,
3:4]
per_image_targets[positive_index, 5:6] = torch.sqrt(
(torch.min(l, r) / torch.max(l, r)) *
(torch.min(t, b) / torch.max(t, b)))
else:
# if a positive point sample have serveral object candidates,then choose the smallest area object candidate as the ground turth for this positive point sample
gts_w_h = sample_box_gts[:, :,
2:4] - sample_box_gts[:, :,
0:2]
gts_area = gts_w_h[:, :, 0] * gts_w_h[:, :, 1]
positive_candidates_value = positive_candidates.sum(
axis=2)
# make sure all negative candidates areas==100000000,thus .min() operation wouldn't choose negative candidates
INF = 100000000
inf_tensor = torch.ones_like(gts_area) * INF
gts_area = torch.where(
torch.eq(positive_candidates_value, 0.),
inf_tensor, gts_area)
# get the smallest object candidate index
_, min_index = gts_area.min(axis=1)
candidate_indexes = (
torch.linspace(1, positive_candidates.shape[0],
positive_candidates.shape[0]) -
1).long()
final_candidate_reg_gts = positive_candidates[
candidate_indexes, min_index, :]
final_candidate_cls_gts = sample_class_gts[
candidate_indexes, min_index]
# assign l,t,r,b,class_index,center_ness_gt ground truth
per_image_targets[positive_index,
0:4] = final_candidate_reg_gts
per_image_targets[positive_index,
4:5] = final_candidate_cls_gts + 1
l, t, r, b = per_image_targets[
positive_index, 0:1], per_image_targets[
positive_index, 1:2], per_image_targets[
positive_index,
2:3], per_image_targets[positive_index,
3:4]
per_image_targets[positive_index, 5:6] = torch.sqrt(
(torch.min(l, r) / torch.max(l, r)) *
(torch.min(t, b) / torch.max(t, b)))
per_image_targets = per_image_targets.unsqueeze(0)
batch_targets.append(per_image_targets)
batch_targets = torch.cat(batch_targets, axis=0)
batch_targets = torch.cat([batch_targets, all_points_position], axis=2)
# batch_targets shape:[batch_size, points_num, 8],8:l,t,r,b,class_index,center-ness_gt,point_ctr_x,point_ctr_y
return cls_preds, reg_preds, center_preds, batch_targets
分类loss采用focal loss,计算过程与RetinaNet完全一样,只是样本由Anchor变成了Point。
分类loss代码实现如下:
def compute_one_image_focal_loss(self, per_image_cls_preds,
per_image_targets):
"""
compute one image focal loss(cls loss)
per_image_cls_preds:[points_num,num_classes]
per_image_targets:[points_num,8]
"""
per_image_cls_preds = torch.clamp(per_image_cls_preds,
min=self.epsilon,
max=1. - self.epsilon)
num_classes = per_image_cls_preds.shape[1]
# generate 80 binary ground truth classes for each anchor
loss_ground_truth = F.one_hot(per_image_targets[:, 4].long(),
num_classes=num_classes + 1)
loss_ground_truth = loss_ground_truth[:, 1:]
loss_ground_truth = loss_ground_truth.float()
alpha_factor = torch.ones_like(per_image_cls_preds) * self.alpha
alpha_factor = torch.where(torch.eq(loss_ground_truth, 1.),
alpha_factor, 1. - alpha_factor)
pt = torch.where(torch.eq(loss_ground_truth, 1.), per_image_cls_preds,
1. - per_image_cls_preds)
focal_weight = alpha_factor * torch.pow((1. - pt), self.gamma)
bce_loss = -(
loss_ground_truth * torch.log(per_image_cls_preds) +
(1. - loss_ground_truth) * torch.log(1. - per_image_cls_preds))
one_image_focal_loss = focal_weight * bce_loss
one_image_focal_loss = one_image_focal_loss.sum()
positive_points_num = per_image_targets[
per_image_targets[:, 4] > 0].shape[0]
# according to the original paper,We divide the focal loss by the number of positive sample anchors
one_image_focal_loss = one_image_focal_loss / positive_points_num
return one_image_focal_loss
在FCOS论文中,回归loss采用IoU loss。这里我直接使用GIoU loss。由于回归loss仍然只对正样本进行计算,所以不存在预测框与真实框不相交的情况,此时GIoU loss和IoU loss是完全等同的。
回归loss代码实现如下:
def compute_one_image_giou_loss(self, per_image_reg_preds,
per_image_targets):
"""
compute one image giou loss(reg loss)
per_image_reg_preds:[points_num,4]
per_image_targets:[anchor_num,8]
"""
# only use positive points sample to compute reg loss
device = per_image_reg_preds.device
per_image_reg_preds = per_image_reg_preds[per_image_targets[:, 4] > 0]
per_image_targets = per_image_targets[per_image_targets[:, 4] > 0]
positive_points_num = per_image_targets.shape[0]
if positive_points_num == 0:
return torch.tensor(0.).to(device)
center_ness_targets = per_image_targets[:, 5]
pred_bboxes_xy_min = per_image_targets[:,
6:8] - per_image_reg_preds[:,
0:2]
pred_bboxes_xy_max = per_image_targets[:,
6:8] + per_image_reg_preds[:,
2:4]
gt_bboxes_xy_min = per_image_targets[:, 6:8] - per_image_targets[:,
0:2]
gt_bboxes_xy_max = per_image_targets[:, 6:8] + per_image_targets[:,
2:4]
pred_bboxes = torch.cat([pred_bboxes_xy_min, pred_bboxes_xy_max],
axis=1)
gt_bboxes = torch.cat([gt_bboxes_xy_min, gt_bboxes_xy_max], axis=1)
overlap_area_top_left = torch.max(pred_bboxes[:, 0:2], gt_bboxes[:,
0:2])
overlap_area_bot_right = torch.min(pred_bboxes[:, 2:4], gt_bboxes[:,
2:4])
overlap_area_sizes = torch.clamp(overlap_area_bot_right -
overlap_area_top_left,
min=0)
overlap_area = overlap_area_sizes[:, 0] * overlap_area_sizes[:, 1]
# anchors and annotations convert format to [x1,y1,w,h]
pred_bboxes_w_h = pred_bboxes[:, 2:4] - pred_bboxes[:, 0:2] + 1
gt_bboxes_w_h = gt_bboxes[:, 2:4] - gt_bboxes[:, 0:2] + 1
# compute anchors_area and annotations_area
pred_bboxes_area = pred_bboxes_w_h[:, 0] * pred_bboxes_w_h[:, 1]
gt_bboxes_area = gt_bboxes_w_h[:, 0] * gt_bboxes_w_h[:, 1]
# compute union_area
union_area = pred_bboxes_area + gt_bboxes_area - overlap_area
union_area = torch.clamp(union_area, min=1e-4)
# compute ious between one image anchors and one image annotations
ious = overlap_area / union_area
enclose_area_top_left = torch.min(pred_bboxes[:, 0:2], gt_bboxes[:,
0:2])
enclose_area_bot_right = torch.max(pred_bboxes[:, 2:4], gt_bboxes[:,
2:4])
enclose_area_sizes = torch.clamp(enclose_area_bot_right -
enclose_area_top_left,
min=0)
enclose_area = enclose_area_sizes[:, 0] * enclose_area_sizes[:, 1]
enclose_area = torch.clamp(enclose_area, min=1e-4)
gious_loss = 1. - ious + (enclose_area - union_area) / enclose_area
gious_loss = torch.clamp(gious_loss, min=-1.0, max=1.0)
# use center_ness_targets as the weight of gious loss
gious_loss = gious_loss * center_ness_targets
gious_loss = gious_loss.sum() / positive_points_num
gious_loss = 2. * gious_loss
return gious_loss
最后乘以2是为了平衡回归loss与其他loss的数量级。
centerness使用bce loss进行优化。由于centerness loss的优化目标是不稳定的,在实际训练时会出现loss初期下降一点之后长期不再下降的情况,这个是正常的,不必担心。
centerness loss代码实现如下:
def compute_one_image_center_ness_loss(self, per_image_center_preds,
per_image_targets):
"""
compute one image center_ness loss(center ness loss)
per_image_center_preds:[points_num,4]
per_image_targets:[anchor_num,8]
"""
# only use positive points sample to compute center_ness loss
device = per_image_center_preds.device
per_image_center_preds = per_image_center_preds[
per_image_targets[:, 4] > 0]
per_image_targets = per_image_targets[per_image_targets[:, 4] > 0]
positive_points_num = per_image_targets.shape[0]
if positive_points_num == 0:
return torch.tensor(0.).to(device)
center_ness_targets = per_image_targets[:, 5:6]
center_ness_loss = -(
center_ness_targets * torch.log(per_image_center_preds) +
(1. - center_ness_targets) *
torch.log(1. - per_image_center_preds))
center_ness_loss = center_ness_loss.sum() / positive_points_num
return center_ness_loss
import torch
import torch.nn as nn
import torch.nn.functional as F
INF = 100000000
class FCOSLoss(nn.Module):
def __init__(self,
image_w,
image_h,
strides=[8, 16, 32, 64, 128],
mi=[[-1, 64], [64, 128], [128, 256], [256, 512], [512, INF]],
alpha=0.25,
gamma=2.,
epsilon=1e-4):
super(FCOSLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.epsilon = epsilon
self.image_w = image_w
self.image_h = image_h
self.strides = strides
self.mi = mi
def forward(self, cls_heads, reg_heads, center_heads, batch_positions,
annotations):
"""
compute cls loss, reg loss and center-ness loss in one batch
"""
cls_preds, reg_preds, center_preds, batch_targets = self.get_batch_position_annotations(
cls_heads, reg_heads, center_heads, batch_positions, annotations)
cls_preds = torch.sigmoid(cls_preds)
reg_preds = torch.exp(reg_preds)
center_preds = torch.sigmoid(center_preds)
batch_targets[:, :, 5:6] = torch.sigmoid(batch_targets[:, :, 5:6])
device = annotations.device
cls_loss, reg_loss, center_ness_loss = [], [], []
valid_image_num = 0
for per_image_cls_preds, per_image_reg_preds, per_image_center_preds, per_image_targets in zip(
cls_preds, reg_preds, center_preds, batch_targets):
positive_points_num = (
per_image_targets[per_image_targets[:, 4] > 0]).shape[0]
if positive_points_num == 0:
cls_loss.append(torch.tensor(0.).to(device))
reg_loss.append(torch.tensor(0.).to(device))
center_ness_loss.append(torch.tensor(0.).to(device))
else:
valid_image_num += 1
one_image_cls_loss = self.compute_one_image_focal_loss(
per_image_cls_preds, per_image_targets)
one_image_reg_loss = self.compute_one_image_giou_loss(
per_image_reg_preds, per_image_targets)
one_image_center_ness_loss = self.compute_one_image_center_ness_loss(
per_image_center_preds, per_image_targets)
cls_loss.append(one_image_cls_loss)
reg_loss.append(one_image_reg_loss)
center_ness_loss.append(one_image_center_ness_loss)
cls_loss = sum(cls_loss) / valid_image_num
reg_loss = sum(reg_loss) / valid_image_num
center_ness_loss = sum(center_ness_loss) / valid_image_num
return cls_loss, reg_loss, center_ness_loss
def compute_one_image_focal_loss(self, per_image_cls_preds,
per_image_targets):
"""
compute one image focal loss(cls loss)
per_image_cls_preds:[points_num,num_classes]
per_image_targets:[points_num,8]
"""
per_image_cls_preds = torch.clamp(per_image_cls_preds,
min=self.epsilon,
max=1. - self.epsilon)
num_classes = per_image_cls_preds.shape[1]
# generate 80 binary ground truth classes for each anchor
loss_ground_truth = F.one_hot(per_image_targets[:, 4].long(),
num_classes=num_classes + 1)
loss_ground_truth = loss_ground_truth[:, 1:]
loss_ground_truth = loss_ground_truth.float()
alpha_factor = torch.ones_like(per_image_cls_preds) * self.alpha
alpha_factor = torch.where(torch.eq(loss_ground_truth, 1.),
alpha_factor, 1. - alpha_factor)
pt = torch.where(torch.eq(loss_ground_truth, 1.), per_image_cls_preds,
1. - per_image_cls_preds)
focal_weight = alpha_factor * torch.pow((1. - pt), self.gamma)
bce_loss = -(
loss_ground_truth * torch.log(per_image_cls_preds) +
(1. - loss_ground_truth) * torch.log(1. - per_image_cls_preds))
one_image_focal_loss = focal_weight * bce_loss
one_image_focal_loss = one_image_focal_loss.sum()
positive_points_num = per_image_targets[
per_image_targets[:, 4] > 0].shape[0]
# according to the original paper,We divide the focal loss by the number of positive sample anchors
one_image_focal_loss = one_image_focal_loss / positive_points_num
return one_image_focal_loss
def compute_one_image_giou_loss(self, per_image_reg_preds,
per_image_targets):
"""
compute one image giou loss(reg loss)
per_image_reg_preds:[points_num,4]
per_image_targets:[anchor_num,8]
"""
# only use positive points sample to compute reg loss
device = per_image_reg_preds.device
per_image_reg_preds = per_image_reg_preds[per_image_targets[:, 4] > 0]
per_image_targets = per_image_targets[per_image_targets[:, 4] > 0]
positive_points_num = per_image_targets.shape[0]
if positive_points_num == 0:
return torch.tensor(0.).to(device)
center_ness_targets = per_image_targets[:, 5]
pred_bboxes_xy_min = per_image_targets[:,
6:8] - per_image_reg_preds[:,
0:2]
pred_bboxes_xy_max = per_image_targets[:,
6:8] + per_image_reg_preds[:,
2:4]
gt_bboxes_xy_min = per_image_targets[:, 6:8] - per_image_targets[:,
0:2]
gt_bboxes_xy_max = per_image_targets[:, 6:8] + per_image_targets[:,
2:4]
pred_bboxes = torch.cat([pred_bboxes_xy_min, pred_bboxes_xy_max],
axis=1)
gt_bboxes = torch.cat([gt_bboxes_xy_min, gt_bboxes_xy_max], axis=1)
overlap_area_top_left = torch.max(pred_bboxes[:, 0:2], gt_bboxes[:,
0:2])
overlap_area_bot_right = torch.min(pred_bboxes[:, 2:4], gt_bboxes[:,
2:4])
overlap_area_sizes = torch.clamp(overlap_area_bot_right -
overlap_area_top_left,
min=0)
overlap_area = overlap_area_sizes[:, 0] * overlap_area_sizes[:, 1]
# anchors and annotations convert format to [x1,y1,w,h]
pred_bboxes_w_h = pred_bboxes[:, 2:4] - pred_bboxes[:, 0:2] + 1
gt_bboxes_w_h = gt_bboxes[:, 2:4] - gt_bboxes[:, 0:2] + 1
# compute anchors_area and annotations_area
pred_bboxes_area = pred_bboxes_w_h[:, 0] * pred_bboxes_w_h[:, 1]
gt_bboxes_area = gt_bboxes_w_h[:, 0] * gt_bboxes_w_h[:, 1]
# compute union_area
union_area = pred_bboxes_area + gt_bboxes_area - overlap_area
union_area = torch.clamp(union_area, min=1e-4)
# compute ious between one image anchors and one image annotations
ious = overlap_area / union_area
enclose_area_top_left = torch.min(pred_bboxes[:, 0:2], gt_bboxes[:,
0:2])
enclose_area_bot_right = torch.max(pred_bboxes[:, 2:4], gt_bboxes[:,
2:4])
enclose_area_sizes = torch.clamp(enclose_area_bot_right -
enclose_area_top_left,
min=0)
enclose_area = enclose_area_sizes[:, 0] * enclose_area_sizes[:, 1]
enclose_area = torch.clamp(enclose_area, min=1e-4)
gious_loss = 1. - ious + (enclose_area - union_area) / enclose_area
gious_loss = torch.clamp(gious_loss, min=-1.0, max=1.0)
# use center_ness_targets as the weight of gious loss
gious_loss = gious_loss * center_ness_targets
gious_loss = gious_loss.sum() / positive_points_num
gious_loss = 2. * gious_loss
return gious_loss
def compute_one_image_center_ness_loss(self, per_image_center_preds,
per_image_targets):
"""
compute one image center_ness loss(center ness loss)
per_image_center_preds:[points_num,4]
per_image_targets:[anchor_num,8]
"""
# only use positive points sample to compute center_ness loss
device = per_image_center_preds.device
per_image_center_preds = per_image_center_preds[
per_image_targets[:, 4] > 0]
per_image_targets = per_image_targets[per_image_targets[:, 4] > 0]
positive_points_num = per_image_targets.shape[0]
if positive_points_num == 0:
return torch.tensor(0.).to(device)
center_ness_targets = per_image_targets[:, 5:6]
center_ness_loss = -(
center_ness_targets * torch.log(per_image_center_preds) +
(1. - center_ness_targets) *
torch.log(1. - per_image_center_preds))
center_ness_loss = center_ness_loss.sum() / positive_points_num
return center_ness_loss
def get_batch_position_annotations(self, cls_heads, reg_heads,
center_heads, batch_positions,
annotations):
"""
Assign a ground truth target for each position on feature map
"""
device = annotations.device
batch_mi = []
for reg_head, mi in zip(reg_heads, self.mi):
mi = torch.tensor(mi).to(device)
B, H, W, _ = reg_head.shape
per_level_mi = torch.zeros(B, H, W, 2).to(device)
per_level_mi = per_level_mi + mi
batch_mi.append(per_level_mi)
cls_preds,reg_preds,center_preds,all_points_position,all_points_mi=[],[],[],[],[]
for cls_pred, reg_pred, center_pred, per_level_position, per_level_mi in zip(
cls_heads, reg_heads, center_heads, batch_positions, batch_mi):
cls_pred = cls_pred.view(cls_pred.shape[0], -1, cls_pred.shape[-1])
reg_pred = reg_pred.view(reg_pred.shape[0], -1, reg_pred.shape[-1])
center_pred = center_pred.view(center_pred.shape[0], -1,
center_pred.shape[-1])
per_level_position = per_level_position.view(
per_level_position.shape[0], -1, per_level_position.shape[-1])
per_level_mi = per_level_mi.view(per_level_mi.shape[0], -1,
per_level_mi.shape[-1])
cls_preds.append(cls_pred)
reg_preds.append(reg_pred)
center_preds.append(center_pred)
all_points_position.append(per_level_position)
all_points_mi.append(per_level_mi)
cls_preds = torch.cat(cls_preds, axis=1)
reg_preds = torch.cat(reg_preds, axis=1)
center_preds = torch.cat(center_preds, axis=1)
all_points_position = torch.cat(all_points_position, axis=1)
all_points_mi = torch.cat(all_points_mi, axis=1)
batch_targets = []
for per_image_position, per_image_mi, per_image_annotations in zip(
all_points_position, all_points_mi, annotations):
per_image_annotations = per_image_annotations[
per_image_annotations[:, 4] >= 0]
points_num = per_image_position.shape[0]
if per_image_annotations.shape[0] == 0:
# 6:l,t,r,b,class_index,center-ness_gt
per_image_targets = torch.zeros([points_num, 6], device=device)
else:
annotaion_num = per_image_annotations.shape[0]
per_image_gt_bboxes = per_image_annotations[:, 0:4]
candidates = torch.zeros([points_num, annotaion_num, 4],
device=device)
candidates = candidates + per_image_gt_bboxes.unsqueeze(0)
per_image_position = per_image_position.unsqueeze(1).repeat(
1, annotaion_num, 2)
candidates[:, :,
0:2] = per_image_position[:, :,
0:2] - candidates[:, :,
0:2]
candidates[:, :,
2:4] = candidates[:, :,
2:4] - per_image_position[:, :,
2:4]
candidates_min_value, _ = candidates.min(axis=-1, keepdim=True)
sample_flag = (candidates_min_value[:, :, 0] >
0).int().unsqueeze(-1)
# get all negative reg targets which points ctr out of gt box
candidates = candidates * sample_flag
# get all negative reg targets which assign ground turth not in range of mi
candidates_max_value, _ = candidates.max(axis=-1, keepdim=True)
per_image_mi = per_image_mi.unsqueeze(1).repeat(
1, annotaion_num, 1)
m1_negative_flag = (candidates_max_value[:, :, 0] >
per_image_mi[:, :, 0]).int().unsqueeze(-1)
candidates = candidates * m1_negative_flag
m2_negative_flag = (candidates_max_value[:, :, 0] <
per_image_mi[:, :, 1]).int().unsqueeze(-1)
candidates = candidates * m2_negative_flag
final_sample_flag = candidates.sum(axis=-1).sum(axis=-1)
final_sample_flag = final_sample_flag > 0
positive_index = (final_sample_flag == True).nonzero().squeeze(
dim=-1)
# if no assign positive sample
if len(positive_index) == 0:
del candidates
# 6:l,t,r,b,class_index,center-ness_gt
per_image_targets = torch.zeros([points_num, 6],
device=device)
else:
positive_candidates = candidates[positive_index]
del candidates
sample_box_gts = per_image_annotations[:, 0:4].unsqueeze(0)
sample_box_gts = sample_box_gts.repeat(
positive_candidates.shape[0], 1, 1)
sample_class_gts = per_image_annotations[:, 4].unsqueeze(
-1).unsqueeze(0)
sample_class_gts = sample_class_gts.repeat(
positive_candidates.shape[0], 1, 1)
# 6:l,t,r,b,class_index,center-ness_gt
per_image_targets = torch.zeros([points_num, 6],
device=device)
if positive_candidates.shape[1] == 1:
# if only one candidate for each positive sample
# assign l,t,r,b,class_index,center_ness_gt ground truth
# class_index value from 1 to 80 represent 80 positive classes
# class_index value 0 represenet negative class
positive_candidates = positive_candidates.squeeze(1)
sample_class_gts = sample_class_gts.squeeze(1)
per_image_targets[positive_index,
0:4] = positive_candidates
per_image_targets[positive_index,
4:5] = sample_class_gts + 1
l, t, r, b = per_image_targets[
positive_index, 0:1], per_image_targets[
positive_index, 1:2], per_image_targets[
positive_index,
2:3], per_image_targets[positive_index,
3:4]
per_image_targets[positive_index, 5:6] = torch.sqrt(
(torch.min(l, r) / torch.max(l, r)) *
(torch.min(t, b) / torch.max(t, b)))
else:
# if a positive point sample have serveral object candidates,then choose the smallest area object candidate as the ground turth for this positive point sample
gts_w_h = sample_box_gts[:, :,
2:4] - sample_box_gts[:, :,
0:2]
gts_area = gts_w_h[:, :, 0] * gts_w_h[:, :, 1]
positive_candidates_value = positive_candidates.sum(
axis=2)
# make sure all negative candidates areas==100000000,thus .min() operation wouldn't choose negative candidates
INF = 100000000
inf_tensor = torch.ones_like(gts_area) * INF
gts_area = torch.where(
torch.eq(positive_candidates_value, 0.),
inf_tensor, gts_area)
# get the smallest object candidate index
_, min_index = gts_area.min(axis=1)
candidate_indexes = (
torch.linspace(1, positive_candidates.shape[0],
positive_candidates.shape[0]) -
1).long()
final_candidate_reg_gts = positive_candidates[
candidate_indexes, min_index, :]
final_candidate_cls_gts = sample_class_gts[
candidate_indexes, min_index]
# assign l,t,r,b,class_index,center_ness_gt ground truth
per_image_targets[positive_index,
0:4] = final_candidate_reg_gts
per_image_targets[positive_index,
4:5] = final_candidate_cls_gts + 1
l, t, r, b = per_image_targets[
positive_index, 0:1], per_image_targets[
positive_index, 1:2], per_image_targets[
positive_index,
2:3], per_image_targets[positive_index,
3:4]
per_image_targets[positive_index, 5:6] = torch.sqrt(
(torch.min(l, r) / torch.max(l, r)) *
(torch.min(t, b) / torch.max(t, b)))
per_image_targets = per_image_targets.unsqueeze(0)
batch_targets.append(per_image_targets)
batch_targets = torch.cat(batch_targets, axis=0)
batch_targets = torch.cat([batch_targets, all_points_position], axis=2)
# batch_targets shape:[batch_size, points_num, 8],8:l,t,r,b,class_index,center-ness_gt,point_ctr_x,point_ctr_y
return cls_preds, reg_preds, center_preds, batch_targets
if __name__ == '__main__':
from fcos import FCOS
net = FCOS(resnet_type="resnet50")
image_h, image_w = 600, 600
cls_heads, reg_heads, center_heads, batch_positions = net(
torch.autograd.Variable(torch.randn(3, 3, image_h, image_w)))
annotations = torch.FloatTensor([[[113, 120, 183, 255, 5],
[13, 45, 175, 210, 2]],
[[11, 18, 223, 225, 1],
[-1, -1, -1, -1, -1]],
[[-1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1]]])
loss = FCOSLoss(image_w, image_h)
cls_loss, reg_loss, center_loss = loss(cls_heads, reg_heads, center_heads,
batch_positions, annotations)
print("2222", cls_loss, reg_loss, center_loss)