torch与numpy转换内存上的小坑

在学习机器学习中,tensor的各种转换是新手容易遇到的坑 ,我这里记录一下我遇到的一些坑

  1. 将numpy数据类型转换成Tensor

a = torch.ones(5)
b = a.numpy()
a.add_(1)  # 就地版本的add()
print(a)
print(b)
tensor([2., 2., 2., 2., 2.])
[2. 2. 2. 2. 2.]

torch中的add_()是就地版本的add(),这样b的值会随a变化,而若使用add() 则b的值是全1

  1. 将numpy数组转化成Torch的Tensor

import numpy as np
a = np.ones(5)
b = torch.from_numpy(a)  
np.add(a,1,out=a)
print(a)
print(b)
[2. 2. 2. 2. 2.]
tensor([2., 2., 2., 2., 2.], dtype=torch.float64)

torch.from_numpy()会创建一个tensor从numpy转化来的,返回的tensor和之前的narray共享内存,下面是官方的解释:

Creates a Tensor from a numpy.ndarray.

The returned tensor and ndarray share the same memory. Modifications to the tensor will be reflected in the ndarray and vice versa. The returned tensor is not resizable.

  1. 复制tensor数据类型的数据时候

x = torch.arange(12)
y = torch.tensor(x)

这样复制时会报UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). y = torch.tensor(x)的警告

x = torch.arange(12)
y = torch.as_tensor(x)

使用as_tensor做复制可以不报错,是官方推荐的写法

  1. torch的reshape()是返回的一个view

a = torch.arange(12)
b = a.reshape((3,4))
b[:] = 2
a
tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])

你可能感兴趣的:(numpy,python,机器学习)