实现VNet网络中的concatenation

在VNet网络中有一个skip connection的操作,就是将encoding和decoding的特征进行级联,但是级联的条件是feature map的大小要一致,所以经常采取的措施是将encoding中的feature map裁剪成跟decoding的feature map一样的大小,这里用的是tf.slice(input_tensor,begin,size)

参数解释:

input_tensor是输入的tensor,就是被裁剪的feature map

begin是每一个维度的起始位置,这个下面详细说

size相当于每个维度拿几个元素出来

下面举一个例子进行说明:

参考文献:https://www.jianshu.com/p/71e6ef6c121b

t = tf.constant([[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]], [[5, 5, 5], [6, 6, 6]]])

tf.slice(t, [1, 0, 0], [1, 1, 3])

输出:[[[3, 3, 3]]],最后的shape跟size一样,size为(1,1,3)

首先了解一下shape的概念,将t进行分解:

t = [A, B, C]   #这是第一维度

A = [i, j], B = [k, l], C = [m, n]  #这是第二维度

i = [1, 1, 1], j = [2, 2, 2], k = [3, 3 ,3], l = [4, 4, 4], m = [5, 5, 5], n = [6, 6, 6]  # 这是第三维度

对于t来说,最外面括号里有3个东西,分别是A, B, C。这三个东西每个里面有两个东西, 分别是i和j, k和l, m和n。

它们里面每一个又有3个数字,所以t的shape是[3,2,3]

开始裁剪:

tf.slice(t, [1, 0, 0], [1, 1, 3])  # begin = [1, 0, 0],size=[1,1,3]

参数解释:

begin和size的意义是从左至右,begin的意思是起始位置,其中的每一个数字代表一个维度,那么[1, 0, 0]的意思是在3个维度中,每个维度从哪里算起

第一维度是[A, B, C],begin里[1, 0, 0]是1,也就是从B算起。其次第二维度里B = [k, l](注意啊,我这里只写了B = [k, l],可不代表只有B有用,如果size里第一个数字是2的话,B和C都会被取的),begin里第二个数是0,也就是从k算起,第三维度k = [3, 3 ,3],begin里第三个数是0,就是从第一个3算起,而size的意思是每个维度的大小,也就是每个维度取几个元素,size的大小是最后输出的tensor的shape。

size里第一个是1,意思是在第一个维度取1个元素。t = [A, B, C] begin是从B起算,取一个那就是B,那么第一维度结果就是[B]

size第二个也是1,第二维度B = [k, l], begin是从k起算,取一个是k,那么第二维度结果是[[k]]。

size第三个是3,第三维度k = [3, 3 ,3],begin里起算是第一个3,三个3取3个数,那就要把三个3都取了,所以输出:[[[3, 3, 3]]]

再看一个例子:

t = tf.constant([[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]], [[5, 5, 5], [6, 6, 6]]])

tf.slice(t, [1, 0, 0], [1, 2, 3])

begin还是[1, 0 ,0],size第一个维度取一个,还是[B],然后size第二个维度不是1了,是2,意思是取两个。还记得B = [k, l]吗?现在不是只要k了,是k和l都要,size第三维度取3个,也就是说针对k = [3, 3 ,3]和l = [4, 4, 4]都分别取3个元素,最后输出:[[[3, 3, 3], [4, 4, 4]]]

第三个例子:

t = tf.constant([[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]], [[5, 5, 5], [6, 6, 6]]])

tf.slice(t, [1, 0, 0], [-1, -1, -1])

如果size输入值是-1的话,在那个维度所有的数都会输出。上面的例子中,begin是[1, 0, 0]。三个维度都是-1的话,那么输出结果: 第一维度是[B,C],第二维度是[[k, l], [m, n]], 第三维度是[[[3,3,3], [4,4,4]], [[5,5,5], [6,6,6]]]

 

最后裁剪和concatenate的函数:

def crop_and_concat(x1, x2): #x1是encoding中的feature map,x2是decoding中的feature map
    x1_shape = tf.shape(x1)
    x2_shape = tf.shape(x2)
    # offsets for the top left corner of the crop
    offsets = [0, (x1_shape[1] - x2_shape[1]) // 2, #起始维度
               (x1_shape[2] - x2_shape[2]) // 2, (x1_shape[3] - x2_shape[3]) // 2, 0]
    size = [-1, x2_shape[1], x2_shape[2], x2_shape[3], -1] #切片大小
    x1_crop = tf.slice(x1, offsets, size)
    return tf.concat([x1_crop, x2], 4)

这样裁剪和级联有一个弊端:

网络的输入必须是16倍数的大小

 

你可能感兴趣的:(深度学习)