voxelmorph中的STN网络

首先要理解为什么要进行空间变换?因为我们要将M图像向F图像靠近,不可能一次直接得到变换结果因此需要不断的训练,来得到他们之间的差距,以便进行优化运算。网上有许多介绍空间变换数学模型的博客可以学习一下。
本文主要看Voxelmorph的空间变换函数,原型是3D的,个人原因要做成2D的,原理应该一样,望知晓!
先看一下源代码:


class SpatialTransformer(nn.Module):
    """
    N-D Spatial Transformer
    """

    def __init__(self, size, mode='bilinear'):    #size = [128,128]
        super().__init__()

        self.mode = mode

        # create sampling grid
        vectors = [torch.arange(0, s) for s in size]
        grids = torch.meshgrid(vectors)
        grid = torch.stack(grids)
        grid = torch.unsqueeze(grid, 0)
        grid = grid.type(torch.FloatTensor)
        
        self.register_buffer('grid', grid)

    def forward(self, src, flow):
       
        new_locs = self.grid + flow
        shape = flow.shape[2:]

        for i in range(len(shape)):
            new_locs[:, i, ...] = 2 * (new_locs[:, i, ...] / (shape[i] - 1) - 0.5)

        # move channels dim to last position
        # also not sure why, but the channels need to be reversed
        if len(shape) == 2:
            new_locs = new_locs.permute(0, 2, 3, 1)
            new_locs = new_locs[..., [1, 0]]
        elif len(shape) == 3:
            new_locs = new_locs.permute(0, 2, 3, 4, 1)
            new_locs = new_locs[..., [2, 1, 0]]

        return nnf.grid_sample(src, new_locs, align_corners=True, mode=self.mode)

初始化构建网格

首先第一行命令

vectors = [torch.arange(0, s) for s in size]
#结果
#[tensor([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
#         14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
#         28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,
#         42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,
#         56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,
#         70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
#         84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
#         98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
#        112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
#        126, 127]), tensor([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
#         14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
#         28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,
#         42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,
#         56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,
#         70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
#         84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
#         98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
#        112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
#        126, 127])]

创建一个二维的每个维度长128的张量[[0,…,127],[0,…,127]]

第二行命令

 grids = torch.meshgrid(vectors)
#结果
#(tensor([[  0,   0,   0,  ...,   0,   0,   0],
#        [  1,   1,   1,  ...,   1,   1,   1],
#        [  2,   2,   2,  ...,   2,   2,   2],
#        ...,
#        [125, 125, 125,  ..., 125, 125, 125],
#        [126, 126, 126,  ..., 126, 126, 126],
#        [127, 127, 127,  ..., 127, 127, 127]]), tensor([[  0,   1,   2,  ..., 125, 126, 127],
#        [  0,   1,   2,  ..., 125, 126, 127],
#        [  0,   1,   2,  ..., 125, 126, 127],
#        ...,
#        [  0,   1,   2,  ..., 125, 126, 127],
#        [  0,   1,   2,  ..., 125, 126, 127],
#        [  0,   1,   2,  ..., 125, 126, 127]]))

我理解的是这就相当于生成了一个二维的坐标系(一个元组,包含两个二维张量128×128大小)以便获取图像对应位置的坐标,这两个张量分别是横向和纵向的(一个是纵向0-127,另一个是横向0-127),刚好对应我们熟悉的x,y轴的坐标定位。因此可以认为这行命令生成了一个127×127大小的网格,有128×128个顶点,每个顶点对应图像上一个像素坐标,这也有利于理解插值。torch.meshgrid

第三行命令

grid = torch.stack(grids)   #          torch.Size([2, 128, 128])
#结果
#tensor([[[  0,   0,   0,  ...,   0,   0,   0],
#         [  1,   1,   1,  ...,   1,   1,   1],
#         [  2,   2,   2,  ...,   2,   2,   2],
#         ...,
#         [125, 125, 125,  ..., 125, 125, 125],
#         [126, 126, 126,  ..., 126, 126, 126],
#         [127, 127, 127,  ..., 127, 127, 127]],
#
#        [[  0,   1,   2,  ..., 125, 126, 127],
#         [  0,   1,   2,  ..., 125, 126, 127],
#         [  0,   1,   2,  ..., 125, 126, 127],
#         ...,
#         [  0,   1,   2,  ..., 125, 126, 127],
#         [  0,   1,   2,  ..., 125, 126, 127],
#         [  0,   1,   2,  ..., 125, 126, 127]]])

对上述的元组扩张(我认为可以理解为将元组变换为张量),生成一个三维的张量(注意上面是两个张量)2×128×128大小,其中2代表的是元组中的两个张量,128×128是张量的尺寸。torch.stack

第四行命令

grid = torch.unsqueeze(grid, 0)    #torch.Size([1, 2, 128, 128])   B,C,H,W
#结果
#tensor([[[[  0,   0,   0,  ...,   0,   0,   0],
#          [  1,   1,   1,  ...,   1,   1,   1],
#          [  2,   2,   2,  ...,   2,   2,   2],
#          ...,
#          [125, 125, 125,  ..., 125, 125, 125],
#          [126, 126, 126,  ..., 126, 126, 126],
#          [127, 127, 127,  ..., 127, 127, 127]],

#         [[  0,   1,   2,  ..., 125, 126, 127],
#          [  0,   1,   2,  ..., 125, 126, 127],
#          [  0,   1,   2,  ..., 125, 126, 127],
#          ...,
#          [  0,   1,   2,  ..., 125, 126, 127],
#          [  0,   1,   2,  ..., 125, 126, 127],
#          [  0,   1,   2,  ..., 125, 126, 127]]]])

生成一个四维的张量方便对数据进行处理(输入2D图像是四维的张量)。torch.unsqueeze

第五行命令

grid = grid.type(torch.FloatTensor)

将网格数据变换为浮点型。

第六行命令

self.register_buffer('grid', grid)

将grid写入内存且不随优化器优化而改变(他是刚性变换,优化的是非刚性变换)。register_buffer

前向传播 forward

第一行命令

new_locs = self.grid + flow

是将初始化生成的网格与经过训练得到的形变场叠加得到最终的形变场。

第二行命令

shape = flow.shape[2:]

获取当前执行的配准图像维度

第三行命令(if命令)

for i in range(len(shape)):
            new_locs[:, i, ...] = 2 * (new_locs[:, i, ...] / (shape[i] - 1) - 0.5)
#这是源代码,处理的是三维的图像,我是二维的图像,直接使用源码得到的结果只是前两列像素有像素值,经过探索,我将他修改为
#        for i in range(128):   #我处理的图像大小为(128,128),我把每一列像素都限定在[-1,1]之间
#            new_locs[..., i, ...] = 2 * (new_locs[..., i, ...] / 127 - 0.5)

本if语句的意义是,将网格值标准化为 [-1, 1] 以进行重采样。

第四行命令(if命令)

 if len(shape) == 2:
            new_locs = new_locs.permute(0, 2, 3, 1)
            new_locs = new_locs[..., [1, 0]]

修改网格张量的顺序,第二行是对网格进行旋转,这是我实验得到的结果,如果把这一行注释掉,得到的结果是顺时针旋转90度的图像。torch.permute
最后一行命令

return nnf.grid_sample(src, new_locs, align_corners=True, mode=self.mode)

将M图像和形变场网格输入,返回最终配准的图像。grid_sample

================================================================

文中所引用的内容如有侵权,告删!!!

你可能感兴趣的:(pytorch,深度学习,python)