论文地址: https://arxiv.org/abs/1506.02025
这几天看了下stn,大概写一写吧。说实话,这个东西思想倒是蛮有意思的,但是实际用起来效果不好说,至少在我想要应用的场景下效果不怎么样。
这里先写论文的思路,再写一下我做的一些实验与相应的思考。
我们知道,CNN推动了计算机视觉的发展,但是还是有一些缺陷。在“Visualizing and Understanding Convolutional Networks”这篇论文里其实就有所论述,实际上,卷积神经网络有一定的空间不变性,包括平移,缩放等。
但实际上对于一些变换比较大的问题仍然存在瓶颈,STN的就是为了解决这些问题而提出的一个网络。
STN的结构并不复杂,首先还是拿论文里的图来说一下:
左边的U是输入图片(feature_map),在论文中,作者表示stn可以加在卷积网络的任何位置,所以这里可以是输入图像也可以是经过若干层卷积的特征图。
接下来是一个称为Localisation net的网络,这个网络可以是任意的结构,全连接或者卷积网络都是可以的,这个网络的目标就是通过给定的输入,学习一组用于变换的参数。这个参数的数量是6个,也就是说,Localisation net有6个输出。
接下来用这6个参数对原始的输入做线性变换,生成一个新的输出V,这个输出的channel数与U的channel数相同,并且各个channel所做的变换也是相同的。
这个变换并不复杂,不过需要一些线性代数的知识,我们把输入的矩阵的位置用(x, y)表示,我们就可以得到一个(w, h, 2)的坐标矩阵,将这个坐标增加一维,填1,就得到了一个(w, h, 3)的坐标矩阵。至于为什么要填这个1,主要是仿射变换的需要,旋转、平移、缩放是不需要的,具体细节可以去看一下这些操作的矩阵实现,这里就不细说了。然后我们学到的6个参数变换为(3,2)的矩阵,进行矩阵乘法,就可以得到变换完的坐标,同样是(w, h, 2)的矩阵。
获得新的坐标以后,下一步就是进行采样,把对应的内容填充到矩阵V中,这一步不难理解,可以用双线性插值来完成。
上一步说了,要使用采样的方法把变换后的坐标映射到矩阵V上。理论上哪种方法都可以,但是有一点比较关键,因为要训练网络,需要产生梯度,像最邻近插值这种方法,实际上是不能产生梯度的,因为它只是把矩阵U的内容移动到V上面,本身并没有任何变化。所以作者使用了双线性插值的方法,由于一个点的位置,是由其周围四个点的值计算插值而来,就“产生”了新的值,导致梯度的产生,有了网络训练的基本条件。
论文里发了个效果视频:
https://drive.google.com/file/d/0B1nQa_sA3W2iN3RQLXVFRkNXN0k/view
视频里的效果看起来还是很好的,能够把手写数字矫正,看起来还是很神奇的。
我也尝试使用stn训练旋转的文字,不过效果并不好,也不存在说能够把旋转文字矫正这一现象,所以我产生了一些疑惑,这个网络真的能矫正旋转的文字吗?
后来想了一下,这一点其实理论上并不能说通,因为文字的“正”方向实际上是人类的先验知识,这一点实际上并不能从数据和网络上体现,也就是说,网络并不具备这部分知识,它本身并不知道文字的正反, 那又如何能“矫正”文字呢?另一方面,无论卷积网络还是全连接网络,其实也不存在旋转不变性,Localisation net又怎么学习旋转的性质呢?
带着疑问,我做了几个实验,想要复现论文的结果。
找了几个stn的实现,其实大同小异了,尝试了一下:https://github.com/zsdonghao/Spatial-Transformer-Nets的实现。
发现效果并不理想,stn看起来只是找到了文字并把它填充了图片的视野,并没有旋转的效果:
输入:
stn矫正后:
这个实现不知道是不是有些问题,调用了tensorlayer的一些api,不太好说。找了一个pytorch的实现,是pytorch官方教程中的:https://pytorch.org/tutorials/intermediate/spatial_transformer_tutorial.html#depicting-spatial-transformer-networks
这个代码有点问题的,没有对输入数据进行旋转,直接stn,得不到我想要的效果,所以稍微改了一下,把输入数据先旋转,然后再扔进去训。下面是完整代码:
# coding: utf-8
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
import cv2
import random
plt.ion()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Training dataset
train_loader = torch.utils.data.DataLoader(
datasets.MNIST(root='.', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])), batch_size=64, shuffle=True, num_workers=4)
# Test dataset
test_loader = torch.utils.data.DataLoader(
datasets.MNIST(root='.', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])), batch_size=128, shuffle=True, num_workers=4)
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
# Spatial transformer localization-network
self.localization = nn.Sequential(
nn.Conv2d(1, 8, kernel_size=7),
nn.MaxPool2d(2, stride=2),
nn.ReLU(True),
nn.Conv2d(8, 10, kernel_size=5),
nn.MaxPool2d(2, stride=2),
nn.ReLU(True)
)
# Regressor for the 3 * 2 affine matrix
self.fc_loc = nn.Sequential(
nn.Linear(10 * 3 * 3, 32),
nn.ReLU(True),
nn.Linear(32, 3 * 2)
)
# Initialize the weights/bias with identity transformation
self.fc_loc[2].weight.data.zero_()
self.fc_loc[2].bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))
# Spatial transformer network forward function
def stn(self, x):
xs = self.localization(x)
xs = xs.view(-1, 10 * 3 * 3)
theta = self.fc_loc(xs)
theta = theta.view(-1, 2, 3)
grid = F.affine_grid(theta, x.size())
x = F.grid_sample(x, grid)
return x
def forward(self, x):
# transform the input
x = self.stn(x)
# Perform the usual forward pass
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return F.log_softmax(x, dim=1)
model = Net().to(device)
optimizer = optim.SGD(model.parameters(), lr=0.01)
def train(epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(rotate_image_tensors(data))
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % 500 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
def test():
with torch.no_grad():
model.eval()
test_loss = 0
correct = 0
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(rotate_image_tensors(data))
# sum up batch loss
test_loss += F.nll_loss(output, target, size_average=False).item()
# get the index of the max log-probability
pred = output.max(1, keepdim=True)[1]
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'
.format(test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
def rotate_image(img):
rows, cols, ch = img.shape
M = cv2.getRotationMatrix2D((cols/2, rows/2), random.randint(-135, -45), 1)
dst = cv2.warpAffine(img, M, (cols, rows))
return dst.reshape((28, 28, 1))
def rotate_image_tensors(image_tensors):
batch_size, c, h, w = image_tensors.shape
sp_imgs = torch.chunk(image_tensors, batch_size, dim=0)
# print(sp_img[0].shape)
rotated_imgs = []
for single_img in sp_imgs:
img = single_img.squeeze(0).numpy().transpose(1, 2, 0)
img = rotate_image(img)
r_img = torch.from_numpy(img.transpose(2, 0, 1))
rotated_imgs.append(r_img)
# print(r_img.shape)
res = torch.stack(rotated_imgs, dim=0)
return res
def convert_image_np(inp):
"""Convert a Tensor to numpy image."""
inp = inp.numpy().transpose((1, 2, 0))
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
inp = std * inp + mean
inp = np.clip(inp, 0, 1)
return inp
def visualize_stn():
with torch.no_grad():
# Get a batch of training data
data = next(iter(test_loader))[0].to(device)
input_tensor = data.cpu()
input_tensor = rotate_image_tensors(input_tensor)
transformed_input_tensor = model.stn(input_tensor).cpu()
# transformed_input_tensor = rotate_image_tensors(transformed_input_tensor)
in_grid = convert_image_np(
torchvision.utils.make_grid(input_tensor))
out_grid = convert_image_np(
torchvision.utils.make_grid(transformed_input_tensor))
# Plot the results side-by-side
f, axarr = plt.subplots(1, 2)
axarr[0].imshow(in_grid)
axarr[0].set_title('Dataset Images')
axarr[1].imshow(out_grid)
axarr[1].set_title('Transformed Images')
# visualize_stn()
for epoch in range(1, 20):
train(epoch)
test()
visualize_stn()
plt.ioff()
plt.savefig('./result.png', format='png')
# plt.show()
接下来是训练结果:
上图是旋转了(-90, 90)度,看起来并不是完全没有效果,目测大部分图都被转到了差不多0,-45,45的角度。
接下来,是旋转了(0, 90)度,这就有趣了,输出的图基本都被转到了45度。
那么旋转了(45, 135)度又如何呢?输出的图都转到了90度。
旋转了(0, 360)度貌似没什么规律了。
上面几个实验基本上也印证了我的猜测。首先,网络肯定是没有这个字被转了多少度这个知识的,至于为什么文字都被转到了范围中间的角度了呢?我有个猜测,首先是和神经网络的特点有关系,记得有个论文猜测,神经网络会先去学简单的特征,然后随着训练的进行,才会学比较困难的特征,对于stn可能也差不多,当把这些字转到中间的时候,对于识别来说更简单。
简单来说,Localisation net比后面的分类网络复杂,那么Localisation net可能会倾向于把文字转到一个或几个固定角度(为什么是中间某个角度,大概是因为转这些更“容易”?)。分类网络如果更复杂,Localisation net比较简单,可能就会出现第一个实验的情况,Localisation net只是把文字这个特征缩放了一下,并没有旋转。感觉由于随机梯度下降这个方法,每次都是微调,导致整个神经网络想用最小的能量达到目标,才导致了这个结果。
那么stn到底有没有用呢?应该还是有一定用处的,但是是不是能用来矫正旋转文字呢?看起来小的角度是可以的,但它能矫正并不是因为它知道哪些是正的,只是因为它恰好会把字旋转到中间那个角度而已。