原文:Objects as Points
代码:xingyizhou/CenterNet
略
略
略
假设 I ∈ R H × W × 3 I \in \mathbb{R}^{H \times W \times 3} I∈RH×W×3 是一个高H宽W的图像输入。我们的目标是生成一张heatmap Y ^ ∈ [ 0 , 1 ] H R × W R × C \hat{Y} \in[0,1]^{\frac{H}{R} \times \frac{W}{R} \times C} Y^∈[0,1]RH×RW×C,其中 R R R 是输出的步长, C C C 是关键点的个数。关键点的类型在姿态估计任务中 C = 17 C= 17 C=17,在目标检测任务中 C = 80 C = 80 C=80。文章中默认使用步长 R = 4 R = 4 R=4,输出步长将输出下采样 R R R倍。 Y ^ x , y , c = 1 \hat{Y}_{x,y,c} = 1 Y^x,y,c=1 对应了一个关键点的检测, Y ^ x , y , c = 0 \hat{Y}_{x,y,c} = 0 Y^x,y,c=0 对应了背景。
对于grouth truth的生成,文章使用高斯核来生成,如果两个高斯核重叠,采用逐元素取最大值来生成,训练的目标函数是惩罚减少的的逐元素的focal loss 逻辑回归:
L k = − 1 N ∑ x y z { ( 1 − Y ^ x y z ) α l o g ( Y ^ x y z ) , i f Y x y z = 1 ( 1 − Y x y z ) β ( Y ^ x y z ) α l o g ( 1 − Y ^ x y z ) , o t h e r w i s e L_k = \frac{-1}{N}\sum_{xyz} \left\{\begin{matrix} (1-\hat Y_{xyz})^\alpha log(\hat Y_{xyz}),\ \ \ \ \ \ \ \ \ \ if \ Y_{xyz}=1\\ (1-Y_{xyz})^\beta (\hat Y_{xyz})^\alpha log(1-\hat Y_{xyz}), otherwise \end{matrix}\right. Lk=N−1xyz∑{(1−Y^xyz)αlog(Y^xyz), if Yxyz=1(1−Yxyz)β(Y^xyz)αlog(1−Y^xyz),otherwise
为了恢复有输出步长引起的离散误差,我们另外为每一个中心点预测了一个局部偏置 O ^ ∈ R H R × W R × 2 \hat{O} \in \mathbb{R}^{\frac{H}{R} \times \frac{W}{R} \times 2} O^∈RRH×RW×2, 所有的类别c都共享同一个偏置预测。偏置使用L1 Loss来进行训练。
L o f f = 1 N ∑ p ∣ O ^ p ~ − ( p R − p ~ ) ∣ L_{off} = \frac{1}{N}\sum_{p} \left | \hat O_{\tilde{p}} - (\frac{p}{R}- \tilde{p}) \right | Loff=N1p∑∣∣∣O^p~−(Rp−p~)∣∣∣
监督只作用在关键所在的位置 p ~ \tilde{p} p~,其他位置忽略。
人体姿态估计是拟合照片中所有人体实例的k各个2D关节点的位置(在COCO数据集中k= 17)。我们把姿态考虑为关于中心点的 k × 2 k \times2 k×2维的属性,把每个关节点参数化为到中心点的偏置。我们直接回归关节点的偏差 J ^ ∈ R H R × W R × 2 \hat{J} \in \mathbb{R} ^{\frac{H}{R} \times \frac{W}{R} \times 2} J^∈RRH×RW×2。我们通过mask loss函数来忽略不可见的点。
为了refine 检测结果,我们使用标准的bottom-up的多人姿态估计来拟合了k个关节点热度图。我呢使用focal loss和局部偏置来训练关节点热度图。
然后我们
CenterNet中关于multi pose部分的grouping代码
def multi_pose_decode(
heat, wh, kps, reg=None, hm_hp=None, hp_offset=None, K=100):
batch, cat, height, width = heat.size()
num_joints = kps.shape[1] // 2
# heat = torch.sigmoid(heat)
# perform nms on heatmaps
heat = _nms(heat)
scores, inds, clses, ys, xs = _topk(heat, K=K)
kps = _transpose_and_gather_feat(kps, inds)
kps = kps.view(batch, K, num_joints * 2)
kps[..., ::2] += xs.view(batch, K, 1).expand(batch, K, num_joints)
kps[..., 1::2] += ys.view(batch, K, 1).expand(batch, K, num_joints)
if reg is not None:
reg = _transpose_and_gather_feat(reg, inds)
reg = reg.view(batch, K, 2)
xs = xs.view(batch, K, 1) + reg[:, :, 0:1]
ys = ys.view(batch, K, 1) + reg[:, :, 1:2]
else:
xs = xs.view(batch, K, 1) + 0.5
ys = ys.view(batch, K, 1) + 0.5
wh = _transpose_and_gather_feat(wh, inds)
wh = wh.view(batch, K, 2)
clses = clses.view(batch, K, 1).float()
scores = scores.view(batch, K, 1)
bboxes = torch.cat([xs - wh[..., 0:1] / 2,
ys - wh[..., 1:2] / 2,
xs + wh[..., 0:1] / 2,
ys + wh[..., 1:2] / 2], dim=2)
if hm_hp is not None:
hm_hp = _nms(hm_hp)
thresh = 0.1
kps = kps.view(batch, K, num_joints, 2).permute(
0, 2, 1, 3).contiguous() # b x J x K x 2
reg_kps = kps.unsqueeze(3).expand(batch, num_joints, K, K, 2)
hm_score, hm_inds, hm_ys, hm_xs = _topk_channel(hm_hp, K=K) # b x J x K
if hp_offset is not None:
hp_offset = _transpose_and_gather_feat(
hp_offset, hm_inds.view(batch, -1))
hp_offset = hp_offset.view(batch, num_joints, K, 2)
hm_xs = hm_xs + hp_offset[:, :, :, 0]
hm_ys = hm_ys + hp_offset[:, :, :, 1]
else:
hm_xs = hm_xs + 0.5
hm_ys = hm_ys + 0.5
mask = (hm_score > thresh).float()
hm_score = (1 - mask) * -1 + mask * hm_score
hm_ys = (1 - mask) * (-10000) + mask * hm_ys
hm_xs = (1 - mask) * (-10000) + mask * hm_xs
hm_kps = torch.stack([hm_xs, hm_ys], dim=-1).unsqueeze(
2).expand(batch, num_joints, K, K, 2)
dist = (((reg_kps - hm_kps) ** 2).sum(dim=4) ** 0.5)
min_dist, min_ind = dist.min(dim=3) # b x J x K
hm_score = hm_score.gather(2, min_ind).unsqueeze(-1) # b x J x K x 1
min_dist = min_dist.unsqueeze(-1)
min_ind = min_ind.view(batch, num_joints, K, 1, 1).expand(
batch, num_joints, K, 1, 2)
hm_kps = hm_kps.gather(3, min_ind)
hm_kps = hm_kps.view(batch, num_joints, K, 2)
l = bboxes[:, :, 0].view(batch, 1, K, 1).expand(batch, num_joints, K, 1)
t = bboxes[:, :, 1].view(batch, 1, K, 1).expand(batch, num_joints, K, 1)
r = bboxes[:, :, 2].view(batch, 1, K, 1).expand(batch, num_joints, K, 1)
b = bboxes[:, :, 3].view(batch, 1, K, 1).expand(batch, num_joints, K, 1)
mask = (hm_kps[..., 0:1] < l) + (hm_kps[..., 0:1] > r) + \
(hm_kps[..., 1:2] < t) + (hm_kps[..., 1:2] > b) + \
(hm_score < thresh) + (min_dist > (torch.max(b - t, r - l) * 0.3))
mask = (mask > 0).float().expand(batch, num_joints, K, 2)
kps = (1 - mask) * hm_kps + mask * kps
kps = kps.permute(0, 2, 1, 3).contiguous().view(
batch, K, num_joints * 2)
detections = torch.cat([bboxes, scores, kps, clses], dim=2)
return detections