[PyTorch] non-local的patch版本实现

MM 18有篇文章《Non-locally Enhanced Encoder-Decoder Network for Single Image De-raining》,这篇文章把两大吃显存利器——non-local 和 densely connection一起给用了(绝望脸),如果要控制显存使用的话,就对实现要求比较高了。里面提到了一个控制 non-local 运算的方法,将feature map划分成patch,在patch内进行non-local操作,而不是原来的全局non-local。

思路

这里实现一下这个patch版本的non-local。一般的non-local实现,请参考:Github 传送门。
思路很简单,就是把patch的索引作为一个特殊的batch索引,原来的non-local运算会逐batch中的样本进行,现在就是逐batch中的样本、逐每个样本中的patch进行了。

图1. 新的索引

上图中,索引的第1个位置表示batch,第2、3个位置表示patch,将1~3个位置的索引看成整体,作为这个特殊的“batch”的索引。此时,patch版本的 non-local 就和一般的 non-local 没有太大区别了。

实现

那么,如何实现这样的一个新索引呢?
假如有输入图像 ,首先需要将最后两个表示位置的索引分解成四个索引,两个表示块的位置,两个表示块中元素的位置,例如要将行分解成m块、列分解成n块,就得到,使用 view 方法就能实现。如果对这个实现有疑惑,可以参考附录中的例子。
然后进行转置(或者说是交换索引的位置),得到。这里使用前三个索引,表示具体某个patch(batch中某个feature map的某个patch)。最后,使用这个新的Tensor来进行non-local的操作即可,方法类似,仅仅是在前面多了两个索引。

# implementation in PyTorch
# x=>(b, c, m, h/m, n, w/n)
# e.g. nb_patches = [2, 2]
b, c, h, w = x.size()
x = x.view(b, c,
           nb_patches[0], h / nb_patches[0],
           nb_patches[1], w / nb_patches[1])
# x=>(b, m, n, h/m, w/n, c)->(b, m, n, h/m*w/n, c)
x = x.permute(0, 2, 4, 3, 5, 1).contiguous()
x = x.view(b, nb_patches[0], nb_patches[1], -1, c)

附录

图2. view 用法举例

如果不确定view的使用,可以举个简单的例子,如上图的的一个方块 ,显然大小就是。上图里中间一列,左边表示元素,右边表示原来的元素索引。如果使用 view 就是在新的 Tensor 中按顺序排列旧的 Tensor 中的元素。按如下代码重新排列元素。图中右边一列,就是使用了 view 后的新索引。

X = X.view(1, 1, 2, 4/2, 2, 4/2)

所以,对于左上角的patch,其索引是,右上角则是,也就是使用view得到新的 Tensor的第3、5个位置的索引。

待填坑: Tensor中的contiguous方法

你可能感兴趣的:([PyTorch] non-local的patch版本实现)