对于多维张量而言, 约减的方向是一个需要明确的问题。在TensorFlow中, 提供了很多关于约减的函数, 如tf.reduce_sum, tf.reduce_mean, tf.reduce_max, tf.reduce_min等函数, 它们的约减原理都是一样的,即从一大批数据中,不断减少数据量,直到找到满足要求的数据。
下面以tf.reduce_sum()来说明张量的约减方向。原型如下:
tf.reduce_sum(
input_tensor,
axis=None,
keepdims=False,
name=None,
reduction_indices=None)
只有第一个参数input_tensor是必须的。对张量(多维数组)而言,约减是有方向性的。第2个参数axis,决定了约减的轴方向。axis=0,垂直方向上约减, axis=1,水平方向上约减。并且约减可以有先后顺序。因此axis的值可以是一个向量,比如axis=[1,0], 表示先水平方向约减,然后垂直方向上约减。axis默认为None,表示所有维度的张量都会依次约减。
参数keepdims为True, 那么每个维度的张量被约减到长度为1, 即保留了维度信息。
下面给出代码以及运行结果:
x = tf.constant([[1,1,1], [1,1,1]])
a = tf.reduce_sum(x)
b = tf.reduce_sum(x, 0) # 垂直方向上约减
c = tf.reduce_sum(x, 1) # 水平方向上约减
d = tf.reduce_sum(x, 1, keepdims=True) # 每个维度的张量被约减到长度为1, 即保留了维度信息
e = tf.reduce_sum(x, [0, 1]) #先垂直后水平
with tf.Session() as sess:
print('a =', sess.run(a))
print('b =', sess.run(b))
print('c =', sess.run(c))
print('d =', sess.run(d))
print('e =', sess.run(e))
结果为:
a = 6
b = [2 2 2]
c = [3 3]
d = [[3]
[3]]
e = 6
上述的解释虽然直观,但有很大的局限性。这种轴的概念,在维度小于2时,容易理解。且对于0表示垂直方向, 1表示水平方向是人为强加的。当在维度>=3时,就难以找到直观可理解的方向。
更加普适的解释应该按张量括号层次的方式来理解。张量括号由外到内,对应从小到大的维数。
当指定reduce_sum函数的axis=0时,就是在第0个维度的元素之间进行sum操作,也就是除掉最外层括号后对应的两个元素,即[[1,1,1],[2,2,2]],[[3,3,3],[4,4,4]],然后对同一个括号层次下的这两个张量实施加法约减操作,即张量[[1,1,1],[2,2,2]]和
张量[[3,3,3],[4,4,4]]整体相加, 其结果为[[4,4,4],[6,6,6]]。没有被约减的维度,其括号层次保持不变。
类似的,当axis=1时,就是在第1个维度的元素之间进行sum操作,也就是去掉中间层括号对应的元素[1,1,1],[2,2,2]和[3,3,3],[4,4,4]。需要注意的时, 原来在同一个括号层次内的张量两两相加,即[1,1,1]和[2,2,2]向量相加,[3,3,3]和[4,4,4]向量相加。
没有被约减的维度,其括号保持不变,结果得到 [[3,3,3],[7,7,7]]。
当axis=2时,就是除掉最内层的括号,然后在最内层括号的元素之间进行sum操作。即1+1+1=3,2+2+2=6,3+3+3=9,4+4+4=12。实施约减之后,该层次括号消失,其他维度的括号保留。结果得到[[3,6],[9,12]]。
这里为了便于区分,用逗号','将同一层次的不同元素隔开,实际上TensorFlow中,不同元素是用 空格隔开的。 事实上,每一个维度的约减,在实施之后,该维度都会消失。
下面用一个简单的程序来验证上面的描述:
x1 = tf.constant([
[[1,1,1],[2,2,2]],
[[3,3,3],[4,4,4]]
])
z0 = tf.reduce_sum(x1, 0)
z1 = tf.reduce_sum(x1, 1)
z2 = tf.reduce_sum(x1, 2)
z3 = tf.reduce_sum(x1)
with tf.Session() as sess:
print("============>\n", sess.run(z0))
print("============>\n", sess.run(z1))
print("============> \n", sess.run(z2))
print("============>\n ", sess.run(z3))
结果如下:
============>
[[4 4 4]
[6 6 6]]
============>
[[3 3 3]
[7 7 7]]
============>
[[ 3 6]
[ 9 12]]
============>
30