姿态识别中的loss计算方法总结

1 stack_hourglass

stackhourglass是采用的annotation为h5格式文件,所以在数据读取,判断过程与json格式的文件会有所不同,并且是heatmap比较的方法。

1.1 train

利用h5py.File文件解析功能,解析h5文件,并将结果分别存放于center、scale、part、visible、normalize列表中。

根据output_res和关键点个数利用GenerateHeatmap函数求得heatmaps(8,16,64,64)。计算每个样本nstack个的预测值得到combined_hm_preds(8,8,16,64,64)。其中第一个8是nstack值,第二个为batchsize。估计将每个nstack的loss后得到combined_loss(8,8)。通过torch.mean计算combined_loss所有元素的平均值得到所有nstack的loss总值,并记录。最后loss.backward反向传播,更新模型。

1.2 test

在计算test的loss时是直接采用的:

error = np.linalg.norm(p[0]['keypoints'][j, :2]-g[0, j, :2]) / normalize

然后通过error和bound(阈值,一般为0.5)比较大小,得到满足要求的结果个数。

2 HRNet

HRNet是采用的annotation为json格式文件,并且是heatmap比较的方法。

利用json.load(anno_file)函数读取json文件,存储在anno列表中。anno元素个数为22246,即所有图片数量。通过转换将数据分为image、center、scale、joints_3d、joints_3d_vis、filename和imgnum存入到列表gt_db中。

通过仿射变换将db中的数据存储到joints和joints_vis数组中。将joints, joints_vis数组通过generate_target函数生成heatmap并保存到target, target_weight中。

train和test计算loss方法相同,具体如下所示:

outputs、target:16,16,64,64。target_weight:16,16,1。

将outputs按照16个关节点分为heatmaps_pred列表,每个元素为(16,1,4096)。target同理变换为heatmaps_gt列表。如果用到target_weight,则将两列表与target_weight相乘,然后进行下一步的对比。对比采用的nn.MSELoss(reduction='mean')方法。

3 PRTR

PRTR结合了transformer,所以它的outputs、target经过特殊复杂的处理,对提高识别的准确性有所帮助。

你可能感兴趣的:(姿态识别,深度学习,计算机视觉)