【pytorch】pytorch维度变换函数:cat, stack, transpose, permute, unsqueeze, squeeze

Talk is cheap, Show me the code

cat

cat用于将两个tensor在某个维度上拼接起来
除了要拼接的维度,这两个tensor其它维度的大小应该一致

import torch
torch.manual_seed(666)

x1 = torch.randn(3, 5)
x2 = torch.randn(4, 5)
x_cat = torch.cat((x1, x2), 0) # 将 x1 和 x2 在第0个维度上拼接
print(f'\nx1 shape: {x1.shape}, x2 shape: {x2.shape} \n\ncat x1 and x2: {x_cat.shape}')

输出

x1 shape: torch.Size([3, 5]), x2 shape: torch.Size([4, 5]) 

cat x1 and x2: torch.Size([7, 5])

stack

stack用于将两个tensor在新生成的某个维度拼接起来

x1 = torch.randn(3, 5)
x2 = torch.randn(3, 5)
x_stack = torch.stack((x1, x2), 0) # 将 x1 和 x2 在新生成的第0个维度上拼接
print(f'\nx1 shape: {x1.shape}, x2 shape: {x2.shape} \n\nstack x1 and x2: {x_cat.shape}')

输出

x1 shape: torch.Size([3, 5]), x2 shape: torch.Size([3, 5]) 

stack x1 and x2: torch.Size([2, 3, 5])

transpose

transpose用于将tensor的某两个维度交换

x = torch.randn(3, 5)
x_transpose = x.transpose(0, 1) # 交换第0和第1个维度
print(f'\nx shape before transpose: {x.shape} \n\nx shape after transpose: {x_transpose.shape}')

输出

x shape before transpose: torch.Size([3, 5]) 

x shape after transpose: torch.Size([5, 3])

permute

permute用于将tensor的任意个维度交换

x = torch.randn(4, 5, 6, 7)
x_permute = x.permute(1, 3, 2, 0) # 将第0,1,2,3个维度转换为第1,3,2,0个维度
print(f'\nx shape before permute: {x.shape} \n\nx shape after permute: {x_permute.shape}')

输出

x shape before permute: torch.Size([4, 5, 6, 7]) 

x shape after permute: torch.Size([5, 7, 6, 4])

unsqueeze

unsqueeze用于生成一个新的大小为1的维度

x = torch.randn(3, 5)
x_unsqueeze = x.unsqueeze(0) # 新增一个大小为1的维度
print(f'\nx shape before unsqueeze: {x.shape} \n\nx shape after unsqueeze: {x_unsqueeze.shape}')

输出

x shape before unsqueeze: torch.Size([3, 5]) 

x shape after unsqueeze: torch.Size([1, 3, 5])

squeeze

squeeze用于去除某个大小为1的维度

x = torch.randn(1, 3, 5)
x_squeeze = x.squeeze(0) # 去除第0个维度
print(f'\nx shape before squeeze: {x.shape} \n\nx shape after squeeze: {x_squeeze.shape}')

输出

x shape before squeeze: torch.Size([1, 3, 5]) 

x shape after squeeze: torch.Size([3, 5])

你可能感兴趣的:(机器学习&深度学习,pytorch)