[论文翻译]Spatial Transformer Networks(STN)

基本是把原文翻译了一遍。
看完之后感兴趣的话,可以点链接查看我对这篇论文的个人理解。
或者查看如何用pytorch在MNIST数据集上实现简单的STN——完整代码解读&运行。
[论文翻译]Spatial Transformer Networks(STN)_第1张图片

0 摘要

CNN在处理经过空间变换的图像时不具有invariance的性质。

本文提出了一种可训练、可微的Spatial Transformer(ST),它可以被插入到CNN的任何一个地方。

STN可以教正图像的translation, scale, rotation以及更一般的warping,取得了较好的分类效果。

1 Intro

CNN在一系列图像任务上取得了巨大的成就……(略)

local max-pooling layers使得CNN具有一定的translation invariance,但是池化层通常比较小(比如2*2 pixal),因此CNN只在深层的特征层才具有translation invariance。对于大范围的transformation,CNN实际上不具有invariance。

ST直接从task中进行学习,不需要额外的supervison。

池化层的感受野(receptive field)是固定(fixed)且局部(local)的的,但ST的行为取决于单个数据样本(意思是对于不同input image,ST的行为是随之改变的),因此是非局部的(non-locally)、动态的。

值得注意的是,ST可以用反向传播机制(back-propagation,BP)进行端到端(end-to-end)训练。

[论文翻译]Spatial Transformer Networks(STN)_第2张图片图1 (a)经过旋转、缩放、平移的MNIST数据集 (b)ST预测的,在原图上应该截取的区域 (c)ST的输出(d)后续全连接层数字识别的结果

ST插入到CNN中有利于提高CNN在很多任务上的性能,比如:
(1)image classification
(2)co-localisation
(3)spatial attention

2 前人相关工作

[论文翻译]Spatial Transformer Networks(STN)_第3张图片
图2 ST的结构。输入的特征图U被本地网络处理得到参数theta,然后经过网格生成器得到采样器,映射到原图U上,从而得到输出V。

3 Spatial Transformers

ST是一个可微模块,它在单个前向传递过程中对特征映射应用空间转换,其中转换取决于特定的输入,产生单个输出特征映射。对于多通道输入,相同的warping被应用在每个通道上。为了简单起见,在本节中,我们考虑单个transformer的单个transform和单个output,但是我们可以推广到多个变换,如实验所示。

Spatial Transformer由三个部分组成:

  • localisation network
    输入feature map的经过localisation network中的一系列的隐藏层后会得到一个预测参数。
  • grid generator
    在grid generator中,接着,参数会被用于创建一个sampling grid。接下来这个sampling grid会作用在input map上。
  • sampler
    最后,input map和sampling grid会被输入到sampler中,从而产生output map。

下面将详细讲解三个部分。

3.1 Localisation Network

输入feature map长为H,宽为W,通道数目为C
U ∈ R H × W × C U \in {R^{H \times W \times C}} URH×W×C
T θ {T_\theta} Tθ作用在input feature map上,得到参数 θ = f 1 o c ( U ) \theta={f_{1oc}}\left( U \right) θ=f1oc(U) θ \theta θ的size取决于设计好的转换方式(例如,在仿射变换中, θ \theta θ是6维的)

Localisation Network的函数 f 1 o c ( ) {f_{1oc}}\left( {} \right) f1oc()可以是任何形式,比如全连接层或卷积层。无论是那种形式,最后都需要包括一个regression layer来产生转换参数 θ \theta θ

3.2 Parameterised Sampling Grid

为了对input feature map执行变形,每个输出的pixel都是通过将sampling kernel作用在input feature map的特定位置上得到的(下一节会详细讲解)。通过逐点(by pixel)的方式,我们得到了一个generic feature map(generic feature map不一定是某种图片)。通常来说,generic feature map上的pixels会处在一个grid上。(注:grid用G表示, G = { G i } G = \left\{ {{G_i}} \right\} G={Gi},Gi为一个个网格点。 G i = { x i t , y i t } {G_i} = \left\{ {x_i^t,y_i^t} \right\} Gi={xit,yit})generic feature map上的pixels组成了output feature map V V V,其中 V ∈ R H ′ × W ′ × C V \in {R^{H' \times W' \times C}} VRH×W×C H ′ H' H为grid的高, W ′ W' W为grid的宽, C C C是通道数,保持不变。
[论文翻译]Spatial Transformer Networks(STN)_第4张图片
图3 两个利用参数化采样网格从输入U得到输出V的例子 (a)由恒等换换控制的gird,G=T_I(G) (b)theta控制的仿射变换T_theta(G)

为了解释清楚这件事,我们先假设 T θ {T_\theta } Tθ是由参数 θ {\theta } θ决定的2D仿射变换,仿射变换矩阵是 A θ {A_\theta } Aθ(接下来还会讲解其他形式的变换)。在仿射变换的例子中,

其中 ( x i t , y i t ) \left( {x_i^t,y_i^t} \right) (xit,yit)为output feature map的grid。 ( x i s , y i s ) \left( {x_i^s,y_i^s} \right) (xis,yis)是source feature map。我们使用归一化的坐标,例如将 − 1 ≤ x i t , y i t ≤ 1 { - 1 \le x_i^t,y_i^t \le 1} 1xit,yit1作为output的坐标边界,将 − 1 ≤ x i s , y i s ≤ 1 { - 1 \le x_i^s,y_i^s \le 1} 1xis,yis1作为inpt的坐标边界。这种source/target的转换与采样和图形学(graphics)中的texture mapping和coordinates是一样的。

下图中的变换允许对input feature map进行裁剪(cropping)、平移、旋转、缩放、扭曲(skew)。这个变换允许裁剪,因为如果transformation是contraction,则经过映射的grid会将位于面积小于 x s t , y s t x_s^t,y_s^t xst,yst范围的平行四边形中。这种transformation作用在grid上的特点的效果如图3所示。
[论文翻译]Spatial Transformer Networks(STN)_第5张图片
通过施加特定的约束,得到一些特殊的变化方法 T θ {T_\theta } Tθ,例如下图的 A θ A_\theta Aθ采用了注意力机制。
在这里插入图片描述
通过控制 s , t x , t y s,tx,ty s,tx,ty,这种 A θ A_\theta Aθ可以实现裁剪、平移、各向同性的伸缩。 T θ {T_\theta } Tθ也可以更一般化,例如含有8个参数的平面投影变换、分段仿射变换或薄板样条变换。

事实上, T θ {T_\theta } Tθ可以是任何参数化的形式,只要它相对于参数是可微的——这样一来才能通过 【from 采样点 T θ ( G i ) {T_\theta }\left( {{G_i}} \right) Tθ(Gi) to θ \theta θ 】进行反向传播(感觉这里比较绕哈哈)。如果transformation的参数是结构化、低维的,那么Localisation Network的复杂度就会降低。例如, T θ = M θ B {T_\theta } = {M_\theta }B Tθ=MθB是注意力、仿射、投影和薄板样条扩展集,其中B是目标grid的representation, M θ {M_\theta } Mθ是由 θ \theta θ参数化的矩阵。

3.3 Differentiable Image Sampling

为了将ST作用在input feature map上,必须将上一步获得的采样点。 T θ ( G ) {T_\theta }\left( G \right) Tθ(G)和input feature map U U U输入采样器,从而产生output feature map V V V。公式如下所示:
在这里插入图片描述
V i C V_i^C ViC指将输出图V向量化后,图上第i个像素点的值,即 ( x i t , y i t ) \left( {x_i^t,y_i^t} \right) (xit,yit)的值。 V i C V_i^C ViC的长度是 H ′ W ′ H'W' HW

U m n C U_{mn}^C UmnC是输入在(m,n)处、通道C的值。

k ( ) k\left( {} \right) k()是采样核(sampling kernel),而 Φ x {\Phi _x} Φx Φ y {\Phi _y} Φy是采样核函数的参数。

注意,采样过程对每个通道C都是等同的,因此可以保证通道间的spatial consistency。

理论上,任何采样核 k ( ) k\left( {} \right) k()都可以被使用,只要相对于 ( x i s , y i s ) \left( {x_i^s,y_i^s} \right) (xis,yis)的梯度可以被确定。例如,使用整数采样核:
[论文翻译]Spatial Transformer Networks(STN)_第6张图片
其中 ⌊ x i s + 0.5 ⌋ \left\lfloor {x_i^s + 0.5} \right\rfloor xis+0.5将x取整为最近的一个整数, δ ( ) \delta \left( {} \right) δ()是 Kronecker delta function。这个采样核相当于将距离 ( x i s , y i s ) \left( {x_i^s,y_i^s} \right) (xis,yis)最近的pixel的值传递给 ( x i t , y i t ) \left( {x_i^t,y_i^t} \right) (xit,yit)

类似地,双线性采样核也可以被使用:
在这里插入图片描述
为了反向传播计算loss,我们可以计算相对于U和G的梯度。对于双线性插值,偏导数可以表示为:

[论文翻译]Spatial Transformer Networks(STN)_第7张图片
这套机制不仅允许loss梯度反向传递到input feature map,也可以传递到参数 θ \theta θ和Localisation Network。由于采样函数的不连续性,必须使用次梯度。通过忽略所有输入位置的和,值查看kernel的区域,这种采样机制可以在GPU上很快地执行。

3.4 Spatial Transformer Networks

localisation network, grid generator, 以及 sampler 的组合形成了一个 spatial transformer(ST)。这种独立自足的模块可以以任何数量、在任何位置插入CNN,从而形成 spatial transformer networks(STN)。简单使用时,加入ST的CNN计算上依然非常快,并且不会降低训练速度,甚至会因为下采样的缘故加速注意力模型的训练速度。

将ST放置在CNN中允许网络学习如何主动地转换特征映射,以帮助网络在训练期间最小化网络的overall loss function。如何对训练样本转换的知识被压缩在ST模块的网络权重中。在一些任务中,将 θ \theta θ提取出来是很有用的,因为它显式地对transformation进行了编码。

ST也可以对feature map进行降采样或过采样,因为输出维度W’和H’可以和输入维度W和H不一样。然而,当采样核是固定且只有small spatial transformer能力的时候,降采样会导致混叠效应。

[论文翻译]Spatial Transformer Networks(STN)_第8张图片
表1 不同模型对于不同数据集的错误率。MNIST数据集经过了不同的处理,符号含义如下:R是旋转,T是平移,S是缩放,E是弹性形变。所有的模型的参数theta数量是一样的
[论文翻译]Spatial Transformer Networks(STN)_第9张图片
图4 一些CNN分类错误,但STN分类正确的侧视图。(a)是输入图 (b)是transformer预测的变换,gird被可视化。©是输出

最后,在CNN中含有复数个ST是可能的。将多个ST放在CNN更深的层允许Localisation Network有更抽象的表达能力。ST的局限性在于,ST的并行的数量限制了网络能够识别的物体的数量。

4 Experiments

这一节中我们将ST运用在多个有监督学习任务上。4.1中我们测试了MNIST手写数据集,4.2中我们测试了识别街景房屋编码,4.3中我们运用多个并行ST在CUB-200-2011鸟类数据集上进行测试。进一步在MNIST的实验,以及共定位实验可以在附录A中找到。

4.1 Distorted MNIST

[论文翻译]Spatial Transformer Networks(STN)_第10张图片
(前面一大堆描述略)……实验中可以看出, 薄板样条变换是最强大的,它能将弹性形变后的数字转换为数字的原型,以减小误差,并且在简单的样本上也没有过拟合。有趣的是,对input的所有转换都把数字转换为了“标准的”直立的pose——这是训练数据中的平均pose。

4.2 Street View House Numbers

[论文翻译]Spatial Transformer Networks(STN)_第11张图片
(略)

4.3 Fine-Grained Classification

(略)

5 Conclusion

本文引入了一种新的神经网络模块——the spatial transformer。这个模块可以该模块可以被放入网络中,并执行特征的显式空间转换,为神经网络建模数据开辟了新的途径,并以端到端的方式学习,而不对损失函数进行任何更改。此外,ST的回归变换参数可作为输出,并可用于后续任务。 虽然我们在本工作中只探索前馈网络,但早期的实验表明,ST在递归模型(recurrent model)中是强大的,便于进行对象参考帧的分离,且易于扩展到涉及3D转换的任务。

6 附录

(略,以后有空再补充)

你可能感兴趣的:(神经网络)