例子是bbox_overlaps计算anchors和gts的iou。
输入: anchors: (N, 4) ndarray of float
gt_boxes: (K, 4) ndarray of float
输出: overlaps: (N, K) ndarray of overlap between boxes and query_boxes
首先想到for i in range(N)
for j in range(K)
overlaps[i][i] = ****
矩阵思维:创建新的维度
#anchors(N, 4) _>boxes(N, K, 4)
#gt_boxes: (K, 4) _> query_boxes(N, K, 4)
def bbox_overlaps(anchors, gt_boxes):
"""
anchors: (N, 4) ndarray of float
gt_boxes: (K, 4) ndarray of float
overlaps: (N, K) ndarray of overlap between boxes and query_boxes
"""
N = anchors.size(0)
K = gt_boxes.size(0)
#gt_boxes_area (1, K)
gt_boxes_area = ((gt_boxes[:,2] - gt_boxes[:,0] + 1) *
(gt_boxes[:,3] - gt_boxes[:,1] + 1)).view(1, K)
#anchors_area (N, 1)
anchors_area = ((anchors[:,2] - anchors[:,0] + 1) *
(anchors[:,3] - anchors[:,1] + 1)).view(N, 1)
#anchors(N, 4) _>boxes(N, K, 4)
boxes = anchors.view(N, 1, 4).expand(N, K, 4)
#gt_boxes: (K, 4) _> query_boxes(N, K, 4)
query_boxes = gt_boxes.view(1, K, 4).expand(N, K, 4)
iw = (torch.min(boxes[:,:,2], query_boxes[:,:,2]) -
torch.max(boxes[:,:,0], query_boxes[:,:,0]) + 1)
iw[iw < 0] = 0
ih = (torch.min(boxes[:,:,3], query_boxes[:,:,3]) -
torch.max(boxes[:,:,1], query_boxes[:,:,1]) + 1)
ih[ih < 0] = 0
#broad (1, K)+(N, 1)= (N, K) - (N, K)
ua = anchors_area + gt_boxes_area - (iw * ih)
overlaps = iw * ih / ua
return overlaps