Tensor和Numpy数组之间具有很高的相似性,彼此之间的互操作也非常简单高效。需要注意的是,Numpy和Tensor共享内存。由于Numpy历史悠久,支持丰富的操作,所以当遇到Tensor不支持的操作时,可先转成Numpy数组,处理后再转回tensor,其转换开销很小。
In [90]:
import numpy as np a = np.ones([2, 3],dtype=np.float32) a
Out[90]:
array([[1., 1., 1.], [1., 1., 1.]], dtype=float32)
In [91]:
b = t.from_numpy(a) b
Out[91]:
tensor([[1., 1., 1.], [1., 1., 1.]])
In [92]:
b = t.Tensor(a) # 也可以直接将numpy对象传入Tensor b
Out[92]:
tensor([[1., 1., 1.], [1., 1., 1.]])
In [93]:
a[0, 1]=100 b
Out[93]:
tensor([[ 1., 100., 1.], [ 1., 1., 1.]])
In [94]:
c = b.numpy() # a, b, c三个对象共享内存 c
Out[94]:
array([[ 1., 100., 1.], [ 1., 1., 1.]], dtype=float32)
注意: 当numpy的数据类型和Tensor的类型不一样的时候,数据会被复制,不会共享内存。
In [95]:
a = np.ones([2, 3]) # 注意和上面的a的区别(dtype不是float32) a.dtype
Out[95]:
dtype('float64')
In [96]:
b = t.Tensor(a) # 此处进行拷贝,不共享内存 b.dtype
Out[96]:
torch.float32
In [97]:
c = t.from_numpy(a) # 注意c的类型(DoubleTensor) c
Out[97]:
tensor([[1., 1., 1.], [1., 1., 1.]], dtype=torch.float64)
In [98]:
a[0, 1] = 100 b # b与a不共享内存,所以即使a改变了,b也不变
Out[98]:
tensor([[1., 1., 1.], [1., 1., 1.]])
In [99]:
c # c与a共享内存
Out[99]:
tensor([[ 1., 100., 1.], [ 1., 1., 1.]], dtype=torch.float64)
注意: 不论输入的类型是什么,t.tensor都会进行数据拷贝,不会共享内存
In [100]:
tensor = t.tensor(a)
In [101]:
tensor[0,0]=0 a
Out[101]:
array([[ 1., 100., 1.], [ 1., 1., 1.]])
广播法则(broadcast)是科学运算中经常使用的一个技巧,它在快速执行向量化的同时不会占用额外的内存/显存。 Numpy的广播法则定义如下:
PyTorch当前已经支持了自动广播法则,但是笔者还是建议读者通过以下两个函数的组合手动实现广播法则,这样更直观,更不易出错:
unsqueeze
或者view
,或者tensor[None],:为数据某一维的形状补1,实现法则1expand
或者expand_as
,重复数组,实现法则3;该操作不会复制数组,所以不会占用额外的空间。注意,repeat实现与expand相类似的功能,但是repeat会把相同数据复制多份,因此会占用额外的空间。
In [102]:
a = t.ones(3, 2) b = t.zeros(2, 3,1)
In [103]:
# 自动广播法则 # 第一步:a是2维,b是3维,所以先在较小的a前面补1 , # 即:a.unsqueeze(0),a的形状变成(1,3,2),b的形状是(2,3,1), # 第二步: a和b在第一维和第三维形状不一样,其中一个为1 , # 可以利用广播法则扩展,两个形状都变成了(2,3,2) a+b
Out[103]:
tensor([[[1., 1.], [1., 1.], [1., 1.]], [[1., 1.], [1., 1.], [1., 1.]]])
In [104]:
# 手动广播法则 # 或者 a.view(1,3,2).expand(2,3,2)+b.expand(2,3,2) a[None].expand(2, 3, 2) + b.expand(2,3,2)
Out[104]:
tensor([[[1., 1.], [1., 1.], [1., 1.]], [[1., 1.], [1., 1.], [1., 1.]]])
In [105]:
# expand不会占用额外空间,只会在需要的时候才扩充,可极大节省内存 e = a.unsqueeze(0).expand(10000000000000, 3,2)