所有代码已上传到本人github repository:https://github.com/zgcr/pytorch-ImageNet-CIFAR-COCO-VOC-training
如果觉得有用,请点个star哟!
下列代码均在pytorch1.4版本中测试过,确认正确无误。
除了之前在从零实现RetinaNet(终)中使用的NMS后处理方法,我又尝试了两种NMS后处理方法。一种是YOLACT(https://arxiv.org/pdf/1904.02689.pdf)中提出的Fast NMS,另一种是torchvision自带的torchvision.ops.nms()。注意torchvision.ops中的op均不支持TorchScript。
Fast NMS把传统NMS的迭代式计算方式改成了用矩阵计算一次性得出结果的计算方式。传统的NMS算法先把所有框按分类得分从大到小排序,然后进行迭代,每次迭代先保留最高分类得分的框,然后计算其他框与该框的IoU,对于IoU大于阈值的框就删除,反复迭代直到没有候选框为止。
Fast NMS算法先将所有框按分类得分从大到小排序,然后计算出所有框两两框之间的IoU,得到一个对称矩阵。然后将矩阵上三角化,且左上到右下的对角线元素也置为0(每个框与自己的IoU),然后按维度0从矩阵取最大IoU,再判断每个IoU是否大于过滤阈值,对于大于阈值的框进行过滤。实际上就是从最高得分的框开始每个框找到一个与该框最大IoU的框,如果这个框IoU超过了阈值就过滤掉。并且由于矩阵是上三角矩阵,后面的框过滤时前面的框不会对其造成干扰。
Fast NMS代码实现如下:
# 用这两个函数替换原来RetinaDecoder中的NMS即可。
def nms(self, one_image_scores, one_image_classes, one_image_pred_bboxes):
"""
one_image_scores:[anchor_nums],4:classification predict scores
one_image_classes:[anchor_nums],class indexes for predict scores
one_image_pred_bboxes:[anchor_nums,4],4:x_min,y_min,x_max,y_max
"""
device = one_image_scores.device
final_scores = (-1) * torch.ones(
(self.max_detection_num, ), device=device)
final_classes = (-1) * torch.ones(
(self.max_detection_num, ), device=device)
final_pred_bboxes = (-1) * torch.ones(
(self.max_detection_num, 4), device=device)
if one_image_scores.shape[0] == 0:
return final_scores, final_classes, final_pred_bboxes
# Sort boxes
sorted_one_image_scores, sorted_one_image_scores_indexes = torch.sort(
one_image_scores, descending=True)
sorted_one_image_classes = one_image_classes[
sorted_one_image_scores_indexes]
sorted_one_image_pred_bboxes = one_image_pred_bboxes[
sorted_one_image_scores_indexes]
ious = self.box_iou(sorted_one_image_pred_bboxes,
sorted_one_image_pred_bboxes)
#Triangulation on matrix
ious = torch.triu(ious, diagonal=1)
keep = ious.max(dim=0)[0]
keep = keep < self.nms_threshold
keep_scores = sorted_one_image_scores[keep]
keep_classes = sorted_one_image_classes[keep]
keep_pred_bboxes = sorted_one_image_pred_bboxes[keep]
final_detection_num = min(self.max_detection_num, keep_scores.shape[0])
final_scores[0:final_detection_num] = keep_scores[
0:final_detection_num]
final_classes[0:final_detection_num] = keep_classes[
0:final_detection_num]
final_pred_bboxes[0:final_detection_num, :] = keep_pred_bboxes[
0:final_detection_num, :]
return final_scores, final_classes, final_pred_bboxes
def box_iou(self, boxes1, boxes2):
"""
boxes1:[N, 4]
boxes2:[M, 4]
ious:[N, M]
"""
area1 = (boxes1.t()[2] - boxes1.t()[0]) * (boxes1.t()[3] -
boxes1.t()[1])
area2 = (boxes2.t()[2] - boxes2.t()[0]) * (boxes2.t()[3] -
boxes2.t()[1])
# boxes1[:, None, :2] shape:[4125, 1, 2], boxes2[:, :2] shape:[4125, 2]
overlap_area_left_top = torch.max(boxes1[:, None, :2], boxes2[:, :2])
overlap_area_right_bot = torch.min(boxes1[:, None, 2:], boxes2[:, 2:])
overlap_area_sizes = (overlap_area_right_bot -
overlap_area_left_top).clamp(min=0)
overlap_area = overlap_area_sizes[:, :, 0] * overlap_area_sizes[:, :,
1]
ious = overlap_area / (area1[:, None] + area2 - overlap_area)
return ious
torchvision.ops.nms的NMS后处理做法与我在从零实现RetinaNet(终)中提到的NMS后处理方法完全一致,但是因为是用C++实现的,速度上要比我实现的NMS要快。
使用torchvision.ops.nms时RetinaDecoder的代码实现如下:
class RetinaDecoder(nn.Module):
def __init__(self,
image_w,
image_h,
top_n=1000,
min_score_threshold=0.05,
nms_threshold=0.5,
max_detection_num=100):
super(RetinaDecoder, self).__init__()
self.image_w = image_w
self.image_h = image_h
self.top_n = top_n
self.min_score_threshold = min_score_threshold
self.nms_threshold = nms_threshold
self.max_detection_num = max_detection_num
def forward(self, cls_heads, reg_heads, batch_anchors):
device = cls_heads[0].device
with torch.no_grad():
filter_scores,filter_score_classes,filter_reg_heads,filter_batch_anchors=[],[],[],[]
for per_level_cls_head, per_level_reg_head, per_level_anchor in zip(
cls_heads, reg_heads, batch_anchors):
scores, score_classes = torch.max(per_level_cls_head, dim=2)
if scores.shape[1] >= self.top_n:
scores, indexes = torch.topk(scores,
self.top_n,
dim=1,
largest=True,
sorted=True)
score_classes = torch.gather(score_classes, 1, indexes)
per_level_reg_head = torch.gather(
per_level_reg_head, 1,
indexes.unsqueeze(-1).repeat(1, 1, 4))
per_level_anchor = torch.gather(
per_level_anchor, 1,
indexes.unsqueeze(-1).repeat(1, 1, 4))
filter_scores.append(scores)
filter_score_classes.append(score_classes)
filter_reg_heads.append(per_level_reg_head)
filter_batch_anchors.append(per_level_anchor)
filter_scores = torch.cat(filter_scores, axis=1)
filter_score_classes = torch.cat(filter_score_classes, axis=1)
filter_reg_heads = torch.cat(filter_reg_heads, axis=1)
filter_batch_anchors = torch.cat(filter_batch_anchors, axis=1)
batch_scores, batch_classes, batch_pred_bboxes = [], [], []
for per_image_scores, per_image_score_classes, per_image_reg_heads, per_image_anchors in zip(
filter_scores, filter_score_classes, filter_reg_heads,
filter_batch_anchors):
pred_bboxes = self.snap_tx_ty_tw_th_reg_heads_to_x1_y1_x2_y2_bboxes(
per_image_reg_heads, per_image_anchors)
score_classes = per_image_score_classes[
per_image_scores > self.min_score_threshold].float()
pred_bboxes = pred_bboxes[
per_image_scores > self.min_score_threshold].float()
scores = per_image_scores[
per_image_scores > self.min_score_threshold].float()
one_image_scores = (-1) * torch.ones(
(self.max_detection_num, ), device=device)
one_image_classes = (-1) * torch.ones(
(self.max_detection_num, ), device=device)
one_image_pred_bboxes = (-1) * torch.ones(
(self.max_detection_num, 4), device=device)
if scores.shape[0] != 0:
# Sort boxes
sorted_scores, sorted_indexes = torch.sort(scores,
descending=True)
sorted_score_classes = score_classes[sorted_indexes]
sorted_pred_bboxes = pred_bboxes[sorted_indexes]
keep = nms(sorted_pred_bboxes, sorted_scores,
self.nms_threshold)
keep_scores = sorted_scores[keep]
keep_classes = sorted_score_classes[keep]
keep_pred_bboxes = sorted_pred_bboxes[keep]
final_detection_num = min(self.max_detection_num,
keep_scores.shape[0])
one_image_scores[0:final_detection_num] = keep_scores[
0:final_detection_num]
one_image_classes[0:final_detection_num] = keep_classes[
0:final_detection_num]
one_image_pred_bboxes[
0:final_detection_num, :] = keep_pred_bboxes[
0:final_detection_num, :]
one_image_scores = one_image_scores.unsqueeze(0)
one_image_classes = one_image_classes.unsqueeze(0)
one_image_pred_bboxes = one_image_pred_bboxes.unsqueeze(0)
batch_scores.append(one_image_scores)
batch_classes.append(one_image_classes)
batch_pred_bboxes.append(one_image_pred_bboxes)
batch_scores = torch.cat(batch_scores, axis=0)
batch_classes = torch.cat(batch_classes, axis=0)
batch_pred_bboxes = torch.cat(batch_pred_bboxes, axis=0)
# batch_scores shape:[batch_size,max_detection_num]
# batch_classes shape:[batch_size,max_detection_num]
# batch_pred_bboxes shape[batch_size,max_detection_num,4]
return batch_scores, batch_classes, batch_pred_bboxes
def snap_tx_ty_tw_th_reg_heads_to_x1_y1_x2_y2_bboxes(
self, reg_heads, anchors):
"""
snap reg heads to pred bboxes
reg_heads:[anchor_nums,4],4:[tx,ty,tw,th]
anchors:[anchor_nums,4],4:[x_min,y_min,x_max,y_max]
"""
anchors_wh = anchors[:, 2:] - anchors[:, :2]
anchors_ctr = anchors[:, :2] + 0.5 * anchors_wh
device = anchors.device
factor = torch.tensor([[0.1, 0.1, 0.2, 0.2]]).to(device)
reg_heads = reg_heads * factor
pred_bboxes_wh = torch.exp(reg_heads[:, 2:]) * anchors_wh
pred_bboxes_ctr = reg_heads[:, :2] * anchors_wh + anchors_ctr
pred_bboxes_x_min_y_min = pred_bboxes_ctr - 0.5 * pred_bboxes_wh
pred_bboxes_x_max_y_max = pred_bboxes_ctr + 0.5 * pred_bboxes_wh
pred_bboxes = torch.cat(
[pred_bboxes_x_min_y_min, pred_bboxes_x_max_y_max], axis=1)
pred_bboxes = pred_bboxes.int()
pred_bboxes[:, 0] = torch.clamp(pred_bboxes[:, 0], min=0)
pred_bboxes[:, 1] = torch.clamp(pred_bboxes[:, 1], min=0)
pred_bboxes[:, 2] = torch.clamp(pred_bboxes[:, 2],
max=self.image_w - 1)
pred_bboxes[:, 3] = torch.clamp(pred_bboxes[:, 3],
max=self.image_h - 1)
# pred bboxes shape:[anchor_nums,4]
return pred_bboxes
下表中测试的都是同一个模型,即从零实现RetinaNet(终)中的ResNet50-RetinaNet-myresize667-fastdecode。fastdecode即该文中提到的后处理方法。测试时batch=1,resize=667,测完COCO2017_val中所有的图片的时间再除以图片总数就是speed。speed速度单位为ms。
测试时均使用单张GTX 1070 Max-Q。
Network | NMS method | epoch12-mAP-speed |
---|---|---|
ResNet50-RetinaNet-myresize667 | fastdecode | 0.293,154 |
ResNet50-RetinaNet-myresize667 | fast nms | 0.282,128 |
ResNet50-RetinaNet-myresize667 | torchvision.ops.nms() | 0.293,118 |
可以看到torchvision.ops.nms()的速度最快,而且各项性能指标与我实现的fastdecode是完全一样的。fast nms速度要比我的fastdecode要快,但是由于其YOLACT中是用在分割上的,用在检测上掉点幅度较大。