tf.stack()与tf.unstack()函数

tf.stack()是一个矩阵拼接函数,即将秩为 R 的张量列表堆叠成一个秩为 (R+1) 的张量。 

import tensorflow as tf
a = tf.constant([[1,2,3],[4,5,6]])
b = tf.constant([[7,8,9],[0,1,7]])
c = tf.stack([a,b],axis = 0)
with tf.Session() as sess:
    result1 = sess.run(c)
    print(result1)
>>>
[[[1 2 3]
  [4 5 6]]

 [[7 8 9]
  [0 1 7]]]

将 values 中的张量列表打包成一个张量,该张量比 values 中的每个张量都高一个秩,通过沿 axis 维度打包。给定一个形状为(A, B, C)的张量的长度 N 的列表;

如果 axis == 0,那么 output 张量将具有形状(N, A, B, C)。如果 axis == 1,那么 output 张量将具有形状(A, N, B, C)。

如果 axis == 2,那么 output 张量将具有形状( A, B, N, C)。如果 axis == 3,那么 output 张量将具有形状(A, B, C, N)。


tf.unstack()是一个拆分矩阵的函数,将秩为 R 的张量的给定维度出栈为秩为 (R-1) 的张量。

通过沿 axis 维度将 num 张量从 value 中分离出来。如果没有指定 num(默认值),则从 value 的形状推断。如果 value.shape[axis] 不知道,则引发 ValueError。

例如,给定一个具有形状 (A, B, C, D) 的张量。

  • 如果 axis == 0,那么 output 中的第 i 个张量就是切片 value[i, :, :, :],并且 output 中的每个张量都具有形状 (B, C, D)。(请注意,出栈的维度已经消失,不像split)。 
  • 如果 axis == 1,那么 output 中的第 i 个张量就是切片 value[:, i, :, :],并且 output 中的每个张量都具有形状 (A, C, D)。 

这与堆栈(stack.)相反

 

转:https://www.w3cschool.cn/tensorflow_python/tensorflow_python-bsky2o7k.html

你可能感兴趣的:(tensorflow)