Returns a contiguous tensor containing the same data as self tensor.
返回一个与原始tensor相同元素数据的 “连续”tensor类型
If self tensor is contiguous, this function returns the self tensor.
如果原始tensor本身就是连续的,则返回原始tensor
定义本身有两个重要的点:
对原始tensor进行复制
返回contiguous“类型”的一个tensor
Tensor.contiguous()函数不会对原始数据进行任何修改,而仅仅对其进行复制,并在内存空间上进行对齐,即在内存空间上,tensor元素的内存地址保持连续。
这么做的目的是,在对tensor元素进行转换和维度变换等操作之后,元素地址在内存空间中保证连续性,在后续利用指针对tensor元素进行读取时,能够减少读取便利,提高内存空间优化。
import torch
src_t = torch.randn((2,3))
print(src_t.shape)
print(src_t.is_contiguous())
输出:
>>> torch.Size([2, 3])
>>> True
可以看出,在利用torch.randn函数进行tensor创建时,获取的tensor元素地址是连续内存空间保存的。那么,如果对创建的tensor进行transpose变换操作:
trans_t = src_t.transpose(0,1)
print(trans_t.shape)
print(trans_t.is_contiguous())
输出:
>>> torch.Size([3, 2])
>>> False
我们发现经过transpose变换以后,tensor变成非连续保存类型(uncontiguous)。
那么,变成这种非连续保存类型会造成什么样的后果呢?
简单的以view函数为例:
trans_t.view(-1,3)
当尝试对uncontiguous类型tensor进行维度变换时,就会出现下面错误:
Traceback (most recent call last):
File "" , line 1, in <module>
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
错误提示告诉我们,至少有一个维度数据在内存空间上跨越了两个连续子空间!此时,我们输出trans_t的连续保存类型是什么:
print(trans_t.is_contiguous())
>>> False
因此,为了能够实现对张量trans_t的维度变换,需要先对tensor进行contiguous内存地址对齐操作,然后再进行view操作:
print(trans_t.shape)
trans_t.contiguous().view(-1,3)
print(trans_t.shape)
>>> torch.Size([3, 2])
>>> torch.Size([2, 3])
总结一下,为了保证代码的可读性和严谨性,当对tensor进行维度变化时,常需要配合contiguous函数使用,但是哪些函数会造成原始tensor变的uncontiguous呢?
transpose()
narrow()
expand()
有其他函数,我会进一步补充,有错误欢迎指正,谢谢!