Pytorch中Tensor和numpy数组的互相转化

Pytorch中Tensor和Numpy数组的相互转化分为两种,第一种转化前后的对象共享相同的内存区域(即修改其中另外一个也会改变);第二种是二者并不共享内存区域。

共享内存区域的转化

这种涉及到numpy()from_numpy()两个函数。

使用numpy()函数可以将Tensor转化为Numpy数组:

a=torch.ones(5)
b=a.numpy()
print(type(a))
print(type(b))

输出:



与这个过程相反,使用from_numpy()函数可以将Numpy数组转化为Tensor:

a=np.ones(5)
b=torch.from_numpy(a)

不共享内存区域的转化

这种方式通过torch.tensor()将Numpy转化为Tensor,此时会进行数据的拷贝,返回的Tensor与原来的数据并不共享内存。

函数原型:

torch.tensor(data, *, dtype=None, device=None, requires_grad=False, pin_memory=False) → Tensor

注意事项:

torch.tensor()总会进行数据拷贝,如果我们有一个Tensor数据并且希望避免数据拷贝,那么需要使用torch.Tensor.requires_grad_()或torch.Tensor.detach()。如果我们有一个Numpy ndarray数据并且希望避免数据拷贝,那么需要使用torch.as_tensor()。

When data is a tensor x, torch.tensor() reads out ‘the data’ from whatever it is passed, and constructs a leaf variable. Therefore torch.tensor(x) is equivalent to x.clone().detach() and torch.tensor(x, requires_grad=True) is equivalent to x.clone().detach().requires_grad_(True). The equivalents using clone() and detach() are recommended.

参数:

  • data:(array_like) – Initial data for the tensor. Can be a list, tuple, NumPy ndarray, scalar, and other types.
  • dtype (torch.dtype, optional) – the desired data type of returned tensor. Default: if None, infers data type from data.
  • device (torch.device, optional) – the desired device of returned tensor. Default: if None, uses the current device for the default tensor type (see torch.set_default_tensor_type()). device will be the CPU for CPU tensor types and the current CUDA device for CUDA tensor types.
  • requires_grad (bool, optional) – If autograd should record operations on the returned tensor. Default: False.
  • pin_memory (bool, optional) – If set, returned tensor would be allocated in the pinned memory. Works only for CPU tensors. Default: False.

使用示例:

#根据输入数据创建tensor
torch.tensor([[0.1,1.2],[2.2,3.1],[4.9,5.2]])

#创建tensor时指定数据类型及device
torch.tensor([[0.1111,0.2222,0.3333]],dtype=torch.float64,device=torch.device('cuda:0'))

#创建一个空tensor
torch.tensor([])

你可能感兴趣的:(pytorch,pytorch,python,深度学习)