目录
前言
主函数部分main.py
数据部分
datasets/dataset_factory.py
datasets/dataset/coco.py
datasets/sample/ctdet.py
utils/image.py
模型
models/model.py
models/networks/pose_dla_dcn.py
训练部分
trains/train_factory.py
trains/ctdet.py
models/losses.py
models/utils.py
测试部分demo.py
detectors/detector_factoey.py
detectors/ctdet.py
detectors/base_detector.py
models/decode.py
multi pose
CenterNet是基于关键点(中心点)检测的目标检测方法,论文详情请参照https://blog.csdn.net/LXX516/article/details/106251090。代码地址为:https://github.com/xingyizhou/CenterNet,源代码像个超市,功能太多了,我简要地将一些我认为比较重要的功能拿出来注释一下,因为最终的目的还是把这些功能搬到自己的数据与应用上去。
主函数main.py中最主要的是三部分:数据准备、模型建立和训练过程(损失函数计算)。
from opts import opts
from models.model import create_model, load_model, save_model
from datasets.dataset_factory import get_dataset
from trains.train_factory import train_factory
#'假设任务和数据集分别是ctdet和coco'
#准备数据
Dataset = get_dataset(opt.dataset, opt.task)
opt = opts().update_dataset_info_and_set_heads(opt, Dataset)
# 建立模型
model = create_model(opt.arch, opt.heads, opt.head_conv)
#print(opt.arch, opt.heads, opt.head_conv)
#dla_34 {"hm": 80, "wh": 2, "reg": 2} 256,
#对于目标检测而言,需要预测热力图,回归长宽度和偏移,所以在backbone之上有3个分支
#训练
Trainer = train_factory[opt.task]
trainer = Trainer(opt, model, optimizer)
trainer.set_device(opt.gpus, opt.chunk_sizes, opt.device)
#...
log_dict_train, _ = trainer.train(epoch, train_loader)
from .sample.ctdet import CTDetDataset
from .dataset.coco import COCO
def get_dataset(dataset, task):
class Dataset(dataset_factory[dataset], _sample_factory[task]):
pass
return Dataset
主程序中返回的Dataset,该Dataset继承了两个父类,根据不同的任务和数据集决定"ctdet": CTDetDataset和"coco": COCO.
class COCO(data.Dataset):
#定义了属于类的变量,num_classes,default_resolution,mean,std
num_classes = 80
default_resolution = [512, 512]
mean = np.array([0.40789654, 0.44719302, 0.47026115],dtype=np.float32).reshape(1, 1, 3)
std = np.array([0.28863828, 0.27408164, 0.27809835],dtype=np.float32).reshape(1, 1, 3)
def __init__(self, opt, split):
super(COCO, self).__init__()
self.data_dir = os.path.join(opt.data_dir, "coco")
self.img_dir = os.path.join(self.data_dir, "{}2017".format(split))
#...
self.max_objs = 128#'假设每张图片最大的目标个数不超过128'
#...
#'获取图像信息,记录在self.images中'
self.coco = coco.COCO(self.annot_path)
self.images = self.coco.getImgIds()
self.num_samples = len(self.images)
class CTDetDataset(data.Dataset):
#...
#'该类的主要功能是主要实现__getitem__方法,以便在train的时候被DataLoader加载'
def __getitem__(self, index):
#...
#'读入图像和标注信息'
num_objs = min(len(anns), self.max_objs) #'anns的个数就是图中包含的目标的个数,不能超过设置的最大数128'
height, width = img.shape[0], img.shape[1]
c = np.array([img.shape[1] / 2., img.shape[0] / 2.], dtype=np.float32) #'计算中心点'
s = max(img.shape[0], img.shape[1]) * 1.0 #'最长的一条边'
input_h, input_w = self.opt.input_h, self.opt.input_w #'确定的输入的尺寸'
flipped = False
if self.split == "train":
if not self.opt.not_rand_crop:
#'长边缩放,随机裁剪,主要是随机变换中心点的位置,变换的范围是图像的四边向内移动border距离。border的取值是如果图像尺寸超过256,border为128,否则为64.
s = s * np.random.choice(np.arange(0.6, 1.4, 0.1))
w_border = self._get_border(128, img.shape[1])
h_border = self._get_border(128, img.shape[0])
c[0] = np.random.randint(low=w_border, high=img.shape[1] - w_border)
c[1] = np.random.randint(low=h_border, high=img.shape[0] - h_border)
else:
#'缩放和平移,根据相应的因子确定新的中心点位置和长边'
#'与裁剪的区别是是啥?'
sf = self.opt.scale
cf = self.opt.shift
c[0] += s * np.clip(np.random.randn()*cf, -2*cf, 2*cf)
c[1] += s * np.clip(np.random.randn()*cf, -2*cf, 2*cf)
s = s * np.clip(np.random.randn()*sf + 1, 1 - sf, 1 + sf)
if np.random.random() < self.opt.flip:
#'水平翻转'
flipped = True
img = img[:, ::-1, :]
c[0] = width - c[0] - 1
#'确定这些参数之后,对输入进行仿射变换'
trans_input = get_affine_transform(c, s, 0, [input_w, input_h])
inp=cv2.warpAffine(img, trans_input,(input_w,input_h), flags=cv2.INTER_LINEAR)
inp = (inp.astype(np.float32) / 255.)
#数据增强
if self.split == "train" and not self.opt.no_color_aug:
color_aug(self._data_rng, inp, self._eig_val, self._eig_vec)
inp = (inp - self.mean) / self.std
inp = inp.transpose(2, 0, 1)
#'为输出做仿射变换做准备'
output_h = input_h // self.opt.down_ratio
output_w = input_w // self.opt.down_ratio
num_classes = self.num_classes
trans_output = get_affine_transform(c, s, 0, [output_w, output_h])
#...
#'准备数据集返回变量,主要是真值标签,如热力图、目标长度和宽度、偏移量,如表1所示'
gt_det = []
for k in range(num_objs):#'按实际有多少个目标来'
if flipped:
bbox[[0, 2]] = width - bbox[[2, 0]]-1
#'输出的仿射变换,这些仿射变换由相应的函数完成,在实际编写中只需确定诸如中心点c、长边长度s等参数,便于在自己的数据集中使用'
bbox[:2] = affine_transform(bbox[:2], trans_output)
bbox[2:] = affine_transform(bbox[2:], trans_output)
bbox[[0, 2]] = np.clip(bbox[[0, 2]], 0, output_w - 1)
bbox[[1, 3]] = np.clip(bbox[[1, 3]], 0, output_h - 1)
h, w = bbox[3] - bbox[1], bbox[2] - bbox[0]
if h > 0 and w > 0:
radius = gaussian_radius((math.ceil(h), math.ceil(w)))#'这个热力图的半径是根据目标的尺寸确定的'
radius = max(0, int(radius))
radius = self.opt.hm_gauss if self.opt.mse_loss else radius
ct = np.array([(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2], dtype=np.float32)#'新的中心点'
ct_int = ct.astype(np.int32)
draw_gaussian(hm[cls_id], ct_int, radius)#'绘制热力图,绘制在其所属类别的通道上'
wh[k] = 1. * w, 1. * h#[128,2] #'若object的个数不够128,那么剩下的未填充的位置依然是0'
ind[k] = ct_int[1] * output_w + ct_int[0]#'中心点的位置,用一维表示h*W+w'
reg[k] = ct - ct_int#'由取整引起的误差'
reg_mask[k] = 1#'有目标存在的位置,设为1'
cat_spec_wh[k, cls_id * 2: cls_id * 2 + 2] = wh[k]#类间长宽度不共享
cat_spec_mask[k, cls_id * 2: cls_id * 2 + 2] = 1
if self.opt.dense_wh:#False
draw_dense_reg(dense_wh, hm.max(axis=0), ct_int, wh[k], radius)#???
gt_det.append([ct[0] - w / 2, ct[1] - h / 2, ct[0] + w / 2, ct[1] + h / 2, 1, cls_id])#???
ret = {"input": inp, "hm": hm, "reg_mask": reg_mask, "ind": ind, "wh": wh}
名称 | 作用 | 大小 | 类型 |
---|---|---|---|
hm | 热力图 | [C,H,W] | float32 |
wh | 目标的长度和宽度 | [128,2] | float32 |
dense_wh | 直接回归长宽度的map | [2,H,W] | float32 |
reg | 下采样取整引起的偏移 | [128,2] | float32 |
ind | 中心点的位置(h*W+w),一维的表示方式 | [128] | int64 |
reg_mask | 固定长度的表示下是否存在关键点,最多128个目标 | [128] | uint8 |
cat_spec_wh | 长宽度预测类间不共享?没用到 | [128,C*2] | float32 |
cat_spec_mask | ? | [128,C*2] | uint8 |
M=cv2.GetAffineTransform(src, dst):
所以函数get_affine_transform()的功能是:
def get_affine_transform(center,scale,rot,output_size,shift=np.array([0, 0], dtype=np.float32),inv=0):
if not isinstance(scale, np.ndarray) and not isinstance(scale, list):
scale = np.array([scale, scale], dtype=np.float32)
scale_tmp = scale#[s,s]
src_w = scale_tmp[0]#s
dst_w = output_size[0]
dst_h = output_size[1]
rot_rad = np.pi * rot / 180#0
src_dir = get_dir([0, src_w * -0.5], rot_rad)
dst_dir = np.array([0, dst_w * -0.5], np.float32)
src = np.zeros((3, 2), dtype=np.float32)
dst = np.zeros((3, 2), dtype=np.float32)
src[0, :] = center + scale_tmp * shift#'平移?'
src[1, :] = center + src_dir + scale_tmp * shift#'平移+旋转?'
dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5], np.float32) + dst_dir
src[2:, :] = get_3rd_point(src[0, :], src[1, :])
dst[2:, :] = get_3rd_point(dst[0, :], dst[1, :])
#确定输入输出相对应的3对点
if inv:
trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
else:
trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
return trans
def affine_transform(pt, t):
new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32).T
new_pt = np.dot(t, new_pt)#矩阵乘积
return new_pt[:2]
选用的模型是dla34.
from .networks.pose_dla_dcn import get_pose_net as get_dla_dcn
def create_model(arch, heads, head_conv):
num_layers = int(arch[arch.find("_") + 1:]) if "_" in arch else 0
arch = arch[:arch.find("_")] if "_" in arch else arch
get_model = _model_factory[arch]# pose_dla_dcn-->get_pose_net
model = get_model(num_layers=num_layers, heads=heads, head_conv=head_conv)
return model
#网络结构细节探究,输入到输出之间有一个4倍的降采样
def get_pose_net(num_layers, heads, head_conv=256, down_ratio=4):
model = DLASeg("dla{}".format(num_layers), heads, #DLASeg
pretrained=True,#True
down_ratio=down_ratio,
final_kernel=1,
last_level=5,
head_conv=head_conv)
return model
train_factory = {
"exdet": ExdetTrainer,
"ddd": DddTrainer,
"ctdet": CtdetTrainer,#任务
"multi_pose": MultiPoseTrainer,
}
继承自父类BaseTrainer,主要是重写父类的一些函数,主要是损失函数
class CtdetTrainer(BaseTrainer):
def __init__(self, opt, model, optimizer=None):
super(CtdetTrainer, self).__init__(opt, model, optimizer=optimizer)
def _get_losses(self, opt):
loss_states = ["loss", "hm_loss", "wh_loss", "off_loss"]
loss = CtdetLoss(opt)
return loss_states, loss
class CtdetLoss(torch.nn.Module):#'设计为一个网络模块,里面计算了网络的输出和损失函数'
def __init__(self, opt):
super(CtdetLoss, self).__init__()
def forward(self, outputs, batch):
opt = self.opt
hm_loss, wh_loss, off_loss = 0, 0, 0#'分别计算热力图loss,尺寸回归loss和偏移loss,如表2 所示'
for s in range(opt.num_stacks): #'deep supervision,多阶段loss'
output = outputs[s]
# '…'
hm_loss += self.crit(output["hm"], batch["hm"]) / opt.num_stacks
if opt.wh_weight > 0:
if opt.dense_wh:
# '…'
else:
wh_loss += self.crit_reg(output["wh"], batch["reg_mask"],
batch["ind"], batch["wh"]) / opt.num_stacks
if opt.reg_offset and opt.off_weight > 0:
off_loss += self.crit_reg(output["reg"], batch["reg_mask"],batch["ind"], batch["reg"]) / opt.num_stacks
loss = opt.hm_weight * hm_loss + opt.wh_weight * wh_loss + \
opt.off_weight * off_loss#损失函数加权和
loss_stats = {"loss": loss, "hm_loss": hm_loss,"wh_loss": wh_loss, "off_loss": off_loss}
return loss, loss_stats
名称 | 类别 |
---|---|
self.crit | FocalLoss |
self.crit_reg | RegL1Loss |
self.crit_wh | RegL1Loss |
class RegL1Loss(nn.Module):
def __init__(self):
super(RegL1Loss, self).__init__()
def forward(self, output, mask, ind, target):
#'从输出的map里取出真值关键点位置(ind里记录了)处的输出结果(N,128,C')
pred = _transpose_and_gather_feat(output, ind)
mask = mask.unsqueeze(2).expand_as(pred).float()
# loss = F.l1_loss(pred * mask, target * mask, reduction="elementwise_mean")
loss = F.l1_loss(pred * mask, target * mask, size_average=False)#'*mask是要只选择存在目标的位置'
loss = loss / (mask.sum() + 1e-4)#'每图有多个关键点,mask.sum()是关键点的个数'
return loss
def _gather_feat(feat, ind, mask=None):
# feat [N,H*W,C]
dim = feat.size(2) #C
#'[N,128]-->[N,128,1]-->[N,128,C]'
ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim)
#'从feat也就是output的map里中取出真值关键点位置处的预测值,主要针对(wh,reg)'
feat = feat.gather(1, ind)
if mask is not None:
mask = mask.unsqueeze(2).expand_as(feat)
feat = feat[mask]
feat = feat.view(-1, dim)
return feat#'对应的输出结果(N,128,C)'
def _transpose_and_gather_feat(feat, ind):
feat = feat.permute(0, 2, 3, 1).contiguous()# '[N,C,H,W]-->[N,H,W,C]'
feat = feat.view(feat.size(0), -1, feat.size(3))
#'[N,C,H,W]-->[N,H*W,C]'
feat = _gather_feat(feat, ind)
return feat#'取出真值关键点位置(ind里记录了)处的输出结果(N,128,C)'
\end{lstlisting}
\subsection{trains/base_trainer.py}
\begin{lstlisting}
class ModelWithLoss(torch.nn.Module):
def __init__(self, model, loss):
super(ModelWithLoss, self).__init__()
self.model = model
self.loss = loss
def forward(self, batch):
#'计算输出和损失函数'
outputs = self.model(batch["input"])
loss, loss_stats = self.loss(outputs, batch)
return outputs[-1], loss, loss_stats
class BaseTrainer(object):
def __init__(self, opt, model, optimizer=None):
self.opt = opt
self.optimizer = optimizer
self.loss_stats, self.loss = self._get_losses(opt)
self.model_with_loss = ModelWithLoss(model, self.loss)#' ModelWithLoss定义为网络结构,在这里,计算网络输出以及loss'
def run_epoch(self, phase, epoch, data_loader): #'主要实现该功能,网络一个epoch的训练'
model_with_loss = self.model_with_loss
if phase == "train":
model_with_loss.train()
else:
if len(self.opt.gpus) > 1:
model_with_loss = self.model_with_loss.module
model_with_loss.eval()
torch.cuda.empty_cache()
opt = self.opt
results = {}
#...
for iter_id, batch in enumerate(data_loader):
#...
for k in batch:
if k != "meta":
batch[k] = batch[k].to(device=opt.device, non_blocking=True)
output, loss, loss_stats = model_with_loss(batch)
loss = loss.mean()
def train(self, epoch, data_loader):
return self.run_epoch("train", epoch, data_loader)
detector_factory = {'exdet': ExdetDetector, 'ddd': DddDetector,
'ctdet': CtdetDetector,'multi_pose': MultiPoseDetector, }
继承自父BaseDtector,主要重写process()函数,实现对图片的处理,获取预测结果,结果是[bboxes, scores, clses],大小是[N,K,6],K表示从输出结果中选取K个最大响应的中心点,详情见论文。
def process(self, images, return_time=False):
#'...'
dets = ctdet_decode(hm, wh, reg=reg, cat_spec_wh=self.opt.cat_spec_wh, K=self.opt.K)#detections = torch.cat([bboxes, scores, clses], dim=2)[N,K,6]
class BaseDetector(object):
def __init__(self, opt):
#'...'
#'输入预处理,没太看懂'
def pre_process(self, image, scale, meta=None):
height, width = image.shape[0:2]
new_height = int(height * scale)
new_width = int(width * scale)
if self.opt.fix_res:
inp_height, inp_width = self.opt.input_h, self.opt.input_w
c = np.array([new_width / 2., new_height / 2.], dtype=np.float32)
s = max(height, width) * 1.0
else:
inp_height = (new_height | self.opt.pad) + 1
inp_width = (new_width | self.opt.pad) + 1
c = np.array([new_width // 2, new_height // 2], dtype=np.float32)
s = np.array([inp_width, inp_height], dtype=np.float32)
def run(self, image_or_path_or_tensor, meta=None):
#'预处理?'
pre_processed = False
if isinstance(image_or_path_or_tensor, np.ndarray):
image = image_or_path_or_tensor
elif type(image_or_path_or_tensor) == type (''):
image = cv2.imread(image_or_path_or_tensor)
else:
image = image_or_path_or_tensor['image'][0].numpy()
pre_processed_images = image_or_path_or_tensor
pre_processed = True
#'主要还是实现process函数,完成结果的预测和展示,在ctdet.py中重写该方法'
output, dets, forward_time = self.process(images, return_time=True)
def _nms(heat, kernel=3):
#'通过maxpooling的方式找出3*3邻域内最大的点'
pad = (kernel - 1) // 2
hmax = nn.functional.max_pool2d(
heat, (kernel, kernel), stride=1, padding=pad)
keep = (hmax == heat).float()#'判断最大的点是否是当前点'
return heat * keep
def _topk(scores, K=40):
batch, cat, height, width = scores.size()
#'torch.topk,scores[N,C,H*W],返回最大的K个值和其所在的位置'
topk_scores, topk_inds = torch.topk(scores.view(batch, cat, -1), K)
#'每个N每个C上最大的K个点的坐标,大小分别是[N,C,K],存储纵坐标和横坐标'
topk_inds = topk_inds % (height * width)
topk_ys = (topk_inds / width).int().float()
topk_xs = (topk_inds % width).int().float()
#'在所有的通道里寻找最大值,到底是哪个通道最大 [N,C*k]-->[N,K] '
topk_score, topk_ind = torch.topk(topk_scores.view(batch, -1), K)
topk_clses = (topk_ind / K).int()
#'([N,C*K,1],[N,K])-->[N,K],取出目标在[N,C,H*W]中的位置(也就是像素坐标的一维表示h*W+w)以及所对应的横纵坐标(由一维表示拆解而来)'
topk_inds = _gather_feat(
topk_inds.view(batch, -1, 1), topk_ind).view(batch, K)
topk_ys = _gather_feat(topk_ys.view(batch, -1, 1), topk_ind).view(batch, K)
topk_xs = _gather_feat(topk_xs.view(batch, -1, 1), topk_ind).view(batch, K)
return topk_score, topk_inds, topk_clses, topk_ys, topk_xs
def ctdet_decode(heat, wh, reg=None, cat_spec_wh=False, K=100):
batch, cat, height, width = heat.size()
# perform nms on heatmaps
heat = _nms(heat)#'判断当前点是否为邻域内最大点,若是则保留,否则位0'
#'返回最大k个关键点所对应的置信度、一维表示的坐标、通道id也就是类别、横纵坐标,其中inds是从map中索引wh、reg等真值的序号 '
scores, inds, clses, ys, xs = _topk(heat, K=K)
关于人体姿态关键点的预测,首先通过中心点以及其偏移获得关键点坐标。然后,通过直接回归关键点的热力图求出关键点坐标hm_kps(根据每个通道响应最大的K个位置[N,17,K,2]),这两组的K个关键点之间,每两两点之间,以为目标,找出hm_kps中距离reg_kps最近的K个点,即论文里说的基于中心偏移预测在直接回归的关键点热力图里寻找最近的点。
最后的返回结果是在reg_kps与hm_kps之间进行选择,满足以下条件选择reg_kps:通过热力图直接回归的点不在Bbox内,点的置信度低于阈值,与基于中心点回归的关键点的最小距离超过Bbox尺寸的0.3倍,输出为[bboxes, scores, kps, clses][N,K,40].
返回的变量名称 | dataset中的变量名称 | 说明 | 大小 | 模型预测输出的通道 | 损失函数 | 检测结果名称 | 说明 | |
基于中心点偏移的预测 | hm | hm | 中心点热力图 | 1,H,W | 1 | FocalLoss | ||
wh | wh | 长宽度 | 32,2 | 2 | RegL1Loss | bboxes | 目标的bbox | |
hps | kps | 关键点相对中心点的偏移 | 32,2*17 | 2*17 | RegWeightedL1Loss,跟RegL1Loss也差不多 | kps/reg_kps | 最大K个中心点所对应的17个关键点的偏移+中心点坐标,也就是根据中心点预测的关键点坐标 | |
reg | reg | 中心点的偏移 | 32,2 | 2 | RegL1Loss | |||
ind | ind | 图上的目标的数量 | 32 | |||||
reg_mask | reg_mask | 32个位置处是否包含目标 | 32 | |||||
hps_mask | kps_mask | 关键点是否有不存在的 | 32,2*17 | |||||
基于直接回归热力图的预测 | hm_hp | hm_hp | 直接通过热力图预测关键点 | 17,H,W | 17 | FocalLoss | hm_kps | 根据直接回归的关键点热力图预测关键点坐标,在每个通道上找出K个最大响应所对应的坐标 |
hp_offset | hp_offset | 通过热力图预测的关键点的偏移 | 32*17,2 | 2 | RegL1Loss | |||
hp_ind | hp_ind | 关键点所在的位置 | 32*17 | |||||
hp_mask | hp_mask | 关键点是否有不存在的 | 32*17 |
def multi_pose_decode(
heat, wh, kps, reg=None, hm_hp=None, hp_offset=None, K=100):
batch, cat, height, width = heat.size()
num_joints = kps.shape[1] // 2
# heat = torch.sigmoid(heat)
# perform nms on heatmaps
heat = _nms(heat)
scores, inds, clses, ys, xs = _topk(heat, K=K)
kps = _transpose_and_gather_feat(kps, inds)#取出最大K个中心点所对对应的17个关键点的偏移
kps = kps.view(batch, K, num_joints * 2)
kps[..., ::2] += xs.view(batch, K, 1).expand(batch, K, num_joints)#计算偏移的横坐标
kps[..., 1::2] += ys.view(batch, K, 1).expand(batch, K, num_joints)#计算偏移纵坐标
if reg is not None:
reg = _transpose_and_gather_feat(reg, inds)
reg = reg.view(batch, K, 2)
xs = xs.view(batch, K, 1) + reg[:, :, 0:1]#计算中心点偏移
ys = ys.view(batch, K, 1) + reg[:, :, 1:2]
else:
xs = xs.view(batch, K, 1) + 0.5
ys = ys.view(batch, K, 1) + 0.5
wh = _transpose_and_gather_feat(wh, inds)
wh = wh.view(batch, K, 2)
clses = clses.view(batch, K, 1).float()
scores = scores.view(batch, K, 1)
bboxes = torch.cat([xs - wh[..., 0:1] / 2,
ys - wh[..., 1:2] / 2,
xs + wh[..., 0:1] / 2,
ys + wh[..., 1:2] / 2], dim=2)
if hm_hp is not None:
hm_hp = _nms(hm_hp)
thresh = 0.1
kps = kps.view(batch, K, num_joints, 2).permute(
0, 2, 1, 3).contiguous() # b x J x K x 2
reg_kps = kps.unsqueeze(3).expand(batch, num_joints, K, K, 2)
hm_score, hm_inds, hm_ys, hm_xs = _topk_channel(hm_hp, K=K) # b x J x K
#返回[N,C,K],而不是目标检测的[N,K],即每个通道单独处理,C=17
if hp_offset is not None:
hp_offset = _transpose_and_gather_feat(
hp_offset, hm_inds.view(batch, -1))
hp_offset = hp_offset.view(batch, num_joints, K, 2)
hm_xs = hm_xs + hp_offset[:, :, :, 0]
hm_ys = hm_ys + hp_offset[:, :, :, 1]
else:
hm_xs = hm_xs + 0.5
hm_ys = hm_ys + 0.5
mask = (hm_score > thresh).float()
hm_score = (1 - mask) * -1 + mask * hm_score#把mask为0的地方变成-1
hm_ys = (1 - mask) * (-10000) + mask * hm_ys#把mask为0地方变成-10000
hm_xs = (1 - mask) * (-10000) + mask * hm_xs
hm_kps = torch.stack([hm_xs, hm_ys], dim=-1).unsqueeze(
2).expand(batch, num_joints, K, K, 2)#把坐标再复制一遍
dist = (((reg_kps - hm_kps) ** 2).sum(dim=4) ** 0.5)#reg_kps基于中心点的预测,两个预测之间的距离,计算基于中心点预测任意一点和基于热力图任意一点的距离
min_dist, min_ind = dist.min(dim=3) # b x J x K[N,C,K,K]-->[N,C,K]每两两点之间,以reg_kps为目标,找出hm_kps中K个点中距离最近的K个位置
hm_score = hm_score.gather(2, min_ind).unsqueeze(-1) # b x J x K x [N,C,K]K个基于热力图预测中距离个原始偏移预测最近的点,
min_dist = min_dist.unsqueeze(-1)
min_ind = min_ind.view(batch, num_joints, K, 1, 1).expand(
batch, num_joints, K, 1, 2)#[N,C,K,1,2]最后两维的两列是一样的
hm_kps = hm_kps.gather(3, min_ind)#找出对应的最近的点的坐标[N,C,K,2],即论文里说的基于回归偏移预测寻找在基于回归预测里最近的点
#K个最大中心所对应的一组关键点,也就是在基于直接热力图的预测中,在每类关键点所属的通道上找出K个点
hm_kps = hm_kps.view(batch, num_joints, K, 2)
l = bboxes[:, :, 0].view(batch, 1, K, 1).expand(batch, num_joints, K, 1)
t = bboxes[:, :, 1].view(batch, 1, K, 1).expand(batch, num_joints, K, 1)
r = bboxes[:, :, 2].view(batch, 1, K, 1).expand(batch, num_joints, K, 1)
b = bboxes[:, :, 3].view(batch, 1, K, 1).expand(batch, num_joints, K, 1)
#关键点限制在bbox内
mask = (hm_kps[..., 0:1] < l) + (hm_kps[..., 0:1] > r) + \
(hm_kps[..., 1:2] < t) + (hm_kps[..., 1:2] > b) + \
(hm_score < thresh) + (min_dist > (torch.max(b - t, r - l) * 0.3))
mask = (mask > 0).float().expand(batch, num_joints, K, 2)
kps = (1 - mask) * hm_kps + mask * kps#在kps和hm_kps之间选择一个,不满足以上条件选择kps(通过热力图直接回归的点不在Bbox内,点的置信度低于阈值,与通过中心点回归的点的距离超过Bbox尺寸的0.3倍)
kps = kps.permute(0, 2, 1, 3).contiguous().view(
batch, K, num_joints * 2)#[N,K,2*C]
detections = torch.cat([bboxes, scores, kps, clses], dim=2)
return detections