从 Pytorch tensor 存储空间的连续性 (contiguous) 说到 4D tensor 的存储格式 (memory_format)

起源:令人一头雾水的 contiguous

刚开始阅读 Pytorch 代码的时候,碰到别的作者处理 tensor,有时候会在后面加上 contiguous,就觉得十分奇怪,不明白 contiguous 的含义,看了相关的解释 之后,好像理解了一点,就这浅薄的理解做点笔记,总结一下。

概念解释:什么是 contiguous

如果让我用一句话概括 tensor 的 contiguous 特性,我会说:“tensor 的元素按维度储存在连续的内存地址中”,什么是按维度呢?以 (H, W) 的2维 tensor (这里我习惯用 H 和 W 表示 spatial size,H 是第0维,W是第1维) 举例,如果新建一个 tensor,计算机在存储的时候是先沿着最后一个维度也就是 W 维度存储,然后再沿着 H 维度存储。也许这个说法不太严谨或不好理解,可以参考 Pytorch 的官方文档,那里面有张存储 4D tensor 的图很形象地展示了这种存储特点。

举例分析:什么时候需要特别注意 tensor 的 contiguous 特性

说一千,道一万,最重要的还是应用场景。最直接的一个问题是:“我们什么时候需要特别注意 contiguous 特性呢?” 如果非要我说一个答案的话,我只能说到目前为止,我只知道在进行 view 操作之前需要注意一下 tensor 的 contiguous 特性,看了很多的解释,主要是说 view 操作要求 tensor 具有相应的 contiguous 特性。好吧,其实我看到的解释是 view 操作要求 tensor 是 contiguous,但我又发现了一些比较特殊的点,所以加了修饰词 “相应的”。

代码示例:

import torch
## 建立一个 size 是 (4, 3) 的 tensor
a = torch.arange(1, 13).view(4, 3)
## 对 a 进行转置操作
a_t = a.t()
print(a.is_contiguous()) ## True
print(a_t.is_contiguous()) ## False
a_t_v = a_t.view(-1) ## 会报错

通过上面的例子我们可以看到,对一个 tensor 进行转置操作之后会改变它的 contiguous 特性,然后再使用 view 操作就会报错。这时候我们就需要在 view 操作之前把转置之后的 tensor 变成 contiguous。

代码更改:

import torch
## 建立一个 size 是 (4, 3) 的 tensor
a = torch.arange(1, 13).view(4, 3)
## 对 a 进行转置操作并把转置之后的 tensor 变连续
a_t = a.t().contiguous()
print(a.is_contiguous()) ## True
print(a_t.is_contiguous()) ## True
a_t_v = a_t.view(-1) ## 正常运行

这时候我们不禁思考,为什么转置之后 tensor 就不连续了呢?tensor.contiguous() 又是怎么使 tensor 变的连续的呢? Pytorch关于 view 操作的官方解释 可以回答这些问题,这里我也可以简单的解释一下。我们也许都知道,Python 为了节省内存,很多操作都是不会重新分配地址的,而是在原地址上进行,这一特性在我使用 numpy 的时候就深刻的感受到了,如果想确保分配一个新的地址,需要使用 copy 或者 clone 等操作。而 Pytorch 与 numpy 一脉相承 (这里需要说一下, Pytorch 的很多操作我们都可以在 numpy 里面找到对应,就像王者荣耀和英雄联盟的关系),很多操作都是在原地址上进行的。在上面的代码示例中, 执行 transpose 操作的时候系统就不会另外分配内存,所以示例中的 a 和 a_t 是共用一块内存的,只不过读取内存的顺序有所不同。因此,对 a 来说,数据是按维度存储在内存里的,但对 a_t 来说,它的维度改变了,数据存的方式却并没有改变,这就解释了为什么 a 是 contiguous,而 a_t 不是 contiguous。再解释对 a_t 使用 tensor.contiguous(),它的作用是为 a_t 申请一块新的内存,并让数据按照 a_t 的维度进行存储。它们之间的差别可以通过读取 tensor 的 stride 来区分 (至于 stride 的含义,参见最开始提到的链接)。

代码示例:

import torch
## 建立一个 size 是 (4, 3) 的 tensor
a = torch.arange(1, 13).view(4, 3)
print(a.stride()) ## (3, 1)
## 对 a 进行转置操作
a_t = a.t()
print(a_t.stride()) ## (1, 3)
a_t = a_t.contiguous()
print(a_t.stride()) ## (4, 1)

上文说到,view 操作要求 tensor 具有相应的 contiguous 特性,其实并不是说 tensor 一定要是 contiguous 才能进行 view 操作,而是 view 操作执行的维度上数据一定是要按这个维度进行存储的(恩,好像有点绕)。

代码示例:

import torch
N, C, H, W = 2, 3, 4, 4
## 创建一个 tensor,格式是 N,C,H,W
a = torch.zeros(N, C, H, W)
## 把 tensor 的格式转换成 N,H,W,C
a_t = a.permute(0, 2, 3, 1) 
## 在 H,W 维度上进行 view 操作
a_t_v_1 = a_t.view(N, -1, C) ## 正常运行
## 在 W,C 维度上进行 view 操作
a_t_v_2 = a_t.view(N, H, -1) ## 会报错

在上面的代码示例中,我们建立了一个格式为 (N, C, H, W) 的 tensor a,然后把它转换成格式为 (N, H, W, C) 的 tensor a_t,此时2个 tensor 共用地址,所以 a_t 不是 contiguous 的,如果这时在 (H, W) 的维度上对 a_t 进行 view 操作,是没有问题的,因为这2个维度上的数据存储是 contiguous 的;但如果在 (W, C) 的维度上进行 view 操作,则会报错,因为不是 contiguous。

拓展提高:Pytorch 4D tensor 的存储格式

在上文中我们已经大概知道了 contiguous 是怎么回事,这一部分来聊一聊 Pytorch 4D tensor 的存储格式。或许我们知道,在深度学习中 Pytorch 需要 tensor 以 (N, C, H, W) 的格式输入,而在 tensorflow 中,输入的 tensor 格式是 (N, H, W, C)。在 Pytorch 的官方文档中对 4D tensor 提供了另一种存储格式,即 torch.channels_last。我们输入一个格式为 (N, C, H, W) 的 tensor,一般来说,系统会先沿着 W 存储,然后再沿着 H 存储,依次进行,这时候的 tensor 是具有 contiguous 特性的。但是我们也可以让这个 tensor 以 channels_last 的方式进行存储,先沿着 C 存储,然后再沿着 W, H, N 依次存储,但以这种方式存储的 tensor 不是 contiguous 的。但请注意,不论以何种方式进行存储,Pytorch 要求 tensor 格式必须是 (N, C, H, W)

代码示例:

import torch
N, C, H, W = 2, 3, 4, 4
images = torch.zeros(2, 3, 4, 4)
print(images.is_contiguous()) ## True
images_cl = images.contiguous(memory_format=torch.channels_last)
print(images_cl.is_contiguous()) ## False
print(images_cl.is_contiguous(memory_format=torch.channels_last)) ## True
images_cl = images_cl.contiguous()
print(images_cl.is_contiguous()) ## True
print(images_cl.is_contiguous(memory_format=torch.channels_last)) ## False

在上面的示例中,我们新建一个格式为 (N, C, H, W) 的 tensor 并且让它按 contiguous 的格式存储(默认存储格式),这时候我们也可以改变它的存储格式为 channels_last,但它就不再是 contiguous 了。如果我们把它变回 contiguous,它又自动不是 channels_last 了。所以一个 4D tensor,它的存储格式可以既不是 contiguous,又不是 channels_last,或者存储格式是2者中的一个,但绝不可能既是 contiguous,又是 channels_last。
你或许觉得这有啥用呢?为什么要专门对 4D tensor 设计2种存储格式?这方面我没有过多探究,但据官方测验 (还是这个链接),对于有些开源网络,使用 channels_last 格式存储的 tensor 能让运行速度提升 22%,所以存在即合理。

总结

上文分析了 tensor 的 contiguous 特性和该特性的应用场景、注意事项,并延伸到 4D tensor 的存储格式。本人水平有限,如果有说的不对的地方,希望列位看官不吝赐教。

你可能感兴趣的:(Pytorch,python,人工智能,编程语言)