1. 维度拼接
(1)cat:根据指定维度进行数据的合并。
通过指定维度拼接时,需要保证其它维度的size是相同的。
示例代码:
import torch
a = torch.rand(4, 32, 8)
b = torch.rand(5, 32, 8)
# 在第一个维度进行拼接
print(torch.cat([a, b], dim=0).shape)
输出结果:
torch.Size([9, 32, 8]
(2)stack:在指定维度的位置前插入新的维度。
stack需要保证两个Tensor的shape是一致的,这就像是有两类东西,它们的其它属性都是一样的(比如男的一张表,女的一张表)。使用stack时候要指定一个维度位置,在那个位置前会插入一个新的维度,因为是两类东西合并过来所以这个新的维度size是2,通过指定这个维度是0或者1来选择性别是男还是女。
示例代码:
c = torch.rand(4, 3, 32, 32)
d = torch.rand(4, 3, 32, 32)
print(torch.stack([c, d], dim=2).shape)
print(torch.stack([c, d], dim=0).shape)
输出结果:
torch.Size([4, 3, 2, 32, 32])
torch.Size([2, 4, 3, 32, 32])
2. 维度拆分
(1)split:按照指定维度的size进行拆分
对一个Tensor而言,要拆分的那个维度的size就是"这个维度的总长度"了,可以指定拆分后的几个Tensor各取多少长度,或者指定每个Tensor取多少长度。
示例代码:
import torch
a = torch.rand(2, 4, 3, 32, 32)
# 对0号维度拆分,拆分后每个Tensor取长度1
a1, a2 = a.split(1, dim=0)
print(a1.shape, a2.shape)
b = torch.rand(4, 3, 32, 32)
# 对1号维度拆分,拆分后第一个维度取2,第二个维度取1
b1, b2 = b.split([2, 1], dim=1)
print(b1.shape, b2.shape)
输出结果:
torch.Size([1, 4, 3, 32, 32]) torch.Size([1, 4, 3, 32, 32])
torch.Size([4, 2, 32, 32]) torch.Size([4, 1, 32, 32])
(2)chunk:按照份数等量拆分
给定在指定的维度上要拆分的份数,就会按照指定的份数尽量等量地进行拆分。
示例代码:
c = torch.rand(7, 4)
# 拆分成4份
c1, c2, c3, c4 = c.chunk(4, dim=0)
print(c1.shape, c2.shape, c3.shape, c4.shape
输出结果:
torch.Size([2, 4]) torch.Size([2, 4]) torch.Size([2, 4]) torch.Size([1, 4])