最近使用pytorch开发个模型,中间遇到的bug中占比较大的一类是数据格式的转换。这里记录下转换的方式,便于后续查阅。
ndarray = np.array(list)
# -*- encoding:utf-8 -*-
import numpy as np
a = [1, 2, 3]
print(' type :{0} value: {1}'.format(type(a), a))
ndarray = np.array(a)
print(' type :{0} value: {1}'.format(type(ndarray), ndarray))
输出:
type : value: [1, 2, 3]
type : value: [1 2 3]
list = ndarray.tolist()
# -*- encoding:utf-8 -*-
import numpy as np
ndarray = np.array([1,2,3]) # list 转为 ndarray
print(' type :{0} value: {1}'.format(type(ndarray), ndarray))
list = ndarray.tolist() # ndarray 转为 list
print(' type :{0} value: {1}'.format(type(list), list))
output:
type : value: [1 2 3]
type : value: [1, 2, 3]
tensor=torch.Tensor(list)
注意:将list中元素类型为int,转换为tensor后,类型转为了float,如果希望转换为int,则需要加上类型。
常用的不同数据类型的Tensor,有32位的浮点型torch.FloatTensor, 64位浮点型 torch.DoubleTensor, 16位整形torch.ShortTensor, 32位整形torch.IntTensor和64位整形torch.LongTensor
# -*- encoding:utf-8 -*-
import torch
a = [1, 2, 3]
print(' type :{0} value: {1}'.format(type(a), a))
tensor = torch.Tensor(a) #默认为float
print(' type :{0} value: {1}'.format(type(tensor), tensor))
tensor = torch.IntTensor(a) #转为int
print(' type :{0} value: {1}'.format(type(tensor), tensor))
output:
type : value: [1, 2, 3]
type : value: tensor([1., 2., 3.])
type : value: tensor([1, 2, 3], dtype=torch.int32)
先转numpy,后转list
list = tensor.numpy().tolist()
# -*- encoding:utf-8 -*-
import torch
a = [1, 2, 3]
print(' type :{0} value: {1}'.format(type(a), a))
tensor = torch.Tensor(a)
print(' type :{0} value: {1}'.format(type(tensor), tensor))
list = tensor.numpy().tolist()
print(' type :{0} value: {1}'.format(type(list), list))
output:
type : value: [1, 2, 3]
type : value: tensor([1., 2., 3.])
type : value: [1.0, 2.0, 3.0]
ndarray = tensor.numpy()
*gpu上的tensor不能直接转为numpy
ndarray = tensor.cpu().numpy()
tensor = torch.from_numpy(ndarray)
# -*- encoding:utf-8 -*-
import numpy as np
import torch
a = [1, 2, 3]
print(' type :{0} value: {1}'.format(type(a), a))
ndarray = np.array([1,2,3]) # list 转为 ndarray
print(' type :{0} value: {1}'.format(type(ndarray), ndarray))
tensor = torch.from_numpy(ndarray)
print(' type :{0} value: {1}'.format(type(tensor), tensor))
output:
type : value: [1, 2, 3]
type : value: [1 2 3]
type : value: tensor([1, 2, 3])