pytorch中tensor的基本维度变换:transpose/permute/t/cat/chunk/split/stack/squeeze/unsqueeze/expand/repeat

直接从代码中学习tensor的一些维度变换操作:

import torch

torch.manual_seed(2020)

x = torch.rand(1, 2, 3)
print(x)
# tensor([[[0.4869, 0.1052, 0.5883], [0.1161, 0.4949, 0.2824]]])

print('\nview:')
print(x.view(-1, 3).size())      # torch.Size([2, 3])

print('\ntorch.transpose:')
print(torch.transpose(x, 0, 1))
print(x.transpose(0, 1).size())  # torch.Size([2, 1, 3])
print(x.transpose(1, 2).size())  # torch.Size([1, 3, 2])
# transpose要指明待交换的维度
# 与torch.Tensor.permute功能相似,permute需要指明重新排列后的维度顺序
print('Tensor.permute:')
print(x.permute(0, 2, 1).size()) # torch.Size([1, 3, 2])

print('\ntorch.t:')
t = torch.rand(())
print(t, torch.t(t), t.t())  
# tensor(0.5899) tensor(0.5899) tensor(0.5899)
t1d = torch.rand(3)
print(t1d, torch.t(t1d), t1d.t())
# tensor([0.8105, 0.2512, 0.6307]) tensor([0.8105, 0.2512, 0.6307]) tensor([0.8105, 0.2512, 0.6307])
t2d = torch.rand(2, 3)
print(t2d, torch.t(t2d), t2d.t())
# tensor([[0.5403, 0.8033, 0.7781], [0.4966, 0.8888, 0.5570]]) 
# tensor([[0.5403, 0.4966], [0.8033, 0.8888], [0.7781, 0.5570]]) 
# tensor([[0.5403, 0.4966], [0.8033, 0.8888], [0.7781, 0.5570]])
# torch.t()要求输入的tensor的维度小于等于2,对输入的tensor进行转置:
# 当输入维度小于2,对tensor无作用;当输入tensor的维度为2时,相当于transpose(input, 0, 1)

print('\ntorch.cat:')
y = torch.rand(1, 1, 3)
print(torch.cat((x, y), dim=1).size())  # torch.Size([1, 3, 3])
# dim指定待拼接的维度;待拼接的两个向量除了待拼接的维度,其余维度必须相等或为空

print('\ntorch.chunk:')
x_chunks = torch.chunk(x, chunks=2, dim=1)  # x_chunks是一个tuple
print(x_chunks)
# (tensor([[[0.4869, 0.1052, 0.5883]]]), tensor([[[0.1161, 0.4949, 0.2824]]]))
print(x_chunks[0].size(), x_chunks[1].size())
# torch.Size([1, 1, 3]) torch.Size([1, 1, 3])
print(torch.chunk(x, 2, 2))  # 不能整除时,最后一个chunk较小
# (tensor([[[0.4869, 0.1052], [0.1161, 0.4949]]]), 
#  tensor([[[0.5883], [0.2824]]]))
print(torch.chunk(x, 4, 2))  # chunks大于tensor在维度dim上的值时,每个chunk均为1
# (tensor([[[0.4869], [0.1161]]]), 
#  tensor([[[0.1052], [0.4949]]]), 
#  tensor([[[0.5883], [0.2824]]]))
# torch.chunk将tensor在dim维度上划分为chunks块;

print('\ntorch.split:')
z = torch.rand(4, 6, 8)
z_split = torch.split(z, split_size_or_sections=2, dim=1)
print([z_split[i].size() for i in range(len(z_split))])
# [torch.Size([4, 2, 8]), torch.Size([4, 2, 8]), torch.Size([4, 2, 8])]
z_split = torch.split(z, split_size_or_sections=4, dim=1)
print([z_split[i].size() for i in range(len(z_split))])
# [torch.Size([4, 4, 8]), torch.Size([4, 2, 8])]
z_split = torch.split(z, split_size_or_sections=[3, 3], dim=1)
print([z_split[i].size() for i in range(len(z_split))])
# [torch.Size([4, 3, 8]), torch.Size([4, 3, 8])]
# torch.split也是将tensor在指定的维度上分成若干块,不同于torch.chunk的是:
# torch.chunk指定分成几个chunk;torch.split指定每个chunk的大小
# torch.chunk和torch.split可以看作是torch.cat的反面

print('\ntorch.stack:')
a = torch.rand(2, 3, 4)
b = torch.rand(2, 3, 4)
print(torch.stack((a, b), dim=0).size())  # torch.Size([2, 2, 3, 4])
c = torch.rand(2, 3, 4)
print(torch.stack((a, b, c), dim=0).size())  # torch.Size([3, 2, 3, 4])
# torch.stack与torch.cat的区别:前者在新的维度上拼接;后者在已有的维度上拼接

print('\ntorch.squeeze:')
d = torch.rand(1, 2, 3, 1)
print(torch.squeeze(d).size())  # torch.Size([2, 3])
print(torch.squeeze(d, dim=3).size())    # torch.Size([1, 2, 3])
# torch.squeeze去掉大小为1的维度;dim默认为None,去掉所有大小为1的维度;
# 指定dim时,只去掉指定的大小为1的维度;若指定的维度大小不为1,则不起作用
print(torch.unsqueeze(d, dim=0).size())  # torch.Size([1, 1, 2, 3, 1])
print(torch.unsqueeze(d, dim=-1).size()) # torch.Size([1, 2, 3, 1, 1])
# torch.unsqueeze在指定位置增加一个大小为1的维度

print('\ntorch.expand:')
f = torch.tensor([[1], [2]])  # torch.Size([2, 1])
g = torch.rand(2, 2)
print(f.expand(2, 2)) # tensor([[1, 1], [2, 2]])
print(f.expand_as(g)) # tensor([[1, 1], [2, 2]])
# expand和expand_as都是对大小为1的维度进行扩展,扩展方式为复制
# expand的输入是扩展的tensor shape,expand_as输入的是具有目标shape的tensor

print('\ntorch.repeat:')
print(f.repeat(1, 2)) # tensor([[1, 1], [2, 2]])
print(f.repeat(2, 1)) # tensor([[1], [2], [1], [2]])
print(f.repeat(2, 2)) # tensor([[1, 1], [2, 2], [1, 1], [2, 2]])
# repeat将tensor相应维度上的内容赋值指定的次数
# 输入为repeat之后的shape,维度数量必须和要repeat的tensor的维度相同;repeat一次表示不变

你可能感兴趣的:(PyTorch)