CenterNet (objects as points)代码阅读笔记

目录

前言

主函数部分main.py

数据部分

datasets/dataset_factory.py

datasets/dataset/coco.py

datasets/sample/ctdet.py

utils/image.py

模型

models/model.py

models/networks/pose_dla_dcn.py

训练部分

trains/train_factory.py

trains/ctdet.py

models/losses.py

models/utils.py

测试部分demo.py

detectors/detector_factoey.py

detectors/ctdet.py

detectors/base_detector.py

models/decode.py

multi pose


前言

CenterNet是基于关键点(中心点)检测的目标检测方法,论文详情请参照https://blog.csdn.net/LXX516/article/details/106251090。代码地址为:https://github.com/xingyizhou/CenterNet,源代码像个超市,功能太多了,我简要地将一些我认为比较重要的功能拿出来注释一下,因为最终的目的还是把这些功能搬到自己的数据与应用上去。

主函数部分main.py

主函数main.py中最主要的是三部分:数据准备、模型建立和训练过程(损失函数计算)。

from opts import opts
from models.model import create_model, load_model, save_model
from datasets.dataset_factory import get_dataset
from trains.train_factory import train_factory
#'假设任务和数据集分别是ctdet和coco'
#准备数据
Dataset = get_dataset(opt.dataset, opt.task)
opt = opts().update_dataset_info_and_set_heads(opt, Dataset)
# 建立模型
model = create_model(opt.arch, opt.heads, opt.head_conv)
#print(opt.arch, opt.heads, opt.head_conv)
#dla_34 {"hm": 80, "wh": 2, "reg": 2} 256,
#对于目标检测而言,需要预测热力图,回归长宽度和偏移,所以在backbone之上有3个分支
#训练
Trainer = train_factory[opt.task]
trainer = Trainer(opt, model, optimizer)
trainer.set_device(opt.gpus, opt.chunk_sizes, opt.device)
#...
log_dict_train, _ = trainer.train(epoch, train_loader)

数据部分

datasets/dataset_factory.py

from .sample.ctdet import CTDetDataset
from .dataset.coco import COCO
def get_dataset(dataset, task):
	class Dataset(dataset_factory[dataset], _sample_factory[task]):
		pass
	return Dataset

主程序中返回的Dataset,该Dataset继承了两个父类,根据不同的任务和数据集决定"ctdet": CTDetDataset和"coco": COCO.

datasets/dataset/coco.py

class COCO(data.Dataset):
	#定义了属于类的变量,num_classes,default_resolution,mean,std
	num_classes = 80 
	default_resolution = [512, 512]
	mean = np.array([0.40789654, 0.44719302, 0.47026115],dtype=np.float32).reshape(1, 1, 3)
	std  = np.array([0.28863828, 0.27408164, 0.27809835],dtype=np.float32).reshape(1, 1, 3)
	def __init__(self, opt, split):
		super(COCO, self).__init__()
		self.data_dir = os.path.join(opt.data_dir, "coco")
		self.img_dir = os.path.join(self.data_dir, "{}2017".format(split))
		#...
		self.max_objs = 128#'假设每张图片最大的目标个数不超过128'
		#...
		#'获取图像信息,记录在self.images中'
		self.coco = coco.COCO(self.annot_path)
		self.images = self.coco.getImgIds()
		self.num_samples = len(self.images)

datasets/sample/ctdet.py

class CTDetDataset(data.Dataset):
    #...
	#'该类的主要功能是主要实现__getitem__方法,以便在train的时候被DataLoader加载'
	def __getitem__(self, index):
		#...
		#'读入图像和标注信息'
		num_objs = min(len(anns), self.max_objs) #'anns的个数就是图中包含的目标的个数,不能超过设置的最大数128'
		height, width = img.shape[0], img.shape[1]
		c = np.array([img.shape[1] / 2., img.shape[0] / 2.], dtype=np.float32) #'计算中心点'
		s = max(img.shape[0], img.shape[1]) * 1.0 #'最长的一条边'
		input_h, input_w = self.opt.input_h, self.opt.input_w #'确定的输入的尺寸'
		
		flipped = False
			if self.split == "train":
				if not self.opt.not_rand_crop:
					#'长边缩放,随机裁剪,主要是随机变换中心点的位置,变换的范围是图像的四边向内移动border距离。border的取值是如果图像尺寸超过256,border为128,否则为64.
					s = s * np.random.choice(np.arange(0.6, 1.4, 0.1))
					w_border = self._get_border(128, img.shape[1])
					h_border = self._get_border(128, img.shape[0])
					c[0] = np.random.randint(low=w_border, high=img.shape[1] - w_border)
					c[1] = np.random.randint(low=h_border, high=img.shape[0] - h_border)
				else:
					#'缩放和平移,根据相应的因子确定新的中心点位置和长边'
					#'与裁剪的区别是是啥?'
					sf = self.opt.scale
					cf = self.opt.shift
					c[0] += s * np.clip(np.random.randn()*cf, -2*cf, 2*cf)
					c[1] += s * np.clip(np.random.randn()*cf, -2*cf, 2*cf)
					s = s * np.clip(np.random.randn()*sf + 1, 1 - sf, 1 + sf)
		
				if np.random.random() < self.opt.flip:
					#'水平翻转'
					flipped = True
					img = img[:, ::-1, :]
					c[0] =  width - c[0] - 1
		
		#'确定这些参数之后,对输入进行仿射变换'
		trans_input = get_affine_transform(c, s, 0, [input_w, input_h])
		inp=cv2.warpAffine(img, trans_input,(input_w,input_h),           flags=cv2.INTER_LINEAR)
		inp = (inp.astype(np.float32) / 255.)
        #数据增强
		if self.split == "train" and not self.opt.no_color_aug:
			color_aug(self._data_rng, inp, self._eig_val, self._eig_vec)
		inp = (inp - self.mean) / self.std
		inp = inp.transpose(2, 0, 1)
		
        #'为输出做仿射变换做准备'
		output_h = input_h // self.opt.down_ratio
		output_w = input_w // self.opt.down_ratio
		num_classes = self.num_classes
		trans_output = get_affine_transform(c, s, 0, [output_w, output_h])	
		#...
		
		#'准备数据集返回变量,主要是真值标签,如热力图、目标长度和宽度、偏移量,如表1所示'
		gt_det = []
		for k in range(num_objs):#'按实际有多少个目标来'
			if flipped:
				bbox[[0, 2]] = width - bbox[[2, 0]]-1
			#'输出的仿射变换,这些仿射变换由相应的函数完成,在实际编写中只需确定诸如中心点c、长边长度s等参数,便于在自己的数据集中使用'
			bbox[:2] = affine_transform(bbox[:2], trans_output)
			bbox[2:] = affine_transform(bbox[2:], trans_output)
			bbox[[0, 2]] = np.clip(bbox[[0, 2]], 0, output_w - 1)
			bbox[[1, 3]] = np.clip(bbox[[1, 3]], 0, output_h - 1)
			h, w = bbox[3] - bbox[1], bbox[2] - bbox[0]
		
			if h > 0 and w > 0:
				radius = gaussian_radius((math.ceil(h), math.ceil(w)))#'这个热力图的半径是根据目标的尺寸确定的'
				radius = max(0, int(radius))
				radius = self.opt.hm_gauss if self.opt.mse_loss else radius
				ct = np.array([(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2], dtype=np.float32)#'新的中心点'
				ct_int = ct.astype(np.int32)
				draw_gaussian(hm[cls_id], ct_int, radius)#'绘制热力图,绘制在其所属类别的通道上'
				wh[k] = 1. * w, 1. * h#[128,2] #'若object的个数不够128,那么剩下的未填充的位置依然是0'
				ind[k] = ct_int[1] * output_w + ct_int[0]#'中心点的位置,用一维表示h*W+w'
				reg[k] = ct - ct_int#'由取整引起的误差'
				reg_mask[k] = 1#'有目标存在的位置,设为1'
				cat_spec_wh[k, cls_id * 2: cls_id * 2 + 2] = wh[k]#类间长宽度不共享
				cat_spec_mask[k, cls_id * 2: cls_id * 2 + 2] = 1
				
				if self.opt.dense_wh:#False
					draw_dense_reg(dense_wh, hm.max(axis=0), ct_int, wh[k], radius)#???
				gt_det.append([ct[0] - w / 2, ct[1] - h / 2, ct[0] + w / 2, ct[1] + h / 2, 1, cls_id])#???
		ret = {"input": inp, "hm": hm, "reg_mask": reg_mask, "ind": ind, "wh": wh}
表 1: 数据集返回的真值标签变量
名称 作用 大小 类型
hm 热力图 [C,H,W] float32
wh 目标的长度和宽度 [128,2] float32
dense_wh 直接回归长宽度的map [2,H,W] float32
reg 下采样取整引起的偏移 [128,2] float32
ind 中心点的位置(h*W+w),一维的表示方式 [128] int64
reg_mask 固定长度的表示下是否存在关键点,最多128个目标 [128] uint8
cat_spec_wh 长宽度预测类间不共享?没用到 [128,C*2] float32
cat_spec_mask [128,C*2] uint8

utils/image.py

M=cv2.GetAffineTransform(src, dst):

  •     src:原始图像中三个点的坐标
  •     dst:变换后的三个点的坐标
  •     M:根据三个对应点求出仿射变换矩阵M,然后使用函数cv2.warpAffine()和M对原始图像进行变换

所以函数get_affine_transform()的功能是:

  •     首先传入记载原始图像缩放、裁剪、平移、旋转等信息的参数c,s
  •     传入记载目标图像(即输出图像)的信息,如w,h
  •     然后根据相应的规则(我也不是很清楚)确定输入到输出的变换的3对点,求出变换矩阵trans
  •     对原始输入图像就可用cv2.warpAffine()和trans求解出变换后的图像
  •     对于像点的坐标[x,y]就可用affine_transform()函数求出变换后的坐标,其实就是参数矩阵与原始坐标的矩阵乘积操作
def get_affine_transform(center,scale,rot,output_size,shift=np.array([0, 0], dtype=np.float32),inv=0):
	if not isinstance(scale, np.ndarray) and not isinstance(scale, list):
		scale = np.array([scale, scale], dtype=np.float32)	
	scale_tmp = scale#[s,s]
	src_w = scale_tmp[0]#s
	dst_w = output_size[0]
	dst_h = output_size[1]
	
	rot_rad = np.pi * rot / 180#0
	src_dir = get_dir([0, src_w * -0.5], rot_rad)
	dst_dir = np.array([0, dst_w * -0.5], np.float32)
	
	src = np.zeros((3, 2), dtype=np.float32)
	dst = np.zeros((3, 2), dtype=np.float32)
	src[0, :] = center + scale_tmp * shift#'平移?'
	src[1, :] = center + src_dir + scale_tmp * shift#'平移+旋转?'
	dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
	dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5], np.float32) + dst_dir
	
	src[2:, :] = get_3rd_point(src[0, :], src[1, :])
	dst[2:, :] = get_3rd_point(dst[0, :], dst[1, :])
    #确定输入输出相对应的3对点	
	if inv:
		trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
	else:
		trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
	return trans

def affine_transform(pt, t):
	new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32).T
	new_pt = np.dot(t, new_pt)#矩阵乘积
	return new_pt[:2]

模型

models/model.py

选用的模型是dla34.

from .networks.pose_dla_dcn import get_pose_net as get_dla_dcn
def create_model(arch, heads, head_conv):
	num_layers = int(arch[arch.find("_") + 1:]) if "_" in arch else 0
	arch = arch[:arch.find("_")] if "_" in arch else arch
	get_model = _model_factory[arch]# pose_dla_dcn-->get_pose_net
	model = get_model(num_layers=num_layers, heads=heads, head_conv=head_conv)
	return model

models/networks/pose_dla_dcn.py

#网络结构细节探究,输入到输出之间有一个4倍的降采样
def get_pose_net(num_layers, heads, head_conv=256, down_ratio=4):
	model = DLASeg("dla{}".format(num_layers), heads, #DLASeg
	pretrained=True,#True
	down_ratio=down_ratio,
	final_kernel=1,
	last_level=5,
	head_conv=head_conv)
	return model

训练部分

trains/train_factory.py

train_factory = {
"exdet": ExdetTrainer, 
"ddd": DddTrainer,
"ctdet": CtdetTrainer,#任务
"multi_pose": MultiPoseTrainer, 
}

trains/ctdet.py

继承自父类BaseTrainer,主要是重写父类的一些函数,主要是损失函数

class CtdetTrainer(BaseTrainer):
	def __init__(self, opt, model, optimizer=None):
		super(CtdetTrainer, self).__init__(opt, model, optimizer=optimizer)
	def _get_losses(self, opt):
		loss_states = ["loss", "hm_loss", "wh_loss", "off_loss"]
		loss = CtdetLoss(opt)
		return loss_states, loss

class CtdetLoss(torch.nn.Module):#'设计为一个网络模块,里面计算了网络的输出和损失函数'
	def __init__(self, opt):
		super(CtdetLoss, self).__init__()

	def forward(self, outputs, batch):
		opt = self.opt
		hm_loss, wh_loss, off_loss = 0, 0, 0#'分别计算热力图loss,尺寸回归loss和偏移loss,如表2 所示'
		for s in range(opt.num_stacks): #'deep supervision,多阶段loss'
			output = outputs[s]
			# '…'
			hm_loss += self.crit(output["hm"], batch["hm"]) / opt.num_stacks
			if opt.wh_weight > 0:
				if opt.dense_wh:
			# '…'
				else:
					wh_loss += self.crit_reg(output["wh"], batch["reg_mask"],
					batch["ind"], batch["wh"]) / opt.num_stacks
			
			if opt.reg_offset and opt.off_weight > 0:
				off_loss += self.crit_reg(output["reg"], batch["reg_mask"],batch["ind"], batch["reg"]) / opt.num_stacks
		
		loss = opt.hm_weight * hm_loss + opt.wh_weight * wh_loss + \
		opt.off_weight * off_loss#损失函数加权和
		loss_stats = {"loss": loss, "hm_loss": hm_loss,"wh_loss": wh_loss, "off_loss": off_loss}
		return loss, loss_stats
损失函数的种类
名称 类别
self.crit FocalLoss
self.crit_reg RegL1Loss
self.crit_wh RegL1Loss

models/losses.py

class RegL1Loss(nn.Module):
	def __init__(self):
		super(RegL1Loss, self).__init__()

	def forward(self, output, mask, ind, target):
        #'从输出的map里取出真值关键点位置(ind里记录了)处的输出结果(N,128,C')
		pred = _transpose_and_gather_feat(output, ind)
		mask = mask.unsqueeze(2).expand_as(pred).float()
		# loss = F.l1_loss(pred * mask, target * mask, reduction="elementwise_mean")
		loss = F.l1_loss(pred * mask, target * mask, size_average=False)#'*mask是要只选择存在目标的位置'
		loss = loss / (mask.sum() + 1e-4)#'每图有多个关键点,mask.sum()是关键点的个数'
		return loss

models/utils.py

def _gather_feat(feat, ind, mask=None):
	# feat [N,H*W,C]
	dim  = feat.size(2) #C
	#'[N,128]-->[N,128,1]-->[N,128,C]'
	ind  = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim)
	#'从feat也就是output的map里中取出真值关键点位置处的预测值,主要针对(wh,reg)'
	feat = feat.gather(1, ind)
	if mask is not None:
		mask = mask.unsqueeze(2).expand_as(feat)
		feat = feat[mask]
		feat = feat.view(-1, dim)
	return feat#'对应的输出结果(N,128,C)'

def _transpose_and_gather_feat(feat, ind):
	feat = feat.permute(0, 2, 3, 1).contiguous()# '[N,C,H,W]-->[N,H,W,C]'
	feat = feat.view(feat.size(0), -1, feat.size(3)) 
	#'[N,C,H,W]-->[N,H*W,C]'
	feat = _gather_feat(feat, ind)
	return feat#'取出真值关键点位置(ind里记录了)处的输出结果(N,128,C)'
\end{lstlisting}

\subsection{trains/base_trainer.py}
\begin{lstlisting}
class ModelWithLoss(torch.nn.Module):
	def __init__(self, model, loss):
		super(ModelWithLoss, self).__init__()
		self.model = model
		self.loss = loss

	def forward(self, batch):
		#'计算输出和损失函数'
		outputs = self.model(batch["input"])
		loss, loss_stats = self.loss(outputs, batch)
		return outputs[-1], loss, loss_stats

class BaseTrainer(object):
	def __init__(self, opt, model, optimizer=None):
		self.opt = opt
		self.optimizer = optimizer
		self.loss_stats, self.loss = self._get_losses(opt)
		self.model_with_loss = ModelWithLoss(model, self.loss)#' ModelWithLoss定义为网络结构,在这里,计算网络输出以及loss'

	def run_epoch(self, phase, epoch, data_loader): #'主要实现该功能,网络一个epoch的训练'
		model_with_loss = self.model_with_loss
		if phase == "train":
			model_with_loss.train()
		else:
			if len(self.opt.gpus) > 1:
				model_with_loss = self.model_with_loss.module
				model_with_loss.eval()
				torch.cuda.empty_cache()

		opt = self.opt
		results = {}
		#...
		for iter_id, batch in enumerate(data_loader):
		#...
			for k in batch:
				if k != "meta":
					batch[k] = batch[k].to(device=opt.device, non_blocking=True)    
			output, loss, loss_stats = model_with_loss(batch)
			loss = loss.mean()

	def train(self, epoch, data_loader):
		return self.run_epoch("train", epoch, data_loader)

测试部分demo.py

detectors/detector_factoey.py

detector_factory = {'exdet': ExdetDetector, 'ddd': DddDetector,
'ctdet': CtdetDetector,'multi_pose': MultiPoseDetector, }

detectors/ctdet.py

继承自父BaseDtector,主要重写process()函数,实现对图片的处理,获取预测结果,结果是[bboxes, scores, clses],大小是[N,K,6],K表示从输出结果中选取K个最大响应的中心点,详情见论文。

def process(self, images, return_time=False):
	
	#'...'
	dets = ctdet_decode(hm, wh, reg=reg, cat_spec_wh=self.opt.cat_spec_wh, K=self.opt.K)#detections = torch.cat([bboxes, scores, clses], dim=2)[N,K,6]

detectors/base_detector.py

class BaseDetector(object):
	def __init__(self, opt):
	#'...'
	#'输入预处理,没太看懂'
	def pre_process(self, image, scale, meta=None):
		height, width = image.shape[0:2]
		new_height = int(height * scale)
		new_width  = int(width * scale)
		if self.opt.fix_res:
			inp_height, inp_width = self.opt.input_h, self.opt.input_w
			c = np.array([new_width / 2., new_height / 2.], dtype=np.float32)
			s = max(height, width) * 1.0
		else:
			inp_height = (new_height | self.opt.pad) + 1
			inp_width = (new_width | self.opt.pad) + 1
			c = np.array([new_width // 2, new_height // 2], dtype=np.float32)
			s = np.array([inp_width, inp_height], dtype=np.float32)
	def run(self, image_or_path_or_tensor, meta=None):
		#'预处理?'
		pre_processed = False
		if isinstance(image_or_path_or_tensor, np.ndarray):
			image = image_or_path_or_tensor
		elif type(image_or_path_or_tensor) == type (''): 
			image = cv2.imread(image_or_path_or_tensor)
		else:
			image = image_or_path_or_tensor['image'][0].numpy()
			pre_processed_images = image_or_path_or_tensor
			pre_processed = True
			
		#'主要还是实现process函数,完成结果的预测和展示,在ctdet.py中重写该方法'
		output, dets, forward_time = self.process(images, return_time=True)

models/decode.py

def _nms(heat, kernel=3):
	#'通过maxpooling的方式找出3*3邻域内最大的点'
	pad = (kernel - 1) // 2
	hmax = nn.functional.max_pool2d(
	heat, (kernel, kernel), stride=1, padding=pad)
	keep = (hmax == heat).float()#'判断最大的点是否是当前点'
	return heat * keep
def _topk(scores, K=40):
	batch, cat, height, width = scores.size()
	#'torch.topk,scores[N,C,H*W],返回最大的K个值和其所在的位置'
	topk_scores, topk_inds = torch.topk(scores.view(batch, cat, -1), K)
	#'每个N每个C上最大的K个点的坐标,大小分别是[N,C,K],存储纵坐标和横坐标'
	topk_inds = topk_inds % (height * width)
	topk_ys   = (topk_inds / width).int().float()
	topk_xs   = (topk_inds % width).int().float()
	#'在所有的通道里寻找最大值,到底是哪个通道最大 [N,C*k]-->[N,K] '
	topk_score, topk_ind = torch.topk(topk_scores.view(batch, -1), K)
	topk_clses = (topk_ind / K).int()
	#'([N,C*K,1],[N,K])-->[N,K],取出目标在[N,C,H*W]中的位置(也就是像素坐标的一维表示h*W+w)以及所对应的横纵坐标(由一维表示拆解而来)'
	topk_inds = _gather_feat(
	topk_inds.view(batch, -1, 1), topk_ind).view(batch, K)
	topk_ys = _gather_feat(topk_ys.view(batch, -1, 1), topk_ind).view(batch, K)
	topk_xs = _gather_feat(topk_xs.view(batch, -1, 1), topk_ind).view(batch, K)

return topk_score, topk_inds, topk_clses, topk_ys, topk_xs

def ctdet_decode(heat, wh, reg=None, cat_spec_wh=False, K=100):
batch, cat, height, width = heat.size()
	# perform nms on heatmaps
	heat = _nms(heat)#'判断当前点是否为邻域内最大点,若是则保留,否则位0'
	#'返回最大k个关键点所对应的置信度、一维表示的坐标、通道id也就是类别、横纵坐标,其中inds是从map中索引wh、reg等真值的序号 '
	scores, inds, clses, ys, xs = _topk(heat, K=K)

multi pose

关于人体姿态关键点的预测,首先通过中心点以及其偏移获得关键点坐标。然后,通过直接回归关键点的热力图求出关键点坐标hm_kps(根据每个通道响应最大的K个位置[N,17,K,2]),这两组的K个关键点之间,每两两点之间,以为目标,找出hm_kps中距离reg_kps最近的K个点,即论文里说的基于中心偏移预测在直接回归的关键点热力图里寻找最近的点。

最后的返回结果是在reg_kps与hm_kps之间进行选择,满足以下条件选择reg_kps:通过热力图直接回归的点不在Bbox内,点的置信度低于阈值,与基于中心点回归的关键点的最小距离超过Bbox尺寸的0.3倍,输出为[bboxes, scores, kps, clses][N,K,40].

  返回的变量名称 dataset中的变量名称 说明 大小 模型预测输出的通道 损失函数 检测结果名称 说明
基于中心点偏移的预测 hm hm 中心点热力图 1,H,W 1 FocalLoss    
wh wh 长宽度 32,2 2 RegL1Loss bboxes 目标的bbox
hps kps 关键点相对中心点的偏移 32,2*17 2*17 RegWeightedL1Loss,跟RegL1Loss也差不多 kps/reg_kps 最大K个中心点所对应的17个关键点的偏移+中心点坐标,也就是根据中心点预测的关键点坐标
reg reg 中心点的偏移 32,2 2 RegL1Loss    
ind ind 图上的目标的数量 32        
reg_mask reg_mask 32个位置处是否包含目标 32        
hps_mask kps_mask 关键点是否有不存在的 32,2*17        
基于直接回归热力图的预测 hm_hp hm_hp 直接通过热力图预测关键点 17,H,W 17 FocalLoss hm_kps 根据直接回归的关键点热力图预测关键点坐标,在每个通道上找出K个最大响应所对应的坐标
hp_offset hp_offset 通过热力图预测的关键点的偏移 32*17,2 2 RegL1Loss    
hp_ind hp_ind 关键点所在的位置 32*17        
hp_mask hp_mask 关键点是否有不存在的 32*17        

 

def multi_pose_decode(
    heat, wh, kps, reg=None, hm_hp=None, hp_offset=None, K=100):
  batch, cat, height, width = heat.size()
  num_joints = kps.shape[1] // 2
  # heat = torch.sigmoid(heat)
  # perform nms on heatmaps
  heat = _nms(heat)
  scores, inds, clses, ys, xs = _topk(heat, K=K)

  kps = _transpose_and_gather_feat(kps, inds)#取出最大K个中心点所对对应的17个关键点的偏移
  kps = kps.view(batch, K, num_joints * 2)
  kps[..., ::2] += xs.view(batch, K, 1).expand(batch, K, num_joints)#计算偏移的横坐标
  kps[..., 1::2] += ys.view(batch, K, 1).expand(batch, K, num_joints)#计算偏移纵坐标
  if reg is not None:
    reg = _transpose_and_gather_feat(reg, inds)
    reg = reg.view(batch, K, 2)
    xs = xs.view(batch, K, 1) + reg[:, :, 0:1]#计算中心点偏移
    ys = ys.view(batch, K, 1) + reg[:, :, 1:2]
  else:
    xs = xs.view(batch, K, 1) + 0.5
    ys = ys.view(batch, K, 1) + 0.5
  wh = _transpose_and_gather_feat(wh, inds)
  wh = wh.view(batch, K, 2)
  clses  = clses.view(batch, K, 1).float()
  scores = scores.view(batch, K, 1)

  bboxes = torch.cat([xs - wh[..., 0:1] / 2, 
                      ys - wh[..., 1:2] / 2,
                      xs + wh[..., 0:1] / 2, 
                      ys + wh[..., 1:2] / 2], dim=2)
  if hm_hp is not None:
      hm_hp = _nms(hm_hp)
      thresh = 0.1
      kps = kps.view(batch, K, num_joints, 2).permute(
          0, 2, 1, 3).contiguous() # b x J x K x 2
      reg_kps = kps.unsqueeze(3).expand(batch, num_joints, K, K, 2)
      hm_score, hm_inds, hm_ys, hm_xs = _topk_channel(hm_hp, K=K) # b x J x K
      #返回[N,C,K],而不是目标检测的[N,K],即每个通道单独处理,C=17
      if hp_offset is not None:
          hp_offset = _transpose_and_gather_feat(
              hp_offset, hm_inds.view(batch, -1))
          hp_offset = hp_offset.view(batch, num_joints, K, 2)
          hm_xs = hm_xs + hp_offset[:, :, :, 0]
          hm_ys = hm_ys + hp_offset[:, :, :, 1]
      else:
          hm_xs = hm_xs + 0.5
          hm_ys = hm_ys + 0.5
        
      mask = (hm_score > thresh).float()
      hm_score = (1 - mask) * -1 + mask * hm_score#把mask为0的地方变成-1
      hm_ys = (1 - mask) * (-10000) + mask * hm_ys#把mask为0地方变成-10000
      hm_xs = (1 - mask) * (-10000) + mask * hm_xs
      hm_kps = torch.stack([hm_xs, hm_ys], dim=-1).unsqueeze(
          2).expand(batch, num_joints, K, K, 2)#把坐标再复制一遍
      dist = (((reg_kps - hm_kps) ** 2).sum(dim=4) ** 0.5)#reg_kps基于中心点的预测,两个预测之间的距离,计算基于中心点预测任意一点和基于热力图任意一点的距离
      min_dist, min_ind = dist.min(dim=3) # b x J x K[N,C,K,K]-->[N,C,K]每两两点之间,以reg_kps为目标,找出hm_kps中K个点中距离最近的K个位置
      hm_score = hm_score.gather(2, min_ind).unsqueeze(-1) # b x J x K x [N,C,K]K个基于热力图预测中距离个原始偏移预测最近的点,
      min_dist = min_dist.unsqueeze(-1)
      min_ind = min_ind.view(batch, num_joints, K, 1, 1).expand(
          batch, num_joints, K, 1, 2)#[N,C,K,1,2]最后两维的两列是一样的
      hm_kps = hm_kps.gather(3, min_ind)#找出对应的最近的点的坐标[N,C,K,2],即论文里说的基于回归偏移预测寻找在基于回归预测里最近的点
      #K个最大中心所对应的一组关键点,也就是在基于直接热力图的预测中,在每类关键点所属的通道上找出K个点
      hm_kps = hm_kps.view(batch, num_joints, K, 2)
      l = bboxes[:, :, 0].view(batch, 1, K, 1).expand(batch, num_joints, K, 1)
      t = bboxes[:, :, 1].view(batch, 1, K, 1).expand(batch, num_joints, K, 1)
      r = bboxes[:, :, 2].view(batch, 1, K, 1).expand(batch, num_joints, K, 1)
      b = bboxes[:, :, 3].view(batch, 1, K, 1).expand(batch, num_joints, K, 1)
      #关键点限制在bbox内
      mask = (hm_kps[..., 0:1] < l) + (hm_kps[..., 0:1] > r) + \
             (hm_kps[..., 1:2] < t) + (hm_kps[..., 1:2] > b) + \
             (hm_score < thresh) + (min_dist > (torch.max(b - t, r - l) * 0.3))    
      mask = (mask > 0).float().expand(batch, num_joints, K, 2)
      kps = (1 - mask) * hm_kps + mask * kps#在kps和hm_kps之间选择一个,不满足以上条件选择kps(通过热力图直接回归的点不在Bbox内,点的置信度低于阈值,与通过中心点回归的点的距离超过Bbox尺寸的0.3倍)
      kps = kps.permute(0, 2, 1, 3).contiguous().view(
          batch, K, num_joints * 2)#[N,K,2*C]
  detections = torch.cat([bboxes, scores, kps, clses], dim=2)
    
  return detections

 

你可能感兴趣的:(论文代码阅读,目标检测,关键点检测,pytorch)