原文:I2L-MeshNet: Image-to-Lixel Prediction Network for Accurate 3D Human Pose and Mesh Estimation from a Single RGB Image
代码:Pytorch
以前大多数基于图像的3D人体姿态和Mesh估计的工作都是通过估计mesh模型的参数来实现。但是直接去回归参数是一种高度非线性的映射,本文提出的 I2L-MeshNet (实现从图像到 lixel=线+像素的预测网络)则不是直接回归参数,而是来预测每个mesh顶点坐标在1D热图上的逐像素可能性。(后面lixel都称为线素)
3D人体姿态和Mesh估计目的是在恢复3D人体关节点和Mesh顶点位置。由于复杂的人类关节和2D - 3D模糊性,这是一项非常具有挑战性的任务。SMPL 和 MANO 是应用最广泛的参数化人体和手部Mesh模型,分别可以代表各种人体姿态。最近3D人体姿态和Mesh估计的研究大部分基于模型来进行,从输入图像来估计SMPL/MANO的参数,另一部分则是基于无模型方法,直接估计Mesh顶点坐标,他们通过将Mesh模型中包含的联合回归矩阵乘以估计的Mesh来获得3D姿态。但是这些基于模型以及无模型的3D姿势和Mesh估计工作都破坏了输入图像中像素之间的空间关系,因为输出阶段的FC层。
为了不破坏空间关系,最新的3D姿态估计方法中,不是通过Mesh顶点来定位关节点坐标,而是利用热图来进行,其中热图的每个值代表在输入图像的相应像素位置处的人体关节存在的可能性和深度值。因此,它保留了输入图像中像素之间的空间关系并为预测不确定性来建模。
已知 体素(volume+pixel)被定义为3D空间中的量化单元格,我们将 线素(line+pixel)定义为1D空间中的量化单元格。I2L-MeshNet为每个Mesh顶点坐标在1D热图上估计逐像素可能性,因此该网络基于无模型方法。以前基于热图的3D姿势估计方法都是预测每个关节点的3D热图。其中人体关节数为21,Mesh顶点数则要大得多(例如SMPL为6980,MANO为778)。
为更准确地进行3D人体姿势和Mesh估计,本文将 I2L-MeshNet 设计为由 PoseNet 和 MeshNet 组成的级联网络。 其中 PoseNet 预测每个3D关节点坐标的基于lixel的1D热图;MeshNet 利用PoseNet的输出以及图像特征作为输入来预测3D Mesh顶点坐标的基于lixel的1D热图。由于人体关节的位置提供了有关人体网格顶点位置的粗略但重要的信息,因此将其用于3D的Mesh估计很自然,并且可以大大提高准确性。
如Fig.2所示,I2L-MeshNet 由 PoseNet 和 MeshNet 组成,整个网络结构代码整理如下。
from main.model import *
if __name__ == '__main__':
# ==================init========================
pose_backbone = ResNetBackbone(resnet_type=50)
pose_net = PoseNet(joint_num=21)
pose2feat = Pose2Feat(joint_num=21)
mesh_backbone = ResNetBackbone(resnet_type=50)
mesh_net = MeshNet(vertex_num=778)
param_regressor = ParamRegressor(joint_num=21)
# ===================end========================
img = torch.randn(8, 3, 256, 256).cuda()
model_1 = pose_backbone.cuda()
# print(model_1)
total_params = sum(p.numel() for p in model_1.parameters())
share_img_feat, pose_img_feat = model_1(img)
print(share_img_feat.shape, pose_img_feat.shape) #[2, 64, 64, 64], [2, 2048, 8, 8]
model_2 = pose_net.cuda()
joint_img = model_2(pose_img_feat)
print(joint_img.shape)
joint_heatmap = Model().make_gaussian_heatmap(joint_coord_img=joint_img.detach())
print(joint_heatmap.shape)
model_3 = pose2feat.cuda()
next_img = pose2feat(share_img_feat, joint_heatmap)
print(next_img.shape)
model_4 = mesh_backbone.cuda()
_, pose_img_feat = model_4(next_img, skip_early=True)
print(pose_img_feat.shape)
model_5 = mesh_net.cuda()
mesh_img = model_5(pose_img_feat)
print(mesh_img.shape)
PoseNet是来估计三个基于lixel的1D所有关节点的热图
同时参照上图和下面代码来理解,先将图像送入ResNet50里,将得到的 pose_img_feat 输入到 PoseNet,该网络先上采样3次,那么由原来的 h × w h×w h×w --> 8 h × 8 w 8h×8w 8h×8w,通道数也由原来的 c = 2048 c=2048 c=2048 --> c ′ = 256 c'=256 c′=256,在对上采样后的输出进行相应维度求平均,最后再进行一次1D卷积,通道数也由原来的 c ′ = 256 c'=256 c′=256 --> J = 21 J=21 J=21。不懂可以参照原文如下。(求z则不一样,论文里说用到FC但代码里没有)
def forward(self, img_feat):
img_feat_xy = self.deconv(img_feat) # upsample 3 times
# x axis
img_feat_x = img_feat_xy.mean((2)) # (batch_size, channel, h, w)
heatmap_x = self.conv_x(img_feat_x)
coord_x = self.soft_argmax_1d(heatmap_x)
# y axis
img_feat_y = img_feat_xy.mean((3))
heatmap_y = self.conv_y(img_feat_y)
coord_y = self.soft_argmax_1d(heatmap_y)
# z axis
img_feat_z = img_feat.mean((2,3))[:,:,None]
img_feat_z = self.conv_z_1(img_feat_z)
img_feat_z = img_feat_z.view(-1,256,cfg.output_hm_shape[0])
heatmap_z = self.conv_z_2(img_feat_z)
coord_z = self.soft_argmax_1d(heatmap_z)
joint_coord = torch.cat((coord_x, coord_y, coord_z),2)
return joint_coord
大体结构和 PoseNet 一样,原文如下。
def forward(self, img_feat):
img_feat_xy = self.deconv(img_feat)
# x axis
img_feat_x = img_feat_xy.mean((2))
heatmap_x = self.conv_x(img_feat_x)
coord_x = self.soft_argmax_1d(heatmap_x)
# y axis
img_feat_y = img_feat_xy.mean((3))
heatmap_y = self.conv_y(img_feat_y)
coord_y = self.soft_argmax_1d(heatmap_y)
# z axis
img_feat_z = img_feat.mean((2,3))[:,:,None]
img_feat_z = self.conv_z_1(img_feat_z)
img_feat_z = img_feat_z.view(-1,256,cfg.output_hm_shape[0])
heatmap_z = self.conv_z_2(img_feat_z)
coord_z = self.soft_argmax_1d(heatmap_z)
mesh_coord = torch.cat((coord_x, coord_y, coord_z),2)
return mesh_coord