我们已经知道pytorch中,tensor的实际数据是以一维数组(storage)的方式存于某个连续的内存中的。
而且,pytorch的tensor是以“行优先”进行存储的。
1. tensor的连续性
所谓tensor连续(contiguous),指的是tensor的storage元素排列顺序与其按行优先时的元素排列顺序相同。如下图所示:
之所以会出现不连续现象,本质上是由于pytorch中不同tensor可能共用同一个storage导致的。
pytorch的很多操作都会导致tensor不连续,比如tensor.transpose()(tensor.t())、tensor.narrow()、tensor.expand()。
以转置为例,因为转置操作前后共用同一个storage,但显然转置后的tensor按照行优先排列成1维后与原storage不同了,因此转置后结果属于不连续(见下例)。
2. tensor.is_contiguous()
tensor.is_contiguous()用于判断tensor是否连续,我们以转置操作为例进行说明。
>>>a = torch.tensor([[1,2,3],[4,5,6]])
>>>print(a)
tensor([[1, 2, 3],
[4, 5, 6]])
>>>print(a.storage())
1
2
3
4
5
6
[torch.LongStorage of size 6]
>>>print(a.is_contiguous()) #a是连续的
True
>>>b = a.t() #b是a的转置
>>>print(b)
tensor([[1, 4],
[2, 5],
[3, 6]])
>>>print(b.storage())
1
2
3
4
5
6
[torch.LongStorage of size 6]
>>>print(b.is_contiguous()) #b是不连续的
False
# 之所以出现b不连续,是因为转置操作前后是共用同一个storage的
>>>print(a.storage().data_ptr())
>>>print(b.storage().data_ptr())
2638924341056
2638924341056
3. tensor不连续的后果
tensor不连续会导致某些操作无法进行,比如view()就无法进行。接上例:
因为b是不连续的,所以对其进行view()操作会报错;
b.view(3,2)没报错,因为b本身的shape就是(3,2)。
>>>b.view(2,3)
RuntimeError Traceback (most recent call last)
>>>b.view(1,6)
RuntimeError Traceback (most recent call last)
>>>b.view(-1)
RuntimeError Traceback (most recent call last)
>>>b.view(3,2)
tensor([[1, 4],
[2, 5],
[3, 6]])
4. tensor.contiguous()
返回一个与原始tensor有相同元素的 “连续”tensor,如果原始tensor本身就是连续的,则返回原始tensor。
注意:tensor.contiguous()函数不会对原始数据做任何修改,他不仅返回一个新tensor,还为这个新tensor创建了一个新的storage,在这个storage上,该新的tensor是连续的。
还是接上例:
>>>c = b.contiguous()
# 形式上两者一样
>>>print(b)
>>>print(c)
tensor([[1, 4],
[2, 5],
[3, 6]])
tensor([[1, 4],
[2, 5],
[3, 6]])
# 显然storage已经不是同一个了
>>>print(b.storage())
>>>print(c.storage())
1
2
3
4
5
6
[torch.LongStorage of size 6]
1
4
2
5
3
6
[torch.LongStorage of size 6]
False
# b不连续,c是连续的
>>>print(b.is_contiguous())
False
>>>print(c.is_contiguous())
True
#此时执行c.view()不会出错
>>>c.view(2,3)
tensor([[1, 4, 2],
[5, 3, 6]])