第四章 PyTorch中张量(Tensor)拼接和拆分操作
上文介绍了PyTorch中张量(Tensor)的切片操作,本文主要介绍张量的拆分
和拼接
操作。
函数 | 描述 |
---|---|
torch.cat() |
将张量按照指定维度 关系进行拼接 |
torch.stack() |
将张量按照指定维度 关系进行拼接(用法同cat相同 ) |
# 引入库 import torch # 创建张量 A = torch.arange(9).reshape(1, 3, 3) print(A)
输出结果为:
tensor(
[[[0, 1, 2],
[3, 4, 5],
[6, 7, 8]]])
1、按照维度1进行拼接:
B0 = torch.cat((A, A), dim=0) print(B0)
输出结果为:
tensor([[[0, 1, 2],
[3, 4, 5],
[6, 7, 8]],
[[0, 1, 2],
[3, 4, 5],
[6, 7, 8]]])
1、按照维度2(
行
)进行拼接:B1 = torch.cat((A, A), dim=2) print(B1)
输出结果为:
tensor([[[0, 1, 2],
[3, 4, 5],
[6, 7, 8],
[0, 1, 2],
[3, 4, 5],
[6, 7, 8]]])
1、按照维度3(
列
)进行拼接:B2 = torch.cat((A, A), dim=2) print(B2)
输出结果为:
tensor([[[0, 1, 2, 0, 1, 2],
[3, 4, 5, 3, 4, 5],
[6, 7, 8, 6, 7, 8]]])
函数 | 描述 |
---|---|
torch.chunk() |
将张量分割为特定数量的块(当张量对应维度元素数量不足以拆分时会按照可以拆分数量进行拆分,且会出现不均等拆分情况) |
torch.split() |
将张量分割为特定数量的块,可以指定块的大小 |
注意:
torch.chunk()
:当张量对应维度元素数量不足以拆分时,会按照可以拆分的最大数量
进行拆分,且会出现不均等拆分
情况,且最后一个块最小
下文使用B0进行示例
B0 = tensor([[[0, 1, 2],
[3, 4, 5],
[6, 7, 8]],
[[0, 1, 2],
[3, 4, 5],
[6, 7, 8]]])
1、
torch.chunk()
按照维度1进行拆分:C1, C2 = torch.chunk(B0, 2, dim=1) # 维度1只有三组元素,所以会按照2:1的比例进行拆分 print(C1, C2)
输出结果为:
tensor([[[0, 1, 2],
[3, 4, 5],
[6, 7, 8]]])
tensor([[[0, 1, 2],
[3, 4, 5],
[6, 7, 8]]])
1、
torch.chunk()
按照维度2进行拆分:D1, D2 = torch.chunk(B0, 2, dim=1) # 3表示指定拆分数,但由于不足以拆分,所以只会拆分两组 print(D1, D2)
输出结果为:
tensor([[[0, 1, 2],
[3, 4, 5]],
[[0, 1, 2],
[3, 4, 5]]])
tensor([[[6, 7, 8]],
[[6, 7, 8]]])