tensor的连续性、tensor.is_contiguous()、tensor.contiguous()

我们已经知道pytorch中,tensor的实际数据是以一维数组(storage)的方式存于某个连续的内存中的
而且,pytorch的tensor是以“行优先”进行存储的。

1. tensor的连续性

所谓tensor连续(contiguous),指的是tensor的storage元素排列顺序与其按行优先时的元素排列顺序相同。如下图所示:

新建 Microsoft Visio 绘图.png

之所以会出现不连续现象,本质上是由于pytorch中不同tensor可能共用同一个storage导致的。
pytorch的很多操作都会导致tensor不连续,比如tensor.transpose()(tensor.t())、tensor.narrow()、tensor.expand()。
以转置为例,因为转置操作前后共用同一个storage,但显然转置后的tensor按照行优先排列成1维后与原storage不同了,因此转置后结果属于不连续(见下例)。

2. tensor.is_contiguous()

tensor.is_contiguous()用于判断tensor是否连续,我们以转置操作为例进行说明。

>>>a = torch.tensor([[1,2,3],[4,5,6]])
>>>print(a)
tensor([[1, 2, 3],
        [4, 5, 6]])
>>>print(a.storage())
 1
 2
 3
 4
 5
 6
[torch.LongStorage of size 6]
>>>print(a.is_contiguous()) #a是连续的
True

>>>b = a.t() #b是a的转置
>>>print(b)
tensor([[1, 4],
        [2, 5],
        [3, 6]])
>>>print(b.storage())
 1
 2
 3
 4
 5
 6
[torch.LongStorage of size 6]
>>>print(b.is_contiguous()) #b是不连续的
False

# 之所以出现b不连续,是因为转置操作前后是共用同一个storage的
>>>print(a.storage().data_ptr())
>>>print(b.storage().data_ptr())
2638924341056
2638924341056

3. tensor不连续的后果

tensor不连续会导致某些操作无法进行,比如view()就无法进行。接上例:
因为b是不连续的,所以对其进行view()操作会报错;
b.view(3,2)没报错,因为b本身的shape就是(3,2)。

>>>b.view(2,3)
RuntimeError                              Traceback (most recent call last)
>>>b.view(1,6)
RuntimeError                              Traceback (most recent call last)
>>>b.view(-1)
RuntimeError                              Traceback (most recent call last)

>>>b.view(3,2)
tensor([[1, 4],
        [2, 5],
        [3, 6]])

4. tensor.contiguous()

返回一个与原始tensor有相同元素的 “连续”tensor,如果原始tensor本身就是连续的,则返回原始tensor
注意:tensor.contiguous()函数不会对原始数据做任何修改,他不仅返回一个新tensor,还为这个新tensor创建了一个新的storage,在这个storage上,该新的tensor是连续的。
还是接上例:

>>>c = b.contiguous()

# 形式上两者一样
>>>print(b)
>>>print(c)
tensor([[1, 4],
        [2, 5],
        [3, 6]])
tensor([[1, 4],
        [2, 5],
        [3, 6]])

# 显然storage已经不是同一个了
>>>print(b.storage())
>>>print(c.storage())
 1
 2
 3
 4
 5
 6
[torch.LongStorage of size 6]
 1
 4
 2
 5
 3
 6
[torch.LongStorage of size 6]
False

# b不连续,c是连续的
>>>print(b.is_contiguous())
False
>>>print(c.is_contiguous())
True

#此时执行c.view()不会出错
>>>c.view(2,3)
tensor([[1, 4, 2],
        [5, 3, 6]])

你可能感兴趣的:(tensor的连续性、tensor.is_contiguous()、tensor.contiguous())