论文阅读 | RAFT: Recurrent All-Pairs Field Transforms for Optical Flow

RAFT: Recurrent All-Pairs Field Transforms for Optical Flow

ECCV2020光流任务best paper
论文地址:【here】
代码地址:【here】

介绍

光流是对两张相邻图像中的逐像素运动的一种估计。目前碰到的一些困难包括:物体的快速运动,遮挡、运动模糊和缺乏纹理信息的一些图案。
目前深度学习的方法在维持传统方法达到的性能的情况下,有着更快的推理速度。目前需要考虑的问题是:如何设计一个深度学习的光流估计网络,实现更好表现,更易训练和更好的泛化到不同场景。

Recurrent All-Pairs Field Transforms (RAFT)框架有如下优势:

  • SOTA精度
  • 更强泛化
  • 更高效率

RAFT的主要结构:

  • feature encoder(蓝色部分) +context encoder(灰色部分)
  • 一个全像素区域的a correlation layer,同时带多尺度池化
  • a recurrent GRU-based update operator
    论文阅读 | RAFT: Recurrent All-Pairs Field Transforms for Optical Flow_第1张图片

网络架构

  1. Feature encoder:

卷积网络,做了8倍下采样,两张图共享一个网络权重

  1. context encoder:

和feature encoder 一样的网络结构,只作用在左图,作为后续GRU的参数和左图特征

  1. correlation volume生成-相似度的计算

拿Feature encoder得到的两张8倍下采样图后的特征,通过逐像素间的特征相乘再求和可以得到一个逐像素间的相似度,利用的是余弦相似度的计算方式。,

  1. Correlation Pyramid生成

由于correlation volume用于生成cost volume,即相邻像素区域之间的一个相似度(correlation volume是全局像素间的一个相似度),需要对correlation volume进行领域取值才能得到cost volume。
correlation volume: H * W * H * W
cost volume: H * W * delta h * delta w

这样导致如果要搜寻更远空间(larger displacement)内的对应像素,delta h * delta w 会很大,导致占用很大的计算资源

于是本文根据这样的缺点,提出一种相关性金字塔Correlation Pyramid:
即构建了四个不同大小的correlation volume,通过对原始大小的correlation volume 池化得到尺寸为H * W * H/2 * W/2, H * W * H/4 * W/4,以此类推的Correlation Pyramid
论文阅读 | RAFT: Recurrent All-Pairs Field Transforms for Optical Flow_第2张图片
途中阐释的图correlation volume的构建过程,即C3的correlation volume得到的是image2右图中一个方格内所有的像素点与左图image1某一个像素点的匹配相似度。
论文阅读 | RAFT: Recurrent All-Pairs Field Transforms for Optical Flow_第3张图片
构建这样一个金字塔的correlation volume,目的是为了实现不同范围的搜寻空间。在最小的 H * W * H/8 * W/8 correlation volume的上,同样的半径范围r,对应原图的搜寻半径范围是8r.

构建Correlation Pyramid代码如下:

        corr = CorrBlock.corr(fmap1, fmap2)

        batch, h1, w1, dim, h2, w2 = corr.shape
        corr = corr.reshape(batch*h1*w1, dim, h2, w2)
        
        self.corr_pyramid.append(corr)
        for i in range(self.num_levels-1):
            corr = F.avg_pool2d(corr, 2, stride=2)
            self.corr_pyramid.append(corr)
  1. Correlation Lookup

这个步骤也就是上一个节,第3节中提到的correlation volume生成cost volume的过程。
具体操作为,在x维度上,生成一个索引图,H * W * (2r+1),存储每个对应的像素点的相邻坐标索引,用这个索引在Correlation Pyramid中取值,得到4个,尺寸为H * W * (2r+1)的cost volume,最后在特征层做特征连接合并不同范围位移的cost volume, 得到一个金字塔范围的cost volume。在y的维度上做同样的操作
在这里插入图片描述

代码如下

        r = self.radius
        coords = coords.permute(0, 2, 3, 1)
        batch, h1, w1, _ = coords.shape

        out_pyramid = []
        for i in range(self.num_levels):
            corr = self.corr_pyramid[i]
            dx = torch.linspace(-r, r, 2*r+1, device=coords.device)
            dy = torch.linspace(-r, r, 2*r+1, device=coords.device)
            delta = torch.stack(torch.meshgrid(dy, dx), axis=-1)

            centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i
            delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
            coords_lvl = centroid_lvl + delta_lvl

            corr = bilinear_sampler(corr, coords_lvl)
            corr = corr.view(batch, h1, w1, -1)
            out_pyramid.append(corr)

        out = torch.cat(out_pyramid, dim=-1)
  1. 迭代更新过程
    RAFT采用GRU不断迭代更新光流,先将光流初始化0,再不断通过计算的cost volume迭代更新光流,再用将新得到的光流与cost volume优化新的光流
    论文阅读 | RAFT: Recurrent All-Pairs Field Transforms for Optical Flow_第4张图片
    这里的光流用于直接查找 cost volume,因此是绝对值,最后的值要与最初的光流相减

  2. upsample过程
    由于整个过程都是再8倍下采样分辨率下,因此最后做了一个upsample.
    upsample用mask学习周围邻域的分布权重情况,做加权mask的upsample.
    在这里插入图片描述

    def upsample_flow(self, flow, mask):
        """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """
        N, _, H, W = flow.shape
        mask = mask.view(N, 1, 9, 8, 8, H, W)
        mask = torch.softmax(mask, dim=2)

        up_flow = F.unfold(8 * flow, [3,3], padding=1)
        up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)

        up_flow = torch.sum(mask * up_flow, dim=2)
        up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
        return up_flow.reshape(N, 2, 8*H, 8*W)
  1. 损失函数直接用L1损失

实验

精度
论文阅读 | RAFT: Recurrent All-Pairs Field Transforms for Optical Flow_第5张图片

效率
论文阅读 | RAFT: Recurrent All-Pairs Field Transforms for Optical Flow_第6张图片

总结

本文的优势:精度好、效率高,在不同数据集上表现都好

你可能感兴趣的:(论文阅读,论文阅读,计算机视觉,人工智能,python,笔记)