近期阅读了2015年的一篇较为经典的论文"spatial transformer networks(stn)"。本博文是stn阅读心得的记录。在第二小节中,会描述stn的实现细节,包括三大组成构件:localisation network、Grid generator、Sampler。在第三小节中会通过跟踪stn源码(pytorch官方版本)来验证自己的理解正确性。在第四部分作为扩展部分,会尝试从数学角度阐述STN的数学形式并作可导性分析。
图1
spatial transformer networks的提出背景:通常为了使模型在测试阶段spatial invariance, 一种常规的做法是在训练阶段做尽可能丰富的数据扩增操作(eg.shift, crop等)。而stn则是将数据扩增有机的和网络融为一体,达到learnable的效果。从实验结果来看,可较显著的提升(分类)模型的性能。
stn的核心是如图1所示的spatial transformer模块。
名称 | 说明 |
---|---|
U | 输入特征,为spatial transformer的输入 |
V | 输出特征,为spatial transformer的输出 |
localisation net | st模块的三大构件之一,后文会详述 |
Grid generator | st模块的三大构件之一,后文会详述 |
Sampler | st模块的三大构件之一,后文会详述 |
表1
Localisation net的作用是回归仿射变换的参数 θ \theta θ。图3中的公式是仿射变换操作的通式,二维空间上仿射变换的参数为6个,也即localisation net的输入为 N ∗ C ∗ H ∗ W N*C*H*W N∗C∗H∗W的特征图,输出为 N ∗ 6 N*6 N∗6。
图3
Localisation net部分的实现就是以Conv层和Linear层构成 说 明 1 ^{说明1} 说明1,具体如图4所示。这部分比较直观,就不做赘述。
图5
这一部分的作用是建立输出特征图中的坐标与输入特征图中的坐标关系。过程像素级别的操作可以用图6来表示
图6
关于图6中的公式需要注意两点:
以仿射变换的一种特例,顺时针旋转90度为例。
对于输出特征图上位置 ( 0 , 0 ) (0,0) (0,0)处的值’2’来自于输入特征图上的 ( 0 , 2 ) (0,2) (0,2)处。
对于输出特征图上位置 ( 0 , 1 ) (0,1) (0,1)处的值‘3’来自于输入特征图上的 ( 0 , 1 ) (0,1) (0,1)处。
对于输出特征图上位置 ( 0 , 2 ) (0,2) (0,2)处的值‘1’来自于输入特征图上的 ( 0 , 0 ) (0,0) (0,0)处。
按照此规律,可以得到输出特征图上点的所有“来源”。
通过2.2节中描述的Grid generator。可以得到输出特征图上各个value的"来源"矩阵:
而Sampler的过程就是基于该“来源”矩阵取索引处值的过程
STN实际上的Sampler要比这里描述的复杂一些,因为它还会涉及到一个插值操作。回到STN, 在2.1节中,已经讲明, θ \theta θ是网络学习出来的,旋转只是仿射变换的一种特例。因此大概率计算得到的“来源”并不是一个整数。
仍旧以一个实际例子来说明,假如在当前iteration,学习得到的 θ \theta θ为:
那么对于target 特征图中(0,0)处值的来源为source特征图中的(0.3,0.7)。为了处理这种坐标非整数的情形,就需要利用插值:用其附近的四个整数坐标的value来生成。图8展示了二维插值的计算方式示意图。
图8
pytorch已经将stn集成,并提供了stn pytorch tutorials。本部分主要是跟踪其中的代码,来完善并验证上述的理解。
这部分直接贴相关核心代码,细节不再赘述。可以较容易的与图4中的内容对应起来。
xs = self.localization(x)
xs = xs.view(-1, 10 * 3 * 3)
theta = self.fc_loc(xs)
theta = theta.view(-1, 2, 3)
*核心代码段2
# Spatial transformer localization-network
self.localization = nn.Sequential(
nn.Conv2d(1, 8, kernel_size=7),
nn.MaxPool2d(2, stride=2),
nn.ReLU(True),
nn.Conv2d(8, 10, kernel_size=5),
nn.MaxPool2d(2, stride=2),
nn.ReLU(True)
)
# Regressor for the 3 * 2 affine matrix
self.fc_loc = nn.Sequential(
nn.Linear(10 * 3 * 3, 32),
nn.ReLU(True),
nn.Linear(32, 3 * 2)
)
grid = F.affine_grid(theta, x.size())
x = F.grid_sample(x, grid)
这部分以实际模型训练中某一iteration的实际例子来进行说明,此时 θ \theta θ为
x . s i z e ( ) x.size() x.size()为 64 ∗ 1 ∗ 28 ∗ 28 64*1*28*28 64∗1∗28∗28。2.2节以及2.3节中的理解基本是正确的,但pytorch在具体实施的过程中,有两点需要注意:
因此这里的归一化公式为:
x n o r m = ( 2 ∗ x + 1 ) / s − 1 x_{norm} =(2*x+1)/s-1 xnorm=(2∗x+1)/s−1
y n o r m = ( 2 ∗ y + 1 ) / s − 1 y_{norm} =(2*y+1)/s-1 ynorm=(2∗y+1)/s−1
反归一化公式为
x = ( ( x n o r m + 1 ) ∗ s − 1 ) / 2 x=((x_{norm}+1)*s-1)/2 x=((xnorm+1)∗s−1)/2
y = ( ( y n o r m + 1 ) ∗ s − 1 ) / 2 y=((y_{norm}+1)*s-1)/2 y=((ynorm+1)∗s−1)/2
按照2.2中的理解,计算target特征图中(13,5)在source特征图中的来源。
step1:先利用归一化公式操作得到 ( x n o r m , y n o r m ) = ( − 0.03571 , − 0.6071 ) (x_{norm},y_{norm})=(-0.03571,-0.6071) (xnorm,ynorm)=(−0.03571,−0.6071);
step2:与 θ \theta θ相乘,得到输入特征图上的归一化坐标 ( x n , y n ) = ( − 0.1428 , − 0.6859 ) (x_n, y_n)=(-0.1428,-0.6859) (xn,yn)=(−0.1428,−0.6859),与调试的代码结果一致。
step3:对 ( x n , y n ) = ( − 0.1428 , − 0.6859 ) (x_n, y_n)=(-0.1428,-0.6859) (xn,yn)=(−0.1428,−0.6859)反归一化,得到输入特征图上的非归一化坐标(15.4994, 23.10285)。
step4:插值的四个坐标对应的特征值为:
根据图8中公式,可以算得stn输出特征图中x(13,5)处的值为2.4648。而代码打印的结果为2.4819,有一定的误差,但基本与预期相符。
以上基本证明了自己对于stn的如何实施的理解正确性。
第二节,第三节描述了stn的实施细节。但仅仅有这些还不够,我们在设计一个“创新性的”网络结构时,起效的前提或者说理论基础是该模块是differentiable。
在阐述该公式时,先暂时忘却这一公式,看一看按照之前的理解,会如何写这一过程:
上述公式等价于
进一步等价于
再做一点就可以将上述公式中四个sum因子,写成一个通式:
再继续想,我们认为和输出特征图上 i t h ith ith点有关的是4个点是一种很自然的想法。但对于pytorch来讲,需要矩阵的操作,不可能仅仅是4个点。因此上述公式,又要进一步进行转换:
上述公式可以巧妙的将非附近4个点的其他系数计算为0,从而即完成了整个输入特征的计算形式,又达到了实际仅附近4个点参与的效果。
论文中给出的导数公式为:
在对前向公式(5)的已经存在的情况下,得到上述两个偏微分公式并不复杂,因此本小节想讲一讲其他的地方。
公式(6)和公式(7)分别对应图中的圈1和圈2。进一步的可以写出圈3处的求导公式(大概形式):
可以看到 θ \theta θ是可以学习的。且在圈4时,反向传播已经转换为常规的CNN Bp操作了。
本篇论文给我的启发有4点: