torch.expand()用来扩展某张量某维度的值,先看代码示例:
import torch
x = torch.tensor([[1, 2, 3]])
print(x.size())
y = x.expand(2,3)
print(y)
print(y.size())
输出:
torch.Size([1, 3])
tensor([[1, 2, 3],
[1, 2, 3]])
torch.Size([2, 3])
官网文档如下:
通过官方文档,我觉得需要注意的一点是,expand()只能对维度为1的那个维度进行扩张,如果不是1,则无法进行扩展。示例如下:
import torch
x = torch.tensor([[1, 2, 3]])
print(x.size())
y = x.expand(2,4)
print(y)
print(y.size())
输出:
torch.Size([1, 3])
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-27-ac6657df382d> in <module>
1 x = torch.tensor([[1, 2, 3]])
2 print(x.size())
----> 3 y = x.expand(2,4)
4 print(y)
5 print(y.size())
RuntimeError: The expanded size of the tensor (4) must match the existing size (3) at non-singleton dimension 1.
Target sizes: [2, 4]. Tensor sizes: [1, 3]
torch.Tensor.item( )在获取一维张量中的数字的时候使用,且只能使用在一维张量,返回一个数字。
torch.Tensor.tolist( )在取tensor中数据的时候使用,一维张量的时候等同于torch.Tensor.item( )返回一个数字,多维张量的时候使用会返回一个list(嵌套的列表)。
torch.suqeeze( ) 是将一个张量中维度为1的维度删除然后返回,第二个参数可以不指定。
需要注意的是,squeeze只能对于维度为1的维度进行删除
第二个参数如果指定的话,就只对指定的维度进行操作,如果指定维度不是1,就不进行任何操作,如果指定维度是1,就删除此维度。
torch.unsuqeeze( ):是返回一个在指定维度增加了一个维度为1的维度的张量,第二个参数必须指定。
import torch
x = torch.zeros(2, 1, 2, 1, 2)
print(x.size())
y = torch.squeeze(x)
print(y.size())
z = torch.squeeze(x,1)
print(z.size())
输出:
torch.Size([2, 1, 2, 1, 2])
torch.Size([2, 2, 2])
torch.Size([2, 2, 1, 2])
import torch
x = torch.zeros(2, 2, 2)
print(x.size())
y = torch.unsqueeze(x,1)
print(y.size())
z = torch.unsqueeze(x,3)
print(z.size())
输出:
torch.Size([2, 2, 2])
torch.Size([2, 1, 2, 2])
torch.Size([2, 2, 2, 1])