reshape与view的作用一致,即重塑张量的shape。
但view要求操作张量在内存空间中连续,而reshape则没有此要求,因此reshape适用性更好。
使用时只需保证各维度数字的乘积等于元素个数即可。
a = torch.rand(2, 4)
print(a.size()) # torch.Size([2, 4])
a = a.reshape(1, 2, 4)
print(a.size()) # torch.Size([1, 2, 4])
a = a.reshape(2, 1, 4, 1)
print(a.size()) # torch.Size([2, 1, 4, 1])
a = a.reshape(2, 2, 2)
print(a.size()) # torch.Size([2, 2, 2])
unsqueeze即扩展指定维度。
import torch
t = torch.rand(4, 5)
t1 = t.unsqueeze(0) # 在第0维扩展
print(t1.size()) # torch.Size([1, 4, 5])
t2 = t.unsqueeze(1) # 在第1维扩展
print(t2.size()) # torch.Size([4, 1, 5])
t3 = t.unsqueeze(2) # 在第2维扩展
print(t3.size()) # torch.Size([4, 5, 1])
t4 = t.unsqueeze(-1) # 在第最后一维扩展
print(t4.size()) # torch.Size([4, 5, 1])
# 也可用以下形式
t5 = torch.unsqueeze(t, -1)
print(t5.size()) # torch.Size([4, 5, 1])
squeeze即压缩指定维度。
import torch
t = torch.rand(4, 1, 5, 1)
t1 = t.squeeze() # 压缩所有维
print(t1.size()) # torch.Size([4, 5])
t2 = t.squeeze(1) # 在第1维压缩
print(t2.size()) # torch.Size([4, 5, 1])
t3 = t.squeeze(-1) # 在最后一维压缩
print(t3.size()) # torch.Size([4, 1, 5])
t4 = t.squeeze(0) # 压缩第0维(无法压缩则保持原样)
print(t4.size()) # torch.Size([4, 1, 5, 1])
# 也可写成以下形式
t5 = torch.squeeze(t, -1)
print(t5.size()) # torch.Size([4, 1, 5])
transpose即转置。
import torch
t = torch.rand(2, 3)
print(t.size()) # torch.Size([2, 3])
print(t)
"""
tensor([[0.3861, 0.2016, 0.1332],
[0.2616, 0.9513, 0.6442]])
"""
t1 = t.transpose(0, 1) # 将第0维与第1维进行转置
print(t1.size()) # torch.Size([3, 2])
print(t1)
"""
tensor([[0.3861, 0.2616],
[0.2016, 0.9513],
[0.1332, 0.6442]])
"""
cat即张量拼接。
import torch
a = torch.rand(3, 4, 5)
b = torch.rand(2, 4, 5)
c1 = torch.cat([a, b], dim=0) # 按照第0维进行拼接,其他维度需一致
print(c1.size()) # torch.Size([5, 4, 5])
a = torch.rand(4, 4, 3)
b = torch.rand(4, 4, 2)
c2 = torch.cat([a, b], dim=2) # 按照第2维进行拼接,其他维度需一致
print(c2.size()) # torch.Size([4, 4, 5])
repeat会将指定维度重复,构造出新的张量。
import torch
a = torch.rand(3, 4, 5)
a1 = a.repeat(1, 1, 1) # 各维度保持不变
print(a1.size()) # torch.Size([3, 4, 5])
a2 = a.repeat(1, 2, 1) # 第1维变为两倍
print(a2.size()) # torch.Size([3, 8, 5])
a3 = a.repeat(1, 2, 4) # 第1维变为两倍,第二维变为4倍
print(a3.size()) # torch.Size([3, 8, 20])
a4 = a.repeat(1, 1, 1, 1) # 在第0维扩展
print(a4.size()) # torch.Size([1, 3, 4, 5])
flatten会将指定维度展平。
import torch
a = torch.rand(3, 4, 5)
a1 = a.flatten() # 全部展平
print(a1.size()) # torch.Size([60])
a2 = a.flatten(1, 2) # 将第1维到第2维展平
print(a2.size()) # torch.Size([3, 20])
a3 = a.flatten(0, 1) # 将第0维到第1维展平
print(a3.size()) # torch.Size([12, 5])