做配准工作已经有一段时日,在有idea后就一直在写文章,工作中把每个模块当成黑箱子在用。直到最近深入研究才发现自己的知识有多么浅薄,所以决定从最基层开始慢慢理解配准这个任务,本篇文章写出本人对STN浅显的理解。
class SpatialTransform(nn.Module):
def __init__(self):
super(SpatialTransform, self).__init__()
def forward(self, x,flow,sample_grid):
sample_grid = sample_grid+flow
size_tensor = sample_grid.size() #3D c,h,w
#此处将新坐标系归一化
sample_grid[0,:,:,:,0] = (sample_grid[0,:,:,:,0]-((size_tensor[3]-1)/2))/size_tensor[3]*2
sample_grid[0,:,:,:,1] = (sample_grid[0,:,:,:,1]-((size_tensor[2]-1)/2))/size_tensor[2]*2
sample_grid[0,:,:,:,2] = (sample_grid[0,:,:,:,2]-((size_tensor[1]-1)/2))/size_tensor[1]*2
image = torch.nn.functional.grid_sample(x, sample_grid,mode = 'bilinear')
return image
首先明确,输入x , flow, sample_grid究竟是什么?
x是一张图像,我们假设它的坐标系在X空间内
flow是通过x和y输出的一个位移场,这个位移场的位移被限制在了[-1,1]
sample_grid是一个初始化坐标系,里面存储一张图的坐标
要明确的是,不是所有工作都将位移场限制在[-1,1],例如VoxelMorph1。
最近我看到限制位移场的工作有ICNET,SYM等23。
#ICNet/Code/Models
def outputs(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0,
bias=False, batchnorm=False):
if batchnorm:
layer = nn.Sequential(
nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias),
nn.BatchNorm3d(out_channels),
nn.Tanh())
else:
layer = nn.Sequential(
nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias),
nn.Tanh())
return layer
首先在功能上,是用Tanh去实现的,对最终的输出层归一化,将flow限制在[-1,1]中。
然后再通过flow*range_flow去扩大flow的范围。
实际上range_flow做到了两个工作,一是限制了flow的范围(tanh的功劳),二是这个范围还挺大。
基于此,range_flow会使网络输出的flow更加稳定些,但不加tanh+range_flow问题也不会太大。
flow的最终目的是使被它变换的图像A与固定图像B变得相似,那么问题重新回到,它是如何变换图像A的。
sample_grid = sample_grid+flow
sample_grid是恒等变换栅格,在某些微分同胚文章中被定义为初始x。
这里要明确的是sample_grid是没有进行坐标系归一化的,同样的flow在训练之前也不知道自己应该落在什么范围。sample_grid+flow得到形变后的坐标系。
sample_grid+flow得到被flow扭曲后的坐标系。
sample_grid[0,:,:,:,0] = (sample_grid[0,:,:,:,0]-((size_tensor[3]-1)/2))/size_tensor[3]*2
sample_grid[0,:,:,:,1] = (sample_grid[0,:,:,:,1]-((size_tensor[2]-1)/2))/size_tensor[2]*2
sample_grid[0,:,:,:,2] = (sample_grid[0,:,:,:,2]-((size_tensor[1]-1)/2))/size_tensor[1]*2
image = torch.nn.functional.grid_sample(x, sample_grid,mode = 'bilinear')
将扭曲后的坐标系归一化(因为torch.nn.functional.grid_sample需要将坐标系归一化),然后进行插值。
这里不由得发出一个疑问,在最开始就将sample_grid进行归一化不就好了吗?
实际上,如果最开始归一化在加上flow,没有办法保证扭曲后的坐标系在[-1,1]中
可以看到,我们将x(Moving image)的像素值插到扭曲后的坐标系。至此插值结束。
最符合直觉的想法应该是Moving->Fixed的flow,但实际上是Fixed-Moving的Flow。最开始知道这件事我是很震惊的,但通过举例子就很容易发现其中的猫腻。
例如,扭曲前的Moving中的一个点是(5,5),flow在这个点的值为(3,3),也就是说扭曲后这个点变为(8,8)。
假设说这个(8,8)插入Moving的像素值,那就出了问题了,因为Moving中的(8,8)根本不是扭曲前的(5,5),这样插值完还是原图。
所以说扭曲前的那个点一定是Fixed中的点,也就是说flow是Fixed->Moving的flow,这样插入Moving的值才是合理的。
补充一张图作为解释:
Unsupervised Learning for Fast Probabilistic Diffeomorphic Registration ↩︎
Fast Symmetric Diffeomorphic Image Registration with Convolutional Neural Networks ↩︎
Inverse-Consistent Deep Networks for Unsupervised Deformable Image Registration ↩︎