无法给torch.tensor赋予浮点型数据,如何将tensor的数据类型进行转换

今天在调网络的时候发现了这样的bug,下面简化一下该情况

In [2]: import torch

In [3]: a = torch.tensor([1,1,1,1,1])

In [4]: a
Out[4]: tensor([1, 1, 1, 1, 1])

In [5]: a[1] = 20.567

In [6]: a
Out[6]: tensor([ 1, 20,  1,  1,  1])

In [8]: a = torch.cat((a,torch.tensor([1.])),0)

In [9]: a
Out[9]: tensor([ 1., 20.,  1.,  1.,  1.,  1.])

In [10]: a[2] = 20.11

In [11]: a
Out[11]: tensor([ 1.0000, 20.0000, 20.1100,  1.0000,  1.0000,  1.0000])

这里注意

  1. 当tensor中存储的是整数是时候,改写内部元素的时候即使赋予浮点数。最终到达tensor中的结果还是整数。
  2. torch.cat表示拼接,拼接的时候两个tensor一个为整数,另一个为为浮点数时,整个拼接结果应该所有元素为浮点数!
Data type dtype dtype
32-bit floating point torch.float32 or torch.float torch.*.FloatTensor
64-bit floating point torch.float64 or torch.double torch.*.DoubleTensor
16-bit floating point torch.float16 or torch.half torch.*.HalfTensor+
8-bit integer (unsigned) torch.uint8 torch.*.ByteTensor
8-bit integer (signed) torch.int8 torch.*.CharTensor
16-bit integer (signed) torch.int16 or torch.short torch.*.ShortTensor
32-bit integer (signed) torch.int32 or torch.int torch.*.IntTensor
64-bit integer (signed) torch.int64 or torch.long torch.*.LongTensor
表1:dtype的参数
  • 如何改正呢,请参考这篇文章,这里给出了修正数据类型的方法,修改数据类型可以使用以下方法
a = torch,tensor((1,2,3,4))
a = a.float()
a = a.int()
a = a.half()

a = a.type(torch.uint8)

注意,没有torch.tensor.uint8()torch.tensor.int8()这两种转化方法,只有torch.tensor.type(torch.uint8)torch.tensor.type(torch.int8)

  • 当然我们初始化的时候也可以对其数据类型进行赋值
a = torch.tensor((1,2,3,4,5),dtype =torch.float)
b = torch.tensor((1,2,3,4,5), dtype = float)
c = torch.tensor((1,2,3,4,5),dtype = torch.double)
d = torch

In [35]: a
Out[35]: tensor([1., 2., 3., 4., 5.])

In [36]: a.dtype
Out[36]: torch.float32

In [37]: b
Out[37]: tensor([1., 2., 3., 4., 5.], dtype=torch.float64)

In [38]: b.dtype
Out[38]: torch.float64

In [43]: c
Out[43]: tensor([1., 2., 3., 4., 5.], dtype=torch.float64)

In [44]: c.dtype
Out[44]: torch.float64

In [45]: c = torch.tensor((1,2,3,4,5),dtype = torch.float64)

In [46]: c
Out[46]: tensor([1., 2., 3., 4., 5.], dtype=torch.float64)

这里注意,初始化时dtype = floatdtype = torch.doubledtype = torch.float64表示的是一个含义!

  • 如果要使用torch.Tensor进行赋值的话没有办法制定dtype,但是可以通过特殊的函数进行赋值
In [59]: a = torch.FloatTensor(3,4,5)

In [60]: a.dtype
Out[60]: torch.float32	

In [61]: a = torch.Tensor(3,4,5)

In [62]: a.dtype
Out[62]: torch.float32

这里表明Tensor和FloatTensor生成的数据类型是一致的!

你可能感兴趣的:(笔记,pytorch)