PyTorch碎片:深刻透彻理解Torch中Tensor.contiguous()函数

1.函数定义

Returns a contiguous tensor containing the same data as self tensor.

返回一个与原始tensor相同元素数据的 “连续”tensor类型

If self tensor is contiguous, this function returns the self tensor.

如果原始tensor本身就是连续的,则返回原始tensor

2.定义理解

定义本身有两个重要的点:

  • 对原始tensor进行复制

  • 返回contiguous“类型”的一个tensor

Tensor.contiguous()函数不会对原始数据进行任何修改,而仅仅对其进行复制,并在内存空间上进行对齐,即在内存空间上,tensor元素的内存地址保持连续。

这么做的目的是,在对tensor元素进行转换和维度变换等操作之后,元素地址在内存空间中保证连续性,在后续利用指针对tensor元素进行读取时,能够减少读取便利,提高内存空间优化。

3.数据案例分析

import torch
src_t = torch.randn((2,3))
print(src_t.shape)
print(src_t.is_contiguous())

输出:

>>> torch.Size([2, 3])
>>> True

可以看出,在利用torch.randn函数进行tensor创建时,获取的tensor元素地址是连续内存空间保存的。那么,如果对创建的tensor进行transpose变换操作:

trans_t = src_t.transpose(0,1) 
print(trans_t.shape)
print(trans_t.is_contiguous())

输出:

>>> torch.Size([3, 2])
>>> False

我们发现经过transpose变换以后,tensor变成非连续保存类型(uncontiguous)。

那么,变成这种非连续保存类型会造成什么样的后果呢?

简单的以view函数为例:

trans_t.view(-1,3)

当尝试对uncontiguous类型tensor进行维度变换时,就会出现下面错误:

Traceback (most recent call last):
File "", line 1, in <module>
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

错误提示告诉我们,至少有一个维度数据在内存空间上跨越了两个连续子空间!此时,我们输出trans_t的连续保存类型是什么:

print(trans_t.is_contiguous())
>>> False

因此,为了能够实现对张量trans_t的维度变换,需要先对tensor进行contiguous内存地址对齐操作,然后再进行view操作:

print(trans_t.shape)
trans_t.contiguous().view(-1,3)
print(trans_t.shape)
>>> torch.Size([3, 2])
>>> torch.Size([2, 3])

4.总结

总结一下,为了保证代码的可读性和严谨性,当对tensor进行维度变化时,常需要配合contiguous函数使用,但是哪些函数会造成原始tensor变的uncontiguous呢?

  • transpose()

  • narrow()

  • expand()

有其他函数,我会进一步补充,有错误欢迎指正,谢谢!

你可能感兴趣的:(PyTorch碎片,contiguous,pytorch,深度学习)