Pytorch 维度拼接与维度拆分

1. 维度拼接

(1)cat:根据指定维度进行数据的合并。

通过指定维度拼接时,需要保证其它维度的size是相同的。

示例代码:

import torch

a = torch.rand(4, 32, 8)
b = torch.rand(5, 32, 8)
# 在第一个维度进行拼接
print(torch.cat([a, b], dim=0).shape)

输出结果:

torch.Size([9, 32, 8]

(2)stack:在指定维度的位置前插入新的维度。

stack需要保证两个Tensor的shape是一致的,这就像是有两类东西,它们的其它属性都是一样的(比如男的一张表,女的一张表)。使用stack时候要指定一个维度位置,在那个位置前会插入一个新的维度,因为是两类东西合并过来所以这个新的维度size是2,通过指定这个维度是0或者1来选择性别是男还是女。

示例代码:

c = torch.rand(4, 3, 32, 32)
d = torch.rand(4, 3, 32, 32)
print(torch.stack([c, d], dim=2).shape)
print(torch.stack([c, d], dim=0).shape)

输出结果:

torch.Size([4, 3, 2, 32, 32])
torch.Size([2, 4, 3, 32, 32])

2. 维度拆分

(1)split:按照指定维度的size进行拆分
对一个Tensor而言,要拆分的那个维度的size就是"这个维度的总长度"了,可以指定拆分后的几个Tensor各取多少长度,或者指定每个Tensor取多少长度。

示例代码:

import torch

a = torch.rand(2, 4, 3, 32, 32)
# 对0号维度拆分,拆分后每个Tensor取长度1
a1, a2 = a.split(1, dim=0)  
print(a1.shape, a2.shape)

b = torch.rand(4, 3, 32, 32)
# 对1号维度拆分,拆分后第一个维度取2,第二个维度取1
b1, b2 = b.split([2, 1], dim=1)  
print(b1.shape, b2.shape)

输出结果:

torch.Size([1, 4, 3, 32, 32]) torch.Size([1, 4, 3, 32, 32])
torch.Size([4, 2, 32, 32]) torch.Size([4, 1, 32, 32])

(2)chunk:按照份数等量拆分

给定在指定的维度上要拆分的份数,就会按照指定的份数尽量等量地进行拆分。

示例代码:

c = torch.rand(7, 4)
# 拆分成4份
c1, c2, c3, c4 = c.chunk(4, dim=0)
print(c1.shape, c2.shape, c3.shape, c4.shape

输出结果:

torch.Size([2, 4]) torch.Size([2, 4]) torch.Size([2, 4]) torch.Size([1, 4])

 

你可能感兴趣的:(深度学习,pytorch)