获取 Tensor
的元素个数 ,a.numel()
等价 a.nelement()
In [1]: import torch as t
In [5]: a = t.Tensor(2,3)
In [6]: a
Out[6]:
tensor([[1.2116e+30, 4.5695e-41, 2.1064e-36],
[0.0000e+00, nan, 4.5695e-41]])
In [7]: a.numel()
Out[7]: 6
In [8]: a.nelement()
Out[8]: 6
查看 Tensor
的形状,Tensor.size()
返回 torch.Size()
对象, Tensor.shape
等价于 Tensor.size()
In [13]: c.size()
Out[13]: torch.Size([2])
In [14]: b.size()
Out[14]: torch.Size([2, 2])
In [15]: a.size()
Out[15]: torch.Size([2, 3])
In [19]: c.shape
Out[19]: torch.Size([2])
In [20]: b.shape
Out[20]: torch.Size([2, 2])
In [21]: a.shape
Out[21]: torch.Size([2, 3])
通过 tensor.view
方法可以调整 tensor
的形状,但必须保证调整前后的元素总数保持一致,view
不会修改自身的数据,返回的新 tensor
与源 tensor
共享内存,即更改其中一个,另外一个也跟着改变。
In [1]: import torch as t
In [31]: a = t.arange(1,6)
In [32]: a
Out[32]: tensor([1, 2, 3, 4, 5])
In [33]: a = t.arange(0,6)
In [34]: a
Out[34]: tensor([0, 1, 2, 3, 4, 5])
In [35]: b = a.view(2,3)
In [36]: b
Out[36]:
tensor([[0, 1, 2],
[3, 4, 5]])
In [37]: c = a.view(-1, 3) # 某一维度为 -1 时会自动计算它的大小
In [38]: c
Out[38]:
tensor([[0, 1, 2],
[3, 4, 5]])
In [39]: c[0,1] = 100
In [40]: a
Out[40]: tensor([ 0, 100, 2, 3, 4, 5])
In [41]: b
Out[41]:
tensor([[ 0, 100, 2],
[ 3, 4, 5]])
In [42]: c
Out[42]:
tensor([[ 0, 100, 2],
[ 3, 4, 5]])
In [43]:
添加或减少某一维度,可以使用 squeeze
和 unsqueeze
函数。
In [43]: b
Out[43]:
tensor([[ 0, 100, 2],
[ 3, 4, 5]])
In [45]: b.shape
Out[45]: torch.Size([2, 3])
In [46]: d = b.unsqueeze(1) # 在第一维下标从0开始,增加1
In [47]: d
Out[47]:
tensor([[[ 0, 100, 2]],
[[ 3, 4, 5]]])
In [48]: d.shape
Out[48]: torch.Size([2, 1, 3])
In [49]: d.squeeze(1)
Out[49]:
tensor([[ 0, 100, 2],
[ 3, 4, 5]])
In [50]: d.squeeze(1).shape
Out[50]: torch.Size([2, 3])
resize
是另一种可用来调整 size
的方法,但与 view
不同,它可以修改 tensor
的尺寸,如果新尺寸超过了源尺寸,会自动分配新的内存空间,而如果新尺寸小于源尺寸,则之前的数据依旧会被保存。
In [51]: b
Out[51]:
tensor([[ 0, 100, 2],
[ 3, 4, 5]])
In [52]: b.resize_(1,3)
Out[52]: tensor([[ 0, 100, 2]])
In [53]: b
Out[53]: tensor([[ 0, 100, 2]])
In [54]: b.resize_(3,3)
Out[54]:
tensor([[ 0, 100, 2],
[ 3, 4, 5],
[2314885530818447916, 2331492554444382240, 2318280896059485744]])
In [55]: