Pytorch入门:Tensor常见操作(重塑、扩展、压缩、转置、拼接、重复、展平)

reshape或view:

reshape与view的作用一致,即重塑张量的shape。

但view要求操作张量在内存空间中连续,而reshape则没有此要求,因此reshape适用性更好。

使用时只需保证各维度数字的乘积等于元素个数即可。

a = torch.rand(2, 4)
print(a.size())     # torch.Size([2, 4])

a = a.reshape(1, 2, 4)
print(a.size())     # torch.Size([1, 2, 4])

a = a.reshape(2, 1, 4, 1)
print(a.size())     # torch.Size([2, 1, 4, 1])

a = a.reshape(2, 2, 2)
print(a.size())     # torch.Size([2, 2, 2])

unsqueeze:

unsqueeze即扩展指定维度。

import torch

t = torch.rand(4, 5)

t1 = t.unsqueeze(0)     # 在第0维扩展
print(t1.size())        # torch.Size([1, 4, 5])

t2 = t.unsqueeze(1)     # 在第1维扩展
print(t2.size())        # torch.Size([4, 1, 5])

t3 = t.unsqueeze(2)     # 在第2维扩展
print(t3.size())        # torch.Size([4, 5, 1])

t4 = t.unsqueeze(-1)    # 在第最后一维扩展
print(t4.size())        # torch.Size([4, 5, 1])

# 也可用以下形式
t5 = torch.unsqueeze(t, -1)
print(t5.size())        # torch.Size([4, 5, 1])

squeeze:

squeeze即压缩指定维度。

import torch

t = torch.rand(4, 1, 5, 1)

t1 = t.squeeze()    # 压缩所有维
print(t1.size())    # torch.Size([4, 5])

t2 = t.squeeze(1)   # 在第1维压缩
print(t2.size())    # torch.Size([4, 5, 1])

t3 = t.squeeze(-1)  # 在最后一维压缩
print(t3.size())    # torch.Size([4, 1, 5])

t4 = t.squeeze(0)   # 压缩第0维(无法压缩则保持原样)
print(t4.size())    # torch.Size([4, 1, 5, 1])

# 也可写成以下形式
t5 = torch.squeeze(t, -1)
print(t5.size())    # torch.Size([4, 1, 5])

transpose:

transpose即转置。

import torch

t = torch.rand(2, 3)
print(t.size())         # torch.Size([2, 3])
print(t)
"""
tensor([[0.3861, 0.2016, 0.1332],
        [0.2616, 0.9513, 0.6442]])
"""

t1 = t.transpose(0, 1)  # 将第0维与第1维进行转置
print(t1.size())        # torch.Size([3, 2])
print(t1)
"""
tensor([[0.3861, 0.2616],
        [0.2016, 0.9513],
        [0.1332, 0.6442]])
"""

cat:

cat即张量拼接。

import torch

a = torch.rand(3, 4, 5)
b = torch.rand(2, 4, 5)

c1 = torch.cat([a, b], dim=0)   # 按照第0维进行拼接,其他维度需一致
print(c1.size())                # torch.Size([5, 4, 5])



a = torch.rand(4, 4, 3)
b = torch.rand(4, 4, 2)

c2 = torch.cat([a, b], dim=2)   # 按照第2维进行拼接,其他维度需一致
print(c2.size())                # torch.Size([4, 4, 5])

repeat:

repeat会将指定维度重复,构造出新的张量。

import torch

a = torch.rand(3, 4, 5)

a1 = a.repeat(1, 1, 1)  # 各维度保持不变
print(a1.size())        # torch.Size([3, 4, 5])

a2 = a.repeat(1, 2, 1)  # 第1维变为两倍
print(a2.size())        # torch.Size([3, 8, 5])

a3 = a.repeat(1, 2, 4)  # 第1维变为两倍,第二维变为4倍
print(a3.size())        # torch.Size([3, 8, 20])

a4 = a.repeat(1, 1, 1, 1)   # 在第0维扩展
print(a4.size())            # torch.Size([1, 3, 4, 5])

flatten:

flatten会将指定维度展平。

import torch

a = torch.rand(3, 4, 5)

a1 = a.flatten()        # 全部展平
print(a1.size())        # torch.Size([60])

a2 = a.flatten(1, 2)    # 将第1维到第2维展平
print(a2.size())        # torch.Size([3, 20])

a3 = a.flatten(0, 1)    # 将第0维到第1维展平
print(a3.size())        # torch.Size([12, 5])

你可能感兴趣的:(Pytorch学习笔记,python,numpy,深度学习)