将[x1,x2,y1,y2]和[cx,cy,w,h]进行互相转换
def xy_to_cxcy(xy):
'''
xy: [[x1,x2,y1,y2], ...]
'''
return torch.cat([(xy[:,:2]+xy[:,2:])/2, xy[:,2:]-xy[:,:2]],1)
def cxcy_to_xy(cxcy):
# cxcy: [cx, cy, w, h]
return torch.cat([cxcy[:,:2]-cxcy[:,2:]/2, cxcy[:,:2]+cxcy[:,2:]/2], 1)
def find_intersection(set_1, set2):
# set: [[x1,y1,x2,y2], ...]
# set_1 (n1,4)
# set_2 (n2,4)
lower_bounds = torch.max(set_1[:,:2].unsqueeze(1), set_2[:,:2].unsqueeze(0)) # (n1, n2, 2)
upper_bounds = torch.min(set_1[:,2:].unsqueeze(1), set_2[:,2:].unsqueeze(0)) # (n1, n2, 2)
intersection_dims = torch.clamp(upper_bounds - lower_bounds, min=0)
return intersection_dims[:,:,0] * intersection_dims[:,:,1] # (n1, n2)
def find_jaccard_overlap(set_1, set_2):
intersection = find_intersection(set_1, set_2) # (n1, n2)
areas_set_1 = (set_1[:, 2]-set_1[:,0]) * (set_1[:,3]-set_1[:,1]) # (n1)
areas_set_2 = (set_2[:, 2]-set_2[:,0]) * (set_2[:,3]-set_2[:,1]) # (n2)
union = intersection / (areas_set_1.unsqueeze(1) + areas_set_2.unsqueeze(0) - intersection) # (n1, n2)
return union