今天在调网络的时候发现了这样的bug,下面简化一下该情况
In [2]: import torch
In [3]: a = torch.tensor([1,1,1,1,1])
In [4]: a
Out[4]: tensor([1, 1, 1, 1, 1])
In [5]: a[1] = 20.567
In [6]: a
Out[6]: tensor([ 1, 20, 1, 1, 1])
In [8]: a = torch.cat((a,torch.tensor([1.])),0)
In [9]: a
Out[9]: tensor([ 1., 20., 1., 1., 1., 1.])
In [10]: a[2] = 20.11
In [11]: a
Out[11]: tensor([ 1.0000, 20.0000, 20.1100, 1.0000, 1.0000, 1.0000])
这里注意
Data type | dtype | dtype |
---|---|---|
32-bit floating point | torch.float32 or torch.float | torch.*.FloatTensor |
64-bit floating point | torch.float64 or torch.double | torch.*.DoubleTensor |
16-bit floating point | torch.float16 or torch.half | torch.*.HalfTensor+ |
8-bit integer (unsigned) | torch.uint8 | torch.*.ByteTensor |
8-bit integer (signed) | torch.int8 | torch.*.CharTensor |
16-bit integer (signed) | torch.int16 or torch.short | torch.*.ShortTensor |
32-bit integer (signed) | torch.int32 or torch.int | torch.*.IntTensor |
64-bit integer (signed) | torch.int64 or torch.long | torch.*.LongTensor |
a = torch,tensor((1,2,3,4))
a = a.float()
a = a.int()
a = a.half()
a = a.type(torch.uint8)
注意,没有torch.tensor.uint8()
和torch.tensor.int8()
这两种转化方法,只有torch.tensor.type(torch.uint8)
和torch.tensor.type(torch.int8)
a = torch.tensor((1,2,3,4,5),dtype =torch.float)
b = torch.tensor((1,2,3,4,5), dtype = float)
c = torch.tensor((1,2,3,4,5),dtype = torch.double)
d = torch
In [35]: a
Out[35]: tensor([1., 2., 3., 4., 5.])
In [36]: a.dtype
Out[36]: torch.float32
In [37]: b
Out[37]: tensor([1., 2., 3., 4., 5.], dtype=torch.float64)
In [38]: b.dtype
Out[38]: torch.float64
In [43]: c
Out[43]: tensor([1., 2., 3., 4., 5.], dtype=torch.float64)
In [44]: c.dtype
Out[44]: torch.float64
In [45]: c = torch.tensor((1,2,3,4,5),dtype = torch.float64)
In [46]: c
Out[46]: tensor([1., 2., 3., 4., 5.], dtype=torch.float64)
这里注意,初始化时dtype = float
、dtype = torch.double
、dtype = torch.float64
表示的是一个含义!
In [59]: a = torch.FloatTensor(3,4,5)
In [60]: a.dtype
Out[60]: torch.float32
In [61]: a = torch.Tensor(3,4,5)
In [62]: a.dtype
Out[62]: torch.float32
这里表明Tensor和FloatTensor生成的数据类型是一致的!