PyTorch 笔记(11)— Tensor内部存储结构(头信息区 Tensor,存储区 Storage)

1. Tensor 内部存储结构

tensor 数据结构如下图所示,tensor 分为头信息区(Tensor)和存储区 (Storage),信息区主要保存着 Tensor 的形状(size)、步长(stride)、数据类型(type )等信息,而真正的数据则保存成连续的数组。

由于数据动辄成千数万,因此信息区占用内存较少,主要内存占用取决于 tensor 中元素数目,即存储区大小。

PyTorch 笔记(11)— Tensor内部存储结构(头信息区 Tensor,存储区 Storage)_第1张图片

2. 存储区

一般来说,一个 tensor 有着与之对应的 storagestorage 是在 data 之上封装的接口。不同 tensor 的头信息一般不同,但却可能使用相同的 storage

In [25]: a = t.arange(0, 6)

In [26]: a
Out[26]: tensor([0, 1, 2, 3, 4, 5])

In [27]: a.storage()
Out[27]: 
 0
 1
 2
 3
 4
 5
[torch.LongStorage of size 6]

In [28]: a.storage_offset()
Out[28]: 0

In [29]: a.storage_type()
Out[29]: torch.LongStorage

In [30]: b = a.view(2,3)

In [31]: b.storage()
Out[31]: 
 0
 1
 2
 3
 4
 5
[torch.LongStorage of size 6]

In [32]:

可以通过 Pythonid 值来判断它们在内存中的地址是否相等。

In [32]: id(a) == id(b)
Out[32]: False

In [33]: id(a.storage) == id(b.storage)
Out[33]: True

In [34]: 

可以看到 abstorage 是相等的。

a 改变,b 也跟着改变,因为它们是共享 storage 的。

In [34]: a[1] = 100

In [35]: b
Out[35]: 
tensor([[  0, 100,   2],
        [  3,   4,   5]])

数据和存储区的差异,c 的数据是切片之后的值,而存储区仍然是整体的值。

In [36]: c = a[2:]

In [37]: c
Out[37]: tensor([2, 3, 4, 5])

In [43]: c.data
Out[43]: tensor([2, 3, 4, 5])

In [38]: c.storage()
Out[38]: 
 0
 100
 2
 3
 4
 5
[torch.LongStorage of size 6]

In [39]: c.data_ptr(), a.data_ptr()
Out[39]: (132078864, 132078848)

In [40]: c.data_ptr() - a.data_ptr()
Out[40]: 16

data_ptr() 返回 tensor 首元素的内存地址,ca 的内存地址相差 16,也就是两个元素,即每个元素占用 8 个字节(LongStorage)。

c[0] 的内存地址对应 a[2] 内存地址。

In [44]: c[0] = -100

In [45]: a
Out[45]: tensor([   0,  100, -100,    3,    4,    5])

a.storageb.storagec.storage 三者的内存地址是相等的。

In [49]: id(a.storage) == id(b.storage) == id(c.storage)
Out[49]: True

In [51]: a.storage_offset(), b.storage_offset(), c.storage_offset()
Out[51]: (0, 0, 2)

In [52]: 

查看偏移位。

In [55]: d = b[::2, ::2]

In [56]: d
Out[56]: tensor([[   0, -100]])

In [57]: b
Out[57]: 
tensor([[   0,  100, -100],
        [   3,    4,    5]])

In [58]: id(d.storage) == id(b.storage)
Out[58]: True

In [59]: b.stride()
Out[59]: (3, 1)

In [60]: d.stride()
Out[60]: (6, 2)

In [61]: d.is_contiguous()
Out[61]: False

由此可见,绝大多数操作并不修改 tensor 的数据,只是修改了 tensor 的头信息,这种做法更节省内存,同时提升了处理速度。此外,有些操作会导致 tensor 不连续,这时需要调用 torch.contiguous 方法将其变成连续的数据,该方法会复制数据到新的内存,不在与原来的数据共享 storage

高级索引一般不共享 storage ,而普通索引共享 storage ,是因为普通索引可以通过修改 tensoroffsetstridesize 实现,不修改 storage 的数据,而高级索引则不行。

你可能感兴趣的:(PyTorch)