Pointnet++分割代码解读

最近在看PointCONV,发现它在结构上借鉴并修改了Pointnet++的设计,之前偏懒没有仔细去看Pointnet++,现在只能回去看。

Pointnet++整体结构:
Pointnet++分割代码解读_第1张图片
首先先看encoder部分,其中包含了两个SetAbstraction,每个SetAbstraction由三部分组成,包括采样层、分组层和Pointnet层。在图中画出了两个SetAbstraction模块,但是在代码中,具体任务作者使用了3~4个SetAbstraction模块,这是可以理解的,毕竟理论上越多,提取的特征越细致。
采样层使用的是FPS,这个网上有很多资料,就不详细说了,然后是分组,这也没什么好说的,使用的是query_ball,对于分布不均匀的点云,作者又采用的技术建议一起看这里。下面上代码:
其中,point代表了包含xyz在内的特征,例如RGB或者法向量。xyz是只包含xyz特征的点云。

class PointNetSetAbstraction(nn.Module):
    def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all):
        super(PointNetSetAbstraction, self).__init__()
        self.npoint = npoint
        self.radius = radius
        self.nsample = nsample
        self.mlp_convs = nn.ModuleList()
        self.mlp_bns = nn.ModuleList()
        last_channel = in_channel
        for out_channel in mlp:
            self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1))
            self.mlp_bns.append(nn.BatchNorm2d(out_channel))
            last_channel = out_channel
        self.group_all = group_all

    def forward(self, xyz, points):
        """
        Input:
            xyz: input points position data, [B, C, N]
            points: input points data, [B, D, N]
        Return:
            new_xyz: sampled points position data, [B, C, S]
            new_points_concat: sample points feature data, [B, D', S]
        """
        xyz = xyz.permute(0, 2, 1)
        if points is not None:
            points = points.permute(0, 2, 1)

        if self.group_all:
            new_xyz, new_points = sample_and_group_all(xyz, points)
        else:
            new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points)
        # new_xyz: sampled points position data, [B, npoint, C]
        # new_points: sampled points data, [B, npoint, nsample, C+D]
        new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint]
        for i, conv in enumerate(self.mlp_convs):
            bn = self.mlp_bns[i]
            new_points =  F.relu(bn(conv(new_points)))

        new_points = torch.max(new_points, 2)[0]
        new_xyz = new_xyz.permute(0, 2, 1)
        return new_xyz, new_points

值得注意的是,作者在pointnet层只是用了一个1×1的卷积操作来替代了point net,当然这可以看作是point net的初始版本,但是去除了STN是否会导致效果下降,这一点仍待研究。
下面是我最想记录的关于分割部分的decoder,分类decoder太简单就不说了,首先,为什么要对整个网络执行反向插值操作,因为分割任务最终是每个点的分类任务,我们需要得到充分的局部信息才能让分割任务准确。然而经过几个SA操作之后,我们得到的感受野很大,接近于全局信息,因此我们需要更加精确的局部信息。其次,如何得到局部信息。作者对经过encoder的点云N*d,进行了反向插值,什么是反向插值呢,在论文中,作者给出了公式:在这里插入图片描述
具体的含义参考上面蓝字。那么下面上代码:

class PointNetFeaturePropagation(nn.Module):
    def __init__(self, in_channel, mlp):
        super(PointNetFeaturePropagation, self).__init__()
        self.mlp_convs = nn.ModuleList()
        self.mlp_bns = nn.ModuleList()
        last_channel = in_channel
        for out_channel in mlp:
            self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1))
            self.mlp_bns.append(nn.BatchNorm1d(out_channel))
            last_channel = out_channel

    def forward(self, xyz1, xyz2, points1, points2):
        """
        Input:
            xyz1: input points position data, [B, C, N]
            xyz2: sampled input points position data, [B, C, S]
            points1: input points data, [B, D, N]
            points2: input points data, [B, D, S]
        Return:
            new_points: upsampled points data, [B, D', N]
        """
        xyz1 = xyz1.permute(0, 2, 1)#6*64*3
        xyz2 = xyz2.permute(0, 2, 1)#6*16*3

        points2 = points2.permute(0, 2, 1)#6*16*512
        B, N, C = xyz1.shape
        _, S, _ = xyz2.shape

        if S == 1:
            interpolated_points = points2.repeat(1, N, 1)
        else:
            dists = square_distance(xyz1, xyz2)
            #print(dists)
            dists, idx = dists.sort(dim=-1)
            #print(idx.shape)
            dists, idx = dists[:, :, :3], idx[:, :, :3]  # [B, N, 3]取离上一层每个点的三个最近的点

            dist_recip = 1.0 / (dists + 1e-8)#权重
            norm = torch.sum(dist_recip, dim=2, keepdim=True)#分母
            weight = dist_recip / norm
            #print(norm.shape)
            interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2)


        if points1 is not None:
            points1 = points1.permute(0, 2, 1)
            new_points = torch.cat([points1, interpolated_points], dim=-1)
            #print(new_points.shape)
        else:
            new_points = interpolated_points

        new_points = new_points.permute(0, 2, 1)
        for i, conv in enumerate(self.mlp_convs):
            bn = self.mlp_bns[i]
            new_points = F.relu(bn(conv(new_points)))
        return new_points

我们慢慢来看,以作者给出的例子为例,我们输入的是一个batch为6,特征为9的2048个点组成的点云。正向传播的代码为:

class get_model(nn.Module):
    def __init__(self, num_classes):
        super(get_model, self).__init__()
        self.sa1 = PointNetSetAbstraction(1024, 0.1, 32, 9 + 3, [32, 32, 64], False)
        self.sa2 = PointNetSetAbstraction(256, 0.2, 32, 64 + 3, [64, 64, 128], False)
        self.sa3 = PointNetSetAbstraction(64, 0.4, 32, 128 + 3, [128, 128, 256], False)
        self.sa4 = PointNetSetAbstraction(16, 0.8, 32, 256 + 3, [256, 256, 512], False)
        self.fp4 = PointNetFeaturePropagation(768, [256, 256])
        self.fp3 = PointNetFeaturePropagation(384, [256, 256])
        self.fp2 = PointNetFeaturePropagation(320, [256, 128])
        self.fp1 = PointNetFeaturePropagation(128, [128, 128, 128])
        self.conv1 = nn.Conv1d(128, 128, 1)
        self.bn1 = nn.BatchNorm1d(128)
        self.drop1 = nn.Dropout(0.5)
        self.conv2 = nn.Conv1d(128, num_classes, 1)

    def forward(self, xyz):
        l0_points = xyz
        l0_xyz = xyz[:, :3, :]

        l1_xyz, l1_points = self.sa1(l0_xyz, l0_points)
        #print(l1_xyz.shape,l1_points.shape)
        l2_xyz, l2_points = self.sa2(l1_xyz, l1_points)
        #print(l2_xyz.shape, l2_points.shape)
        l3_xyz, l3_points = self.sa3(l2_xyz, l2_points)
        #print(l3_xyz.shape, l3_points.shape)
        l4_xyz, l4_points = self.sa4(l3_xyz, l3_points)
        #print(l4_xyz.shape, l4_points.shape)

        l3_points = self.fp4(l3_xyz, l4_xyz, l3_points, l4_points)
        l2_points = self.fp3(l2_xyz, l3_xyz, l2_points, l3_points)
        l1_points = self.fp2(l1_xyz, l2_xyz, l1_points, l2_points)
        l0_points = self.fp1(l0_xyz, l1_xyz, None, l1_points)

        x = self.drop1(F.relu(self.bn1(self.conv1(l0_points))))
        x = self.conv2(x)
        x = F.log_softmax(x, dim=1)
        x = x.permute(0, 2, 1)
        return  x,l4_points

在经历了三个SA模块后,l4_xyz变为了6×3×16,l4_points变为6×512×16,l3_xyz为6×3×64,l3_points为6×256×64。可以看到,随着特征提取的进行,点的数量在减少,而特征值在增加。那么反向插值就是根据l3中的点坐标,选择距离最近的l4中的三个点,通过这三个点的特征值来计算出这个点的特征值,然后与l4中的点进行拼接,达到扩充点的目的,而且包含局部与全局信息。
代码是如何实现的呢,首先计算出l3和l4中点的欧氏距离square_distance(xyz1, xyz2),得到66416的矩阵,然后选择对距离进行排序dists.sort(dim=-1),得到的结果是从低到高的距离值,和对应的每个距离值对应的点的编号idx。对于一个64×16的矩阵,一行16个值代表了l3中的64个点对应的距离。因此经过排序后取前三个值,就取得了对于l3层每个点距离最近的l4层的三个点以及其编号 dists, idx = dists[:, :, :3], idx[:, :, :3]。
接下来是根据这三个点的距离值计算权重,也就是距离的倒数占三个点的总倒数的百分比。计算得到权重之后与三个点的特征值相乘torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2),index_points为获得对应点的特征值。然后与points1拼接,得到最终的点集合。再经过反向插值之后,获得的特征值送入到反向pointnet中去进一步提取和融合特征,值得注意的是,虽然作者图中标明了unit pointnet,实际上就是一个卷积操作,并且没有max函数。经过多次之后,最后得到每个点的分类情况。

你可能感兴趣的:(深度学习,深度学习,人工智能,计算机视觉)