Talk is cheap, Show me the code
cat
用于将两个tensor在某个维度上拼接起来
除了要拼接的维度,这两个tensor其它维度的大小应该一致
import torch
torch.manual_seed(666)
x1 = torch.randn(3, 5)
x2 = torch.randn(4, 5)
x_cat = torch.cat((x1, x2), 0) # 将 x1 和 x2 在第0个维度上拼接
print(f'\nx1 shape: {x1.shape}, x2 shape: {x2.shape} \n\ncat x1 and x2: {x_cat.shape}')
输出
x1 shape: torch.Size([3, 5]), x2 shape: torch.Size([4, 5])
cat x1 and x2: torch.Size([7, 5])
stack
用于将两个tensor在新生成的某个维度拼接起来
x1 = torch.randn(3, 5)
x2 = torch.randn(3, 5)
x_stack = torch.stack((x1, x2), 0) # 将 x1 和 x2 在新生成的第0个维度上拼接
print(f'\nx1 shape: {x1.shape}, x2 shape: {x2.shape} \n\nstack x1 and x2: {x_cat.shape}')
输出
x1 shape: torch.Size([3, 5]), x2 shape: torch.Size([3, 5])
stack x1 and x2: torch.Size([2, 3, 5])
transpose
用于将tensor的某两个维度交换
x = torch.randn(3, 5)
x_transpose = x.transpose(0, 1) # 交换第0和第1个维度
print(f'\nx shape before transpose: {x.shape} \n\nx shape after transpose: {x_transpose.shape}')
输出
x shape before transpose: torch.Size([3, 5])
x shape after transpose: torch.Size([5, 3])
permute
用于将tensor的任意个维度交换
x = torch.randn(4, 5, 6, 7)
x_permute = x.permute(1, 3, 2, 0) # 将第0,1,2,3个维度转换为第1,3,2,0个维度
print(f'\nx shape before permute: {x.shape} \n\nx shape after permute: {x_permute.shape}')
输出
x shape before permute: torch.Size([4, 5, 6, 7])
x shape after permute: torch.Size([5, 7, 6, 4])
unsqueeze
用于生成一个新的大小为1的维度
x = torch.randn(3, 5)
x_unsqueeze = x.unsqueeze(0) # 新增一个大小为1的维度
print(f'\nx shape before unsqueeze: {x.shape} \n\nx shape after unsqueeze: {x_unsqueeze.shape}')
输出
x shape before unsqueeze: torch.Size([3, 5])
x shape after unsqueeze: torch.Size([1, 3, 5])
squeeze
用于去除某个大小为1的维度
x = torch.randn(1, 3, 5)
x_squeeze = x.squeeze(0) # 去除第0个维度
print(f'\nx shape before squeeze: {x.shape} \n\nx shape after squeeze: {x_squeeze.shape}')
输出
x shape before squeeze: torch.Size([1, 3, 5])
x shape after squeeze: torch.Size([3, 5])