Spatial Transformer Networks(STN)是一种空间注意力模型,可以通过学习对输入数据进行空间变换,从而增强网络的对图像变形、旋转等几何变换的鲁棒性。STN 可以在端到端的训练过程中自适应地学习变换参数,无需人为设置变换方式和参数。
STN 的基本结构包括三个部分:定位网络(Localization Network)、网格生成器(Grid Generator)和采样器(Sampler)。定位网络通常由卷积层、全连接层和激活函数构成,用于学习输入数据的空间变换参数。网格生成器用于生成采样网格,采样器则根据采样网格对输入数据进行采样。整个 STN 模块可以插入到任意位置,用于提高网络的对图像变形、旋转等几何变换的鲁棒性。
在 STN 中,定位网络通常由一个多层感知器(MLP)和一些辅助层(如卷积层、全连接层和激活函数)构成。MLP 的输出用于计算变换参数(如平移、旋转和缩放等),从而生成采样网格。采样器通常由双线性插值、最近邻插值和反卷积等方法实现,用于对输入数据进行采样。
STN 的优点在于,它可以学习对输入数据进行任意复杂的空间变换,从而提高网络的对图像变形、旋转等几何变换的鲁棒性。此外,STN 可以与其他深度学习模型结合使用,从而提高整个系统的性能。例如,在图像分类任务中,可以将 STN 插入到卷积神经网络中,用于对输入图像进行空间变换,增强网络对图像变形、旋转等几何变换的鲁棒性。
STN注意力模块pytorch实现:
import torch
import torch.nn as nn
import torch.nn.functional as F
class STN(nn.Module):
def __init__(self):
super(STN, self).__init__()
# 定义本地化网络,用于估计空间变换的参数
self.localization = nn.Sequential(
nn.Conv2d(1, 8, kernel_size=7), # 输入通道数为 1,输出通道数为 8,卷积核大小为 7
nn.MaxPool2d(2, stride=2), # 最大池化层,核大小为 2,步长为 2
nn.ReLU(True), # ReLU 激活函数
nn.Conv2d(8, 10, kernel_size=5), # 输入通道数为 8,输出通道数为 10,卷积核大小为 5
nn.MaxPool2d(2, stride=2), # 最大池化层,核大小为 2,步长为 2
nn.ReLU(True) # ReLU 激活函数
)
# 定义空间变换网络,用于预测空间变换的参数
self.fc_loc = nn.Sequential(
nn.Linear(10 * 3 * 3, 32), # 全连接层,输入维度为 10 * 3 * 3,输出维度为 32
nn.ReLU(True), # ReLU 激活函数
nn.Linear(32, 3 * 2) # 全连接层,输入维度为 32,输出维度为 3 * 2
)
# 初始化空间变换网络的权重和偏置
self.fc_loc[2].weight.data.zero_()
self.fc_loc[2].bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))
def forward(self, x):
# 使用本地化网络对输入图像进行特征提取
xs = self.localization(x)
# 将特征张量展开成一维张量
xs = xs.view(-1, 10 * 3 * 3)
# 使用空间变换网络预测空间变换的参数
theta = self.fc_loc(xs)
# 将一维张量转换成二维张量,用于执行仿射变换
theta = theta.view(-1, 2, 3)
# 使用仿射变换对输入图像进行空间变换
grid = F.affine_grid(theta, x.size())
x = F.grid_sample(x, grid)
return x
以上代码中,STN
类继承自 PyTorch 的 nn.Module
类,是一个包含了本地化网络和空间变换网络的模块。具体来说,STN
模块包含以下组件:
self.localization
:本地化网络,用于对输入图像进行特征提取,提取出用于估计空间变换参数的特征向量。self.fc_loc
:空间变换网络,用于根据本地化网络提取的特征向量预测空间变换的参数。self.fc_loc[2].weight.data.zero_()
和 self.fc_loc[2].bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))
:用于初始化空间变换网络的权重和偏置,其中权重矩阵初始化为零矩阵,偏置向量初始化为一个 torch.tensor
对象,其元素为 [1,0,0,0,1,0][1,0,0,0,1,0],表示初始的空间变换为一个单位矩阵。forward
方法:模块的前向传播过程。首先使用本地化网络对输入图像进行特征提取,然后将特征张量展开成一维张量,使用空间变换网络预测空间变换的参数。接着将一维张量转换成二维张量,用于执行仿射变换,并使用仿射变换对输入图像进行空间变换,最后返回变换后的图像张量。STN模块在模型中添加:
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.stn = STN()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
def forward(self, x):
# 使用 STN 对输入图像进行空间变换
x = self.stn(x)
# 经过卷积和池化层处理
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2(x), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)