目录:
-
-
- 1.torch.cat 使用dim对Tensor进行指定拼接
- 2.如何将一个Tensor按指定维度切片?
- 3.按照索引对元素进行聚合
- 4,如何按照索引选择目标数据?
- 5、如何选出满足矩阵条件的元素?
- 6.如何找出矩阵中的非零元素的索引?
- 7.如何将输入张量分割成相同形状的的chunks?
- 8.如何增加一个矩阵的维度(如1维变2维)?
- 9.如何实现Tensor维度之间的转置?
- 10.如何沿着某个维度切片后返回所有切片组成的列表?
1.torch.cat 使用dim对Tensor进行指定拼接
Tensor = torch.ones(2,3)
print(torch.cat([Tensor,Tensor]))
print(torch.cat([Tensor,Tensor],dim=0))
print(torch.cat([Tensor,Tensor],dim=1).shape)
print(torch.stack([Tensor,Tensor],dim=1).shape)
2.如何将一个Tensor按指定维度切片?
x = torch.ones(2,10)
print(torch.chunk(x,5,dim=1))
3.按照索引对元素进行聚合
x = torch.Tensor([[33,66,9],[1,99,88]])
print(x)
result = torch.gather(x,1,torch.LongTensor([[0,1],[1,2]]))
print(result)
print(result.shape)
4,如何按照索引选择目标数据?
x = torch.randn(2,7)
print(x)
y = torch.index_select(x,1,torch.LongTensor([1,3,5]))
print(y)
z = torch.index_select(x,0,torch.LongTensor([1]))
print(z)
5、如何选出满足矩阵条件的元素?
x = torch.randn(2,4)
print(x)
mask = x.ge(0.5)
print(mask)
print(torch.masked_select(x,mask))
6.如何找出矩阵中的非零元素的索引?
x = torch.Tensor([[0.0,1.1],[6.6,0.0]])
print(x)
print(torch.nonzero(x))
----------------------------------------------------------------------
result:
tensor([[0.0000, 1.1000],
[6.6000, 0.0000]])
tensor([[0, 1],
[1, 0]])
7.如何将输入张量分割成相同形状的的chunks?
x = torch.ones(2,5)
print(x)
print(torch.split(x,2,dim=1))
print(torch.split(x,1,dim=0))
print(torch.split(x,[2,3],dim=1))
---------------------------------------------------------------------
result:
tensor([[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.]])
(tensor([[1., 1.],
[1., 1.]]), tensor([[1., 1.],
[1., 1.]]), tensor([[1.],
[1.]]))
(tensor([[1., 1., 1., 1., 1.]]), tensor([[1., 1., 1., 1., 1.]]))
(tensor([[1., 1.],
[1., 1.]]), tensor([[1., 1., 1.],
[1., 1., 1.]]))
8.如何增加一个矩阵的维度(如1维变2维)?
x = torch.Tensor([1,2,3,4,5])
print(x.shape)
y = x.unsqueeze(dim=0)
print(y)
print(y.shape)
z = x.unsqueeze(dim=1)
print(z.shape)
-----------------------------------------------------------------------
result:
torch.Size([5])
tensor([[1., 2., 3., 4., 5.]])
torch.Size([1, 5])
torch.Size([5, 1])
9.如何实现Tensor维度之间的转置?
x = torch.randn(2,1)
print(x,x.shape)
y = torch.t(x)
print(y,y.shape)
z = torch.transpose(x,1,0)
print(z,z.shape)
print(x.t())
print(x.transpose(1,0))
-----------------------------------------------------------------------
result:
tensor([[0.6000],
[0.2830]]) torch.Size([2, 1])
tensor([[0.6000, 0.2830]]) torch.Size([1, 2])
tensor([[0.6000, 0.2830]]) torch.Size([1, 2])
tensor([[0.6000, 0.2830]])
tensor([[0.6000, 0.2830]])
10.如何沿着某个维度切片后返回所有切片组成的列表?
x = torch.rand(2,2,2)
print(x)
print(torch.unbind(x,dim=1))
-----------------------------------------------------------------------
result:
tensor([[[0.0420, 0.9931],
[0.5015, 0.7112]],
[[0.2467, 0.9473],
[0.7529, 0.8323]]])
(tensor([[0.0420, 0.9931],
[0.2467, 0.9473]]), tensor([[0.5015, 0.7112],
[0.7529, 0.8323]]))