Pytorch 中的 permute() 函数

作用

交换 tensor 中的维度

用法

将对应维度的序号交换就会交换对应的维度

示例

import torch

a = torch.rand(2, 3, 4)
print(a)
print(a.shape)

# 将第一维和第二维进行交换, 第零维不动
b = a.permute(0, 2, 1)
print(b)
print(b.shape)

>>tensor([[[0.1135, 0.1757, 0.4028, 0.4548],
         [0.4652, 0.9984, 0.8759, 0.8631],
         [0.9614, 0.8819, 0.5834, 0.7719]],

        [[0.5982, 0.6116, 0.2923, 0.1457],
         [0.8527, 0.2529, 0.1352, 0.6022],
         [0.9118, 0.1686, 0.9508, 0.1597]]])
>>torch.Size([2, 3, 4])
>>tensor([[[0.1135, 0.4652, 0.9614],
         [0.1757, 0.9984, 0.8819],
         [0.4028, 0.8759, 0.5834],
         [0.4548, 0.8631, 0.7719]],

        [[0.5982, 0.8527, 0.9118],
         [0.6116, 0.2529, 0.1686],
         [0.2923, 0.1352, 0.9508],
         [0.1457, 0.6022, 0.1597]]])
>>torch.Size([2, 4, 3])

你可能感兴趣的:(Pytorch,中的各种函数,Pytorch)