pytorch索引、切片、连接和换位

目录:

      • 1.torch.cat 使用dim对Tensor进行指定拼接
      • 2.如何将一个Tensor按指定维度切片?
      • 3.按照索引对元素进行聚合
      • 4,如何按照索引选择目标数据?
      • 5、如何选出满足矩阵条件的元素?
      • 6.如何找出矩阵中的非零元素的索引?
      • 7.如何将输入张量分割成相同形状的的chunks?
      • 8.如何增加一个矩阵的维度(如1维变2维)?
      • 9.如何实现Tensor维度之间的转置?
      • 10.如何沿着某个维度切片后返回所有切片组成的列表?

1.torch.cat 使用dim对Tensor进行指定拼接

Tensor = torch.ones(2,3)
print(torch.cat([Tensor,Tensor]))
print(torch.cat([Tensor,Tensor],dim=0))  # 纵向拼接为一个矩阵
print(torch.cat([Tensor,Tensor],dim=1).shape)  # 横向拼接为一个矩阵
 # torch.stack 方法进行拼接
print(torch.stack([Tensor,Tensor],dim=1).shape) # 拼接为两个矩阵

2.如何将一个Tensor按指定维度切片?

x = torch.ones(2,10)
print(torch.chunk(x,5,dim=1)) # 横向切片,切为5个

3.按照索引对元素进行聚合

x = torch.Tensor([[33,66,9],[1,99,88]])
print(x)
result = torch.gather(x,1,torch.LongTensor([[0,1],[1,2]]))   # 按索引取 维度1 第一行的 1 2两个数  第二行的2 3 两个数
print(result)
print(result.shape)

4,如何按照索引选择目标数据?

x = torch.randn(2,7)
print(x)
# 参数1:原Tensor  参数2:维度   参数3:该维度上的索引
y = torch.index_select(x,1,torch.LongTensor([1,3,5]))
print(y)
z = torch.index_select(x,0,torch.LongTensor([1]))
print(z)

5、如何选出满足矩阵条件的元素?

x = torch.randn(2,4)
print(x)
mask = x.ge(0.5)  # 大于0.5为1  小于0.5 为0
print(mask)
print(torch.masked_select(x,mask))

6.如何找出矩阵中的非零元素的索引?

x = torch.Tensor([[0.0,1.1],[6.6,0.0]])
print(x)
print(torch.nonzero(x))
----------------------------------------------------------------------
result:
tensor([[0.0000, 1.1000],
        [6.6000, 0.0000]])
tensor([[0, 1],
        [1, 0]])

7.如何将输入张量分割成相同形状的的chunks?

x = torch.ones(2,5)
print(x)
print(torch.split(x,2,dim=1))
print(torch.split(x,1,dim=0))
print(torch.split(x,[2,3],dim=1))
---------------------------------------------------------------------
result:
tensor([[1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.]])
(tensor([[1., 1.],
        [1., 1.]]), tensor([[1., 1.],
        [1., 1.]]), tensor([[1.],
        [1.]]))
(tensor([[1., 1., 1., 1., 1.]]), tensor([[1., 1., 1., 1., 1.]]))
(tensor([[1., 1.],
        [1., 1.]]), tensor([[1., 1., 1.],
        [1., 1., 1.]]))

8.如何增加一个矩阵的维度(如1维变2维)?

x = torch.Tensor([1,2,3,4,5])
print(x.shape)
y = x.unsqueeze(dim=0)  # 降维用squeeze()删除dim指定维度
print(y)
print(y.shape)
z = x.unsqueeze(dim=1)
print(z.shape)
-----------------------------------------------------------------------
result:
torch.Size([5])
tensor([[1., 2., 3., 4., 5.]])
torch.Size([1, 5])
torch.Size([5, 1])

9.如何实现Tensor维度之间的转置?

x = torch.randn(2,1)
print(x,x.shape)
y = torch.t(x)
print(y,y.shape)
z = torch.transpose(x,1,0)
print(z,z.shape)
print(x.t())
print(x.transpose(1,0))
-----------------------------------------------------------------------
result:
tensor([[0.6000],
        [0.2830]]) torch.Size([2, 1])
tensor([[0.6000, 0.2830]]) torch.Size([1, 2])
tensor([[0.6000, 0.2830]]) torch.Size([1, 2])
tensor([[0.6000, 0.2830]])
tensor([[0.6000, 0.2830]])

10.如何沿着某个维度切片后返回所有切片组成的列表?

x = torch.rand(2,2,2)
print(x)
print(torch.unbind(x,dim=1))
-----------------------------------------------------------------------
result:
tensor([[[0.0420, 0.9931],
         [0.5015, 0.7112]],

        [[0.2467, 0.9473],
         [0.7529, 0.8323]]])
(tensor([[0.0420, 0.9931],
        [0.2467, 0.9473]]), tensor([[0.5015, 0.7112],
        [0.7529, 0.8323]]))

你可能感兴趣的:(PyTorch的攀登年华,pytorch,python)