Pytorch中目标检测内容学习

1. 坐标转换

将[x1,x2,y1,y2]和[cx,cy,w,h]进行互相转换

def xy_to_cxcy(xy):
	'''
	xy: [[x1,x2,y1,y2], ...]
	'''
	return torch.cat([(xy[:,:2]+xy[:,2:])/2, xy[:,2:]-xy[:,:2]],1)

def cxcy_to_xy(cxcy):
	# cxcy: [cx, cy, w, h]
	return torch.cat([cxcy[:,:2]-cxcy[:,2:]/2, cxcy[:,:2]+cxcy[:,2:]/2], 1)

2. IoU计算

def find_intersection(set_1, set2):
	# set: [[x1,y1,x2,y2], ...]
	# set_1 (n1,4)
	# set_2 (n2,4)
	lower_bounds = torch.max(set_1[:,:2].unsqueeze(1), set_2[:,:2].unsqueeze(0)) # (n1, n2, 2)
	upper_bounds = torch.min(set_1[:,2:].unsqueeze(1), set_2[:,2:].unsqueeze(0)) # (n1, n2, 2)
	intersection_dims = torch.clamp(upper_bounds - lower_bounds, min=0)
	return intersection_dims[:,:,0] * intersection_dims[:,:,1] # (n1, n2)

def find_jaccard_overlap(set_1, set_2):
	intersection = find_intersection(set_1, set_2) # (n1, n2)
	areas_set_1 = (set_1[:, 2]-set_1[:,0]) * (set_1[:,3]-set_1[:,1]) # (n1)
	areas_set_2 = (set_2[:, 2]-set_2[:,0]) * (set_2[:,3]-set_2[:,1]) # (n2)
	union = intersection / (areas_set_1.unsqueeze(1) + areas_set_2.unsqueeze(0) - intersection) # (n1, n2)
	return union

你可能感兴趣的:(PyTorch,目标检测,pytorch,目标检测)