tf.unstack在动态单层双向循环神经网络的搭建中出现过,这里记录下方便自己记忆
功能:将输入value按照指定axis(维度)拼接(从0开始),输出新的张量
举个例子,假设value1.shape为(2,3,4),value2.shape也为(2,3,4)
如果axis=0,那么拼接后张量的shape为(4,3,4)
如果axis=1,那么拼接后张量的shape为(2,6,4)
如果axis=2,那么拼接后张量的shape为(2,3,8)
import tensorflow as tf
import numpy as np
X = tf.constant(np.array(range(24)).reshape(2, 3, 4))
Y = tf.constant(np.array(range(24, 48)).reshape(2, 3, 4))
Z0 = tf.concat([X, Y], 0)
Z1 = tf.concat([X, Y], 1)
Z2 = tf.concat([X, Y], 2)
with tf.Session() as sess:
ts = [X, Y, Z0, Z1, Z2]
xs = sess.run([X, Y, Z0, Z1, Z2])
for t, x in zip(ts, xs):
print(t, '\n', x, '\n')
将输出手动美化后的结果如下,依次是X,Y,Z0,Z1,Z2
Tensor("Const:0", shape=(2, 3, 4), dtype=int64)
[[[ 0 1 2 3]
[ 4 5 6 7]
[ 8 9 10 11]]
[[12 13 14 15]
[16 17 18 19]
[20 21 22 23]]]
Tensor("Const_1:0", shape=(2, 3, 4), dtype=int64)
[[[24 25 26 27]
[28 29 30 31]
[32 33 34 35]]
[[36 37 38 39]
[40 41 42 43]
[44 45 46 47]]]
Tensor("concat:0", shape=(4, 3, 4), dtype=int64)
[[[ 0 1 2 3]
[ 4 5 6 7]
[ 8 9 10 11]]
[[12 13 14 15]
[16 17 18 19]
[20 21 22 23]]
[[24 25 26 27]
[28 29 30 31]
[32 33 34 35]]
[[36 37 38 39]
[40 41 42 43]
[44 45 46 47]]]
Tensor("concat_1:0", shape=(2, 6, 4), dtype=int64)
[[[ 0 1 2 3]
[ 4 5 6 7]
[ 8 9 10 11]
[24 25 26 27]
[28 29 30 31]
[32 33 34 35]]
[[12 13 14 15]
[16 17 18 19]
[20 21 22 23]
[36 37 38 39]
[40 41 42 43]
[44 45 46 47]]]
Tensor("concat_2:0", shape=(2, 3, 8), dtype=int64)
[[[ 0 1 2 3 24 25 26 27]
[ 4 5 6 7 28 29 30 31]
[ 8 9 10 11 32 33 34 35]]
[[12 13 14 15 36 37 38 39]
[16 17 18 19 40 41 42 43]
[20 21 22 23 44 45 46 47]]]
怎么理解呢?
如果axis=0,那么从第0维开始看,将第0维内的元素看做待拼接元素,那么X和Y的第0维内就各有2个待拼接元素,如下图所示,然后将这4个元素,先按在当前维度内的顺序拼接,再按从X到Y的顺序拼接,每个数字的下标位置(a,b,c)中,b,c不发生改变,只有a按照重新编排的位置对号入座,比如:
24 (0,0,0) ==> (2,0,0)
36 (1,0,0) ==> (3,0,0)
如果axis=1,那么从第1维开始看,将第一维内的元素看做待拼接元素,那么X和Y的第一维内各有3个元素,当然这里有个"第1维1"和"第1维2",如下图所示,然后分别将"第1维1"和"第1维2"中的元素先按在当前维度内的顺序拼接,再按从X到Y的顺序拼接,"第1维1"将作为第0维的第1个元素,第1维2"将作为第0维的第2个元素,因为针对第1维做拼接,那么拼接后的各个数字的下标(a,b,c)中,a,c不发生改变,而b将会因为重新编排的位置对号入座,比如:
24 (0,0,0) ==> (0,3,0)
36 (1,0,0) ==> (1,3,0)
如果axis=2,那么从第2维开始看,将第2维内的元素看做待拼接元素,那么X和Y的第2维内各有4个元素,当然这里有"第2维1"、“第2维2”、“第2维3”,如下图所示,然后分别将"第2维1"、“第2维2”、"第2维3"中的元素先按在当前维度内的顺序拼接,再按从X到Y的顺序拼接,因为针对第2维做拼接,那么拼接后的各个数字的下标(a,b,c)中,a,b不发生改变,而c将会因为重新编排的位置对号入座,比如:
24 (0,0,0) ==> (0,0,4)
36 (1,0,0) ==> (1,0,4)
这里要注意的是,除了将被拼接的维度外,其他的维度数必须一致
比如,value1.shape为(2,3,4),value2.shape为(1,3,4)
那么,只能指定axis=0,也就是将第0维内的元素拼接,拼接后的shape为(3,3,4)
如果,指定axis=1或axis=2,那么就会抛出一个ValueError的报错,Dimension 0 in both shapes must be equal, but are 2 and 1. Shapes are [2] and [1]. 其意思就是value1.shape[0] 和 value2.shape[0] 必须相等,然而却是 [2] 和 [1]。