flatten()是对多维数据的降维函数。
flatten(),默认缺省参数为0,也就是说flatten()和flatte(0)效果一样。
python里的flatten(dim)表示,从第dim个维度开始展开,将后面的维度转化为一维.也就是说,只保留dim之前的维度,其他维度的数据全都挤在dim这一维。
>>> input = torch.randn(2,3,4)
>>> print(input.flatten().size())
torch.Size([24])
>>> print(input.flatten(1).size())
torch.Size([2, 12])
>>> print(input.flatten(2).size())
torch.Size([2, 3, 4])
>>>
transpose()函数的作用就是调换数组的行列值的索引值,类似于求矩阵的转置:
>>> input = torch.randn(2,3,4)
>>> print(input.transpose(1,2).size())
torch.Size([2, 4, 3])
>>> print(input.transpose(0,2).size())
torch.Size([4, 3, 2])
>>>
permute() 函数一次可以进行多个维度的交换或者可以成为维度重新排列,参数是 0, 1, 2, 3, … ,随着待转换张量的阶数上升参数越来越多,本质上可以理解为多个 transpose() 操作的叠加,因此理解 permute() 函数的关键在于理解 transpose() 函数
>>> import torch
>>> x = torch.Tensor([[[1, 2, 3, 4],
... [5, 6, 7, 8],
... [9, 10, 11, 12]],
...
... [[13, 14, 15, 16],
... [17, 18, 19, 20],
... [21, 22, 23, 24]]]) # 一个结构为 (2, 3, 4) 的 3 阶张量
>>> print(x.shape)
torch.Size([2, 3, 4])
>>> y = x.permute(2, 0, 1) # 对张量 x 进行维度重排
>>> z = x.transpose(0, 1).transpose(0, 2) # 对张量 x 连续交换两次维度
>>> print(y.equal(z))
True
>>> print(y.shape)
torch.Size([4, 2, 3])
>>>
torch.nn.Flatten(start_dim=1, end_dim=- 1)
作用:将连续的维度范围展平为张量。 经常在nn.Sequential()中出现,一般写在某个神经网络模型之后,用于对神经网络模型的输出进行处理,得到tensor类型的数据。
有俩个参数,start_dim和end_dim,分别表示开始的维度和终止的维度,默认值分别是1和-1,其中1表示第一维度,-1表示最后的维度。结合起来看意思就是从第一维度到最后一个维度全部给展平为张量。(注意:数据的维度是从0开始的,也就是存在第0维度,第一维度并不是真正意义上的第一个)
>>> import torch
>>> input = torch.randn(32, 1, 5, 5)
>>> m = torch.nn.Flatten()
>>> output1 = m(input)
>>> output1.size()
torch.Size([32, 25])
>>>
>>> m = torch.nn.Flatten(0, 2)
>>> output = m(input)
>>> output.size()
torch.Size([160, 5])
>>>
torch.nn.Softmax() :将Softmax函数应用于一个n维输入张量,对其进行缩放,使n维输出张量的元素位于[0,1]范围内,总和为1。
参数
dim (int) - Softmax将被计算的维度(因此沿dim的每个切片和为1)。
当dim=0时,指的是在维度0上的元素相加等于1。
当dim=1时,指的是在维度1上的元素相加等于1。
当dim=2时,指的是在维度2上的元素相加等于1。
>>> import torch
>>> a = torch.randn(2,3)
>>> a
tensor([[ 0.3572, -1.2864, 0.5049],
[ 1.4707, -0.3163, -0.2466]])
>>> torch.nn.Softmax(dim = -1)(a) # 在最后一个维度上相加
tensor([[0.4251, 0.0822, 0.4928],
[0.7424, 0.1243, 0.1333]])
>>>