动手学深度学习——数据操作之ndarray与tensor间的转换

为什么可以转换

无论使用哪个深度学习框架,它的张量类(在MXNet中为ndarray, 在PyTorch和TensorFlow中为tensor)都与Numpy的ndarray类似。 但深度学习框架又比Numpy的ndarray多一些重要功能:

首先,GPU很好地支持加速计算,而NumPy仅支持CPU计算;
其次,这些模块下的张量类支持自动微分;

这些功能使得张量类更适合深度学习。

如何实现转换

将深度学习框架定义的张量转换为Numpy张量(ndarray)很容易,反之也同样容易。
以PyTorch为例,转换期间Torch类的张量和Numpy的数组底层内存共享,原地操作更改一个张量也会同时更改另一个张量。

import torch
import numpy
A = torch.arange(12, dtype=torch.float32).reshape((3,4))
B = A.detach().numpy()  # tensor转换为ndarray
C = torch.from_numpy(B) # ndarray转换为tensor
type(A),type(B),type(C)

结果:
(torch.Tensor , numpy.ndarray , torch.Tensor)

print(A)
print(B)
print(C)
B += 5
print(A)
print(B)
print(C)

上半部分输出:

tensor([[ 0., 1., 2., 3.],
     [ 4., 5., 6., 7.],
     [ 8., 9., 10., 11.]])

[[ 0. 1. 2. 3.]
[ 4. 5. 6. 7.]
[ 8. 9. 10. 11.]]

tensor([[ 0., 1., 2., 3.],
     [ 4., 5., 6., 7.],
     [ 8., 9., 10., 11.]])


下半部分输出:

tensor([[ 5., 6., 7., 8.],
     [ 9., 10., 11., 12.],
     [13., 14., 15., 16.]])

[[ 5. 6. 7. 8.]
[ 9. 10. 11. 12.]
[13. 14. 15. 16.]]

tensor([[ 5., 6., 7., 8.],
     [ 9., 10., 11., 12.],
     [13., 14., 15., 16.]])


可以看到当我们修改B中数据时,A与D中数据被一同修改了,也就是说使用使用pytorch框架下的方法进行tensor和numpy的转换时,底层内存共享了。

易混淆的地方

注意上面的numpy转化为tensor的操作要与下面使用numpy生成tensor的操作区分!

D = torch.tensor(B)     # 利用ndarray生成tensor
print(B)
print(D)
B += 5
print(B)
print(D)

上半部分输出:

[[ 0. 1. 2. 3.]
[ 4. 5. 6. 7.]
[ 8. 9. 10. 11.]]

tensor([[ 0., 1., 2., 3.],
     [ 4., 5., 6., 7.],
     [ 8., 9., 10., 11.]])


下半部分输出:

[[ 5. 6. 7. 8.]
[ 9. 10. 11. 12.]
[13. 14. 15. 16.]]

tensor([[ 0., 1., 2., 3.],
     [ 4., 5., 6., 7.],
     [ 8., 9., 10., 11.]])


我们注意到D中数据却没有发生变化,torch.tensor()是用于tensor对象的创建,类似于numpy中的numpy.array()。上述操作实质是利用ndarray对象B产生了一个新的tensor对象D并开辟了一块新的内存空间

验证内存共享

我们可以简单验证一下上述四个对象的内存使用情况,tensor对象可以通过storage()方法获取真实数据的存储区:

print(id(A.storage()),A.storage().data_ptr())
print(id(B),B.ctypes.data)
print(id(C.storage()),C.storage().data_ptr())
print(id(D.storage()),D.storage().data_ptr())

输出结果:
2124245213504 2125774289536
2124245171536 2125774289536
2124208164672 2125774287616
2124245213824 2125774289536

我们可以看到,ABCD四个对象的id号不相同,说明他们是内存中四个不同的变量。而ABD这三个对象的数据指针指向了内存中的同一地址,说明他们使用的数据是内存中的同一区域,而对象C则使用了不同的内存区域。


(PS:4月30日更新)

你可能感兴趣的:(深度学习笔记,pytorch,深度学习,python)