pytorch中的矩阵的转置问题

目录

  • 前言
  • 三阶张量的转置

前言

我在我的pytorch专栏发布了一期pytorch入门之tensor,介绍了torch.tensor()的一些创建方式和常用方法,其中就有矩阵的转置方法----tensor.t()、tensor.transpose()和tensor.permute()。我只是用少量语言和代码介绍了这三种方法的用法,但其中的转置原理没有说清。今天咱们就来絮叨絮叨~

相信学过线性代数的小伙伴对矩阵的转置不会太陌生,对于一个m*n二维矩阵A,它的转置原理如下:

转置前
pytorch中的矩阵的转置问题_第1张图片
转置后
pytorch中的矩阵的转置问题_第2张图片

简单来说,就是把这个矩阵的行数据转换为列数据。tensor.t()就是这样的原理。但对于tensor而言,它可以有许多的维度,二维矩阵用torch.tensor()表示如下:

a = torch.tensor(np.arange(24).reshape([4, 6]))
print(a)

pytorch中的矩阵的转置问题_第3张图片

上图就是一个二维的矩阵,与数学上的矩阵表示有所不同。它的最外边的中括号代表的维度用“0”表示,里面的括号代表的维度用“1”表示。这种表示我们用三阶张量来表示更明显一些。

a = torch.tensor(np.arange(24).reshape([2, 3, 4]))
print(a)
print(a.size(0))
print(a.size(1))
print(a.size(2))

pytorch中的矩阵的转置问题_第4张图片
如果不清楚tensor.size()用法可以看我的pytorch入门之tensor

因为三阶张量的转置无法使用tensor.t(),所以接下来的转置原理的解释都是用tensor.transpose()来演示。

三阶张量的转置

我们先创建一个三阶张量,以这个三阶张量来演示0维、1维和2维两两之间的转置是怎样的

import torch
import numpy as np
example = torch.tensor(np.arange(24).reshape([2, 3, 4]))

注释:这个三阶张量是由2个3行4列的矩阵构成,tensor.transpose()括号里的0,1,2分别对应torch.tensor(np.arange(24).reshape([2, 3, 4]))的[]里面的2,3,4。下面先来展示一下结果

print(example)
print(example.size(0))
print(example.size(1))
print(example.size(2))

pytorch中的矩阵的转置问题_第5张图片

0代表三阶张量最外边的括号,1代表中间的括号,2代表最里层的括号。

  1. tensor.transpose(0, 1)2和3互换,tensor.size() == torch.Size([3,2,4])

    example1 = torch.tensor(np.arange(24).reshape([2, 3, 4]))
    print(example1)
    example2 = example1.transpose(0, 1)
    print(example2)
    

    pytorch中的矩阵的转置问题_第6张图片

  2. tensor.transpose(0, 2)2和4互换,tensor.size() == torch.Size([4,3,2])

    	example1 = torch.tensor(np.arange(24).reshape([2, 3, 4]))
    	print(example1)
    	example2 = example1.transpose(0, 2)
    	print(example2)
    

    pytorch中的矩阵的转置问题_第7张图片

  3. tensor.transpose(1, 2)3和4互换,tensor.size() == torch.Size([2,4,3])

    example1 = torch.tensor(np.arange(24).reshape([2, 3, 4]))
    print(example1)
    example2 = example1.transpose(1, 2)
    print(example2)
    

    pytorch中的矩阵的转置问题_第8张图片

你可能感兴趣的:(#,pytorch,python)