ensor的索引、切片和拼接
- 格式: torch.cat(tensors, dim=0, out=None) → Tensor
- 解释:在指定维度上拼接两个tensor
>>> x = torch.randn(2, 3)
>>> x
tensor([[ 0.6580, -1.0969, -0.4614],
[-0.1034, -0.5790, 0.1497]])
>>> torch.cat((x, x, x), 0) # 在dim=0上拼接,也就是行方向
tensor([[ 0.6580, -1.0969, -0.4614],
[-0.1034, -0.5790, 0.1497],
[ 0.6580, -1.0969, -0.4614],
[-0.1034, -0.5790, 0.1497],
[ 0.6580, -1.0969, -0.4614],
[-0.1034, -0.5790, 0.1497]])
>>> torch.cat((x, x, x), 1) # 在dim=1上拼接
tensor([[ 0.6580, -1.0969, -0.4614, 0.6580, -1.0969, -0.4614, 0.6580,
-1.0969, -0.4614],
[-0.1034, -0.5790, 0.1497, -0.1034, -0.5790, 0.1497, -0.1034,
-0.5790, 0.1497]])
- 格式:torch.stack(tensors, dim=0, out=None) → Tensor
- 解释:沿着一个新的维度对张量进行拼接。序列中的tensors必须具有相同的size。
- **直白的说:**它可以将二维tensor变三维tensor,三维变4维
- 与torch.cat()的区别。stack()属于扩张再拼接的函数。通常用于NLP和CV领域
注意: 下图及实例来自torch.stack()的官方解释,详解以及例子
# 假设是时间步T1的输出
T1 = torch.tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
# 假设是时间步T2的输出
T2 = torch.tensor([[10, 20, 30],
[40, 50, 60],
[70, 80, 90]])
print(torch.stack((T1,T2),dim=0))
print(torch.stack((T1,T2),dim=0).shape)
print(‘==================================’)
print(torch.stack((T1,T2),dim=1))
print(torch.stack((T1,T2),dim=1).shape)
print(‘==================================’)
print(torch.stack((T1,T2),dim=2))
print(torch.stack((T1,T2),dim=2).shape)
print(‘==================================’)
print(torch.stack((T1,T2),dim=3).shape)
输出:
tensor([[[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9]],
[[10, 20, 30],
[40, 50, 60],
[70, 80, 90]]])
torch.Size([2, 3, 3])
==================================
tensor([[[ 1, 2, 3],
[10, 20, 30]],
[[ 4, 5, 6],
[40, 50, 60]],
[[ 7, 8, 9],
[70, 80, 90]]])
torch.Size([3, 2, 3])
==================================
tensor([[[ 1, 10],
[ 2, 20],
[ 3, 30]],
[[ 4, 40],
[ 5, 50],
[ 6, 60]],
[[ 7, 70],
[ 8, 80],
[ 9, 90]]])
torch.Size([3, 3, 2])
==================================
'选择的dim>len(outputs),所以报错'
IndexError: Dimension out of range (expected to be in range of [-3, 2], but got 3)
- 格式: torch.chunk(input, chunks, dim=0) → List of Tensors
- chunks只能是整数
- 解释:可以看做是torch.cat()的反向操作,分割tensor
注意:
如果沿着指定轴不能整数切分,那么最后一个块将会是最小的
>>> c
tensor([[0.9387, 0.5666, 0.3289, 0.7775, 0.5938],
[0.0968, 0.0961, 0.6976, 0.9121, 0.0796],
[0.4676, 0.7772, 0.2398, 0.5254, 0.9906],
[0.0588, 0.7729, 0.2259, 0.6438, 0.8299]], dtype=torch.float64)
>>> cc = torch.chunk(c,2,1) # 很明显size_c = (4,5)不能争分
>>> cc # 因此cc的两个块分别为(4,3)和(4,2)
(tensor([[0.9387, 0.5666, 0.3289],
[0.0968, 0.0961, 0.6976],
[0.4676, 0.7772, 0.2398],
[0.0588, 0.7729, 0.2259]], dtype=torch.float64), tensor([[0.7775, 0.5938],
[0.9121, 0.0796],
[0.5254, 0.9906],
[0.6438, 0.8299]], dtype=torch.float64))
- 格式: torch.split(tensor, split_size_or_sections, dim=0)
- split_size_or_sections 可以是整数,也可以是一个列表。如果是int,则先尽可能整除,然后最后一个最小;如果是列表,则将分割为len(list)个块,并按照list的元素进行分配
- 解释:可以看做是torch.cat()的反向操作,分割tensor
>>> c
tensor([[0.9387, 0.5666, 0.3289, 0.7775, 0.5938],
[0.0968, 0.0961, 0.6976, 0.9121, 0.0796],
[0.4676, 0.7772, 0.2398, 0.5254, 0.9906],
[0.0588, 0.7729, 0.2259, 0.6438, 0.8299]], dtype=torch.float64)
# 是一个列表
>>>cc = torch.split(c,[2,2,1],1)
>>> cc
(tensor([[0.9387, 0.5666],
[0.0968, 0.0961],
[0.4676, 0.7772],
[0.0588, 0.7729]], dtype=torch.float64), tensor([[0.3289, 0.7775],
[0.6976, 0.9121],
[0.2398, 0.5254],
[0.2259, 0.6438]], dtype=torch.float64), tensor([[0.5938],
[0.0796],
[0.9906],
[0.8299]], dtype=torch.float64))
# 是一个整数
>>> cc=torch.split(c,2,1)
>>> cc
(tensor([[0.9387, 0.5666],
[0.0968, 0.0961],
[0.4676, 0.7772],
[0.0588, 0.7729]], dtype=torch.float64), tensor([[0.3289, 0.7775],
[0.6976, 0.9121],
[0.2398, 0.5254],
[0.2259, 0.6438]], dtype=torch.float64), tensor([[0.5938],
[0.0796],
[0.9906],
[0.8299]], dtype=torch.float64))
- 格式:torch.gather(input, dim, index, out=None, sparse_grad=False) → Tensor
- 解释:从原tensor中按照指定轴dim和索引index获取数据
- index是一个tensor。输出的维度与index的维度一定时相同的
实例:摘自知乎:图解PyTorch中的torch.gather函数
先定义一个原始tensor:
>>> a = torch.arange(3, 12).view(3, 3)
>>> a
tensor([[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]])
指定index
index = torch.tensor([[2,1,0]])
在dim=0方向采集
>>> b = torch.gather(a,0,index)
>>> b
tensor([[9, 7, 5]])
- 格式:torch.squeeze(input, dim=None, out=None) → Tensor
- 解释:大小为1的维度都被删除
官网实例
如果 input 的size=(A×1×B×C×1×D)
,则返回的tensor size= (A×B×C×D)
如果指定dim上的维度不为1,则返回不变
>>> x = torch.zeros(2, 1, 2, 1, 2)
>>> x.size()
torch.Size([2, 1, 2, 1, 2])
>>> y = torch.squeeze(x)
>>> y.size()
torch.Size([2, 2, 2])
>>> y = torch.squeeze(x, 0)
>>> y.size()
torch.Size([2, 1, 2, 1, 2])
>>> y = torch.squeeze(x, 1)
>>> y.size()
torch.Size([2, 2, 1, 2])
- 格式:torch.unsqueeze(input, dim, out=None) → Tensor
- 解释:在指定轴添加1。
>>> x = torch.tensor([1, 2, 3, 4])
>>> torch.unsqueeze(x, 0)
tensor([[ 1, 2, 3, 4]])
>>> torch.unsqueeze(x, 1)
tensor([[ 1],
[ 2],
[ 3],
[ 4]])
举例
>>> t = torch.tensor([[[1, 2],
[3, 4]],
[[5, 6],
[7, 8]]])
>>>t.size # (2,2,2)
>>> torch.flatten(t) # 默认从dim=0开始
tensor([1, 2, 3, 4, 5, 6, 7, 8])
>>> torch.flatten(t,start_dim = 1) #
tensor([[1, 2, 3, 4],
[5, 6, 7, 8]])
torch官网:https://pytorch.org/docs/1.2.0/torch.html