TensorFlow——多维矩阵的转置(transpose)

今天在深度学习第四课的神经风格转移遇到了一个折磨了我很久的东西,就是高纬度矩阵转置,不得不说,一旦维度升高,真的会让人懵逼,废话不多说,开始讲一下我对TensorFlow中transpose()函数的用法。

先看一下官方API:

tf.transpose
transpose(
    a,                          #  a是一个张量
    perm=None,                  # perm就是你对张量怎么转置的规则,即序列改变列表
    name='transpose'
)

Args:
a: A Tensor.
perm: A permutation of the dimensions of a.
name: A name for the operation (optional).
Returns:
A transposed Tensor.

Transposes a. Permutes the dimensions according to perm.

The returned tensor’s dimension i will correspond to the input dimension perm[i]. If perm is not given, it is set to (n-1…0), where n is the rank of the input tensor. Hence by default, this operation performs a regular matrix transpose on 2-D input Tensors.

下面是官方给出的例子

For example:

# 'x' is [[1 2 3]
#         [4 5 6]]
tf.transpose(x) ==> [[1 4]
                     [2 5]
                     [3 6]]

# Equivalently
tf.transpose(x, perm=[1, 0]) ==> [[1 4]
                                  [2 5]
                                  [3 6]]

# 'perm' is more useful for n-dimensional tensors, for n > 2
# 'x' is   [[[1  2  3]
#            [4  5  6]]
#           [[7  8  9]
#            [10 11 12]]]
# Take the transpose of the matrices in dimension-0
tf.transpose(x, perm=[0, 2, 1]) ==> [[[1  4]
                                      [2  5]
                                      [3  6]]

                                     [[7 10]
                                      [8 11]
                                      [9 12]]]

将到这里,我不得不说一下TensorFlow张量的排列方式,在numpy中也是一样的,看下面的一个shape为2*3*2的张量,注意我特意调整的对齐方式以及中括号的嵌套。

[[[1  4]
  [2  5]
  [3  6]]

  [[7 10]
   [8 11]
   [9 12]]]

第一个维度是2,于是最外面的中括号里面套了2个子张量,也就是2个3*2的张量,我们把其中一个提取出来。

[[1  4]
 [2  5]
 [3  6]]

它是一个3*2的张量,按照上面说的方法,最外面的中括号里面套了3个子张量,也就是3个2维的向量了,不过2维的张量我们还是很熟悉的,没必要这样看,这种“看”张量方法很适合在3维甚至更高维度的情况。

我再给一个4维的2*2*2*3的张量,你可以按照我说的方法“看”一下这个张量,很管用的,对吧?起码对我来说是这样的。

[[[[ 1  2  3]
   [ 4  5  6]]

   [[ 7  8  9]
   [10 11 12]]]


 [[[ 1  2  3]
   [ 4  5  6]]

   [[ 7  8  9]
   [10 11 12]]]]

* 声明:我以下说的索引轴都是从0开始计数的

着重讲一下perm这个参数,它是一个列表,列表第i位的数字,表明现在第i个索引轴(axis)对应的是原来张量第perm[i]个索引轴(axis)。

好吧,我打出这句话后,我都不知道是什么意思,举几个栗子就知道了。

  • 二维的栗子
    ‘x’ is

    [[1 2 3]
    [4 5 6]]

    这个张量是二维的,它有两条索引轴(axis),分别是0和1,0索引轴就是我们说的行标,1就是列标,ok。现在我们执行tf.transpose(x,perm=[1,0]) 。

    执行这句话后,现在第0个索引轴对应原来的第1个索引轴,第1个索引轴对应原来的第0个索引轴。比如2这个值,原本的索引应该是[0,1],新的索引成了[1,0];6这个值,原本的索引应该是[1,2],新的索引成了[2,1]

于是结果就自然成了行列转置。


[[1 4]
[2 5]
[3 6]]

  • 三维的栗子
 'x' is  
  [[[1  2  3]
    [4  5  6]]

    [[7  8  9]
     [10 11 12]]]

执行tf.transpose(x, perm=[0, 2, 1]), 执行这句话后,现在第0个索引轴对应原来的第0个索引轴,第1个索引轴对应原来的第2个索引轴,第2个索引轴对应原来的第1个索引轴。比如3这个值,原本的索引应该是[0,0,2],新的索引成了[0,2,0];12这个值,原本的索引应该是[1,1,2],新的索引成了[1,2,1]

也就是说,第1个维度不变,第2个和第3个维度交换位置,再加上我上面说的“看”张量的方法,你会发现其实就是将每个子张量行列转置。

[[[1  4]
  [2  5]
  [3  6]]

  [[7 10]
   [8 11]
   [9 12]]]

如有理解不对的地方,请指出,谢谢~

你可能感兴趣的:(神经网络)