姿态估计2-08:PVNet(6D姿态估计)-源码无死角解析(4)-RANSAC投票机制

以下链接是个人关于PVNet(6D姿态估计) 所有见解,如有错误欢迎大家指出,我会第一时间纠正。有兴趣的朋友可以加微信:a944284742相互讨论技术。若是帮助到了你什么,一定要记得点赞!因为这是对我最大的鼓励。
姿态估计2-00:PVNet(6D姿态估计)-目录-史上最新无死角讲解

前言

通过前面的博客,我们可以知道主干网络为作者修改过的lib/networks/pvnet/resnet18.py,修改的内容在上篇博客中有具体提到,该网络结构简单,我就不做详细的介绍了,但是其中有个比较重要的地方,我们还是需要重点分析的,在lib/networks/pvnet/resnet18.py代码中,我们可以看到如下部分:

    def decode_keypoint(self, output):
        vertex = output['vertex'].permute(0, 2, 3, 1)
        # vn_2 = 9*2 =18
        b, h, w, vn_2 = vertex.shape
        # 把x,y值分离出来,分别占用一个纬度
        vertex = vertex.view(b, h, w, vn_2//2, 2)
        # 获得前景(目标物体)对应的mask
        mask = torch.argmax(output['seg'], 1)
        # 如果使用了不确定性的pnp
        if cfg.test.un_pnp:
            # 基于RANSAC进行投票选举关键点
            mean = ransac_voting_layer_v3(mask, vertex, 512, inlier_thresh=0.99)
            # 获得关键点的概率分布
            kpt_2d, var = estimate_voting_distribution_with_mean(mask, vertex, mean)
            output.update({'mask': mask, 'kpt_2d': kpt_2d, 'var': var})
        else:
            kpt_2d = ransac_voting_layer_v3(mask, vertex, 128, inlier_thresh=0.99, max_num=100)
            output.update({'mask': mask, 'kpt_2d': kpt_2d})

其中ransac_voting_layer_v3函数是比较重要的。

ransac_voting_layer_v3

对于该函数的注释如下:

def ransac_voting_layer_v3(mask, vertex, round_hyp_num, inlier_thresh=0.999, confidence=0.99, max_iter=20,
                           min_num=5, max_num=30000):
    '''
    :param mask:      [b,h,w]
    :param vertex:    [b,h,w,vn,2]
    :param round_hyp_num:
    :param inlier_thresh:
    :return: [b,vn,2]
    '''
    b, h, w, vn, _ = vertex.shape
    batch_win_pts = []
    # 分别对每张图片进行处理
    for bi in range(b):
        #
        hyp_num = 0
        # 获得当前图片对应的mask
        cur_mask = (mask[bi]).byte()
        # 计算前景mask的和
        foreground_num = torch.sum(cur_mask)

        # if too few points, just skip it
        # 如果其前景的像数太少,则设置win_pts为0,并且continue跳过该图像的处理
        if foreground_num < min_num:
            win_pts = torch.zeros([1, vn, 2], dtype=torch.float32, device=mask.device)
            batch_win_pts.append(win_pts)  # [1,vn,2]
            continue

        # if too many inliers, we randomly down sample it
        # 如果前景层的像素点太多了
        if foreground_num > max_num:
            # 随机选取一定数目的像素点,得到新的mask
            selection = torch.zeros(cur_mask.shape, dtype=torch.float32, device=mask.device).uniform_(0, 1)
            selected_mask = (selection < (max_num / foreground_num.float())).byte()
            cur_mask *= selected_mask

        # 获得进行筛选之后的mask前景坐标
        coords = torch.nonzero(cur_mask).float()  # [tn,2]
        # 把x,y的坐标位置互换一下(估计是为了后续矩阵相乘方便)
        coords = coords[:, [1, 0]]
        # 获得当前图片,经过筛选的mask对应像素的vector方向
        direct = vertex[bi].masked_select(torch.unsqueeze(torch.unsqueeze(cur_mask, 2), 3))  # [tn,vn,2]
        direct = direct.view([coords.shape[0], vn, 2])
        # tn表示随机选取mask前景像素的数目
        tn = coords.shape[0]

        # 记录索引
        idxs = torch.zeros([round_hyp_num, vn, 2], dtype=torch.int32, device=mask.device).random_(0, direct.shape[0])
        # 记录所有vector的方向
        all_win_ratio = torch.zeros([vn], dtype=torch.float32, device=mask.device)
        # 记录所有vector的坐标
        all_win_pts = torch.zeros([vn, 2], dtype=torch.float32, device=mask.device)

        cur_iter = 0
        # 每次循环随机选择round_hyp_num数目的direct(样本)进行投票.
        while True:
            # generate hypothesis
            # 这里的cur_hyp_pts对应论文中的 hypothesis = hki,
            # 根据direct,coords随机生成N组关键点,hn代表论文中的N,cur_hyp_pts
            cur_hyp_pts = ransac_voting.generate_hypothesis(direct, coords, idxs)  # [hn,vn,2]

            # voting for hypothesis
            # 基于RANSAC策略进行投票,
            # tn 为前景像素的数目
            # vn=9,表示9个关键点
            # round_hyp_num表示局内点(样本)的数目,(具体查看RANSAC算法)
            # coords:经过筛选的前景坐标
            # direct:论文中的vertex
            # cur_hyp_pts为关键点可能存在的位置
            # cur_inlier 保存每个像素被投票为关键点票数
            cur_inlier = torch.zeros([round_hyp_num, vn, tn], dtype=torch.uint8, device=mask.device)
            ransac_voting.voting_for_hypothesis(direct, coords, cur_hyp_pts, cur_inlier, inlier_thresh)  # [hn,vn,tn]

            # find max,对投票的结果进行计数
            cur_inlier_counts = torch.sum(cur_inlier, 2)                   # [hn,vn]
            # 获得每个关键最多的票数,以及对应索引
            cur_win_counts, cur_win_idx = torch.max(cur_inlier_counts, 0)  # [vn]
            # 根据索引获得其关键点的坐标位置
            cur_win_pts = cur_hyp_pts[cur_win_idx, torch.arange(vn)]
            # 获得投票的数目/总的投票人数目(一人只持有一张票)=关键点获得票数的占比
            cur_win_ratio = cur_win_counts.float() / tn

            # update best point,如果出现了占比更高的投票,则对all_win_pts进行更新
            larger_mask = all_win_ratio < cur_win_ratio
            all_win_pts[larger_mask, :] = cur_win_pts[larger_mask, :]
            all_win_ratio[larger_mask] = cur_win_ratio[larger_mask]

            #
            hyp_num += round_hyp_num
            cur_iter += 1
            cur_min_ratio = torch.min(all_win_ratio)
            if (1 - (1 - cur_min_ratio ** 2) ** hyp_num) > confidence or cur_iter > max_iter:
                break

        # compute mean intersection again
        normal = torch.zeros_like(direct)   # [tn,vn,2]
        # x,y的坐标互换
        normal[:, :, 0] = direct[:, :, 1]
        normal[:, :, 1] = -direct[:, :, 0]
        # 0表示局外,1表示局内
        all_inlier = torch.zeros([1, vn, tn], dtype=torch.uint8, device=mask.device)
        all_win_pts = torch.unsqueeze(all_win_pts, 0)  # [1,vn,2]

        # 再一次假设交叉点(关键点)
        ransac_voting.voting_for_hypothesis(direct, coords, all_win_pts, all_inlier, inlier_thresh)  # [1,vn,tn]
        # 后续的操作本人不是很了解,感觉all_win_pts以及获得了关键点的位置,不知道为何还要有如下操作
        # 估计是为了剔除局外的投票者

        # coords [tn,2] normal [vn,tn,2]
        all_inlier = torch.squeeze(all_inlier.float(), 0)              # [vn,tn]
        # 矢量x,y的坐标互换
        normal = normal.permute(1, 0, 2)                                # [vn,tn,2]
        # 把局外投票者的方向全部清零
        normal = normal*torch.unsqueeze(all_inlier, 2)                 # [vn,tn,2] outlier is all zero
        # 注意,这里的coords表示投票者vector的坐标,normal表示其方向
        b = torch.sum(normal*torch.unsqueeze(coords, 0), 2)             # [vn,tn]


        # 剔除局外投票者,重新进行投票,获得精确的结果
        # 获得ATA矩阵,以及ATB矩阵
        ATA = torch.matmul(normal.permute(0, 2, 1), normal)              # [vn,2,2]
        ATb = torch.sum(normal*torch.unsqueeze(b, 2), 1)                # [vn,2]
        # try:
        # 根据ATA以及ATb矩阵求得坐标值
        # [vn,2,2] * [vn,2,1] = [vn,2,1]
        all_win_pts = torch.matmul(b_inv(ATA), torch.unsqueeze(ATb, 2)) # [vn,2,1]
        # except:
        #    __import__('ipdb').set_trace()
        batch_win_pts.append(all_win_pts[None,:,:, 0])

    batch_win_pts = torch.cat(batch_win_pts)
    return batch_win_pts

领读

对于RANSAC的理解,大家可以参考一下这篇博客:
RANSAC算法理解:https://blog.csdn.net/robinhjwy/article/details/79174914
总的来说,主要步骤如下:

1.像素筛选:如果前景像素太少,则该张图像忽略,如果前景像素太多,则随机剔除部分像素
2.使用ransac_voting.generate_hypothesis获得假设,即关键点可能存在的位置。
3.随机抽取样本(direct)对假设出来的关键点位置进行投票。
4.循环执行2,3步骤,直到投票的置信度达到标准
5.使用循环迭代出来的最好模型(摒弃了局外样本),再一次去生成假设坐标,并进行投票。

总的来说,就是一直假设坐标,投票,刷新最高票数坐标,假设坐标,投票,刷新最高票数坐标…一直这样下去,知道票数达到标准才停止。

你可能感兴趣的:(姿态估计2-08:PVNet(6D姿态估计)-源码无死角解析(4)-RANSAC投票机制)