Pytorch中‘内存共享’和‘内存连续’特性总结

文章目录

  • 前言
  • 1、前置基础知识
    • 1.1.Tensor的结构
    • 1.2.内存共享和内存连续API介绍
  • 2、内存连续性
    • 2.1.维度变换操作(transpose, permute)
    • 2.2.view和reshape
    • 2.3.维度拼接:cat和stack op
    • 2.4. squeeze()和unsqueeze()
    • 2.5. expand 和 repeat
    • 2.6. numpy和from_numpy内存共享
    • 2.7.切片
  • 总结


前言

 本文旨在记录pytorch的API如何影响Tensor运算的‘内存共享性’和‘内存连续性’。’内存共享‘可以理解为浅拷贝;’内存连续’就是Tensor在信息区的内存空间上的连续性。 本文会结合代码介绍pytorch中的op是如何影响这两个性质的。

1、前置基础知识

1.1.Tensor的结构

 因为涉及到Tensor的性质,因此,本节先简单回顾下Tensor的数据结构,Tensor包含信息区和存储区。信息区包含Tensor的一些维度信息(比如一个Tensor的shape=(2,3),变成(3,2),张量内容没变,变得只是我们看待这个张量的视角);存储区则是存储着数据。
Pytorch中‘内存共享’和‘内存连续’特性总结_第1张图片

深拷贝自然会同时拷贝两个区的内容;而维度变换操作往往仅影响信息区的内容,是为了减少张量计算中频繁的拷贝操作

1.2.内存共享和内存连续API介绍

 大家可先扫一眼下面的代码:这里简单介绍两个API,is_contiguous()能够判断一个Tensor的**信息区**上是否‘内存连续’;.data_ptr()能够返回张量在内存空间上的地址,可用于判断两个张量是否‘内存共享’。

# case 1: share contiguous And deepCopy?
x = torch.tensor([1,2,3], dtype=torch.float32)
y = x     # shallow copy 

print(x.is_contiguous(), y.is_contiguous())   # True, True
print(x.data_ptr() == y.data_ptr())           # True

# 若y发生额外的运算,此时pytorch会额外开辟新的内存,即转化成深拷贝!
y = y + 1                                     # x = [1,2,3], y = [2,3,4]
print(x.data_ptr() == y.data_ptr())           # False

 我说下结论:说到底是python的语言特性。1)大多数赋值操作 = 全是浅拷贝,比如(y = x),因此,张量x和y内存连续且内存共享。也就是说:由于发生的是浅拷贝,即当我们对y做了某些op后,对应的x的值也会发生变化。2)但千万不能对y做运算(比如y = y +1),此时就由浅拷贝转化成了深拷贝,即python内部会自动开辟一块新的内存来存储y,即此时x和y各自内存连续但已经不共享内存了。(可能会forward的计算图产生影响)。
 本文会在第2部分介绍一些pytorch中哪些op会对Tensor的内存连续性产生影响;在第3部分介绍pytorch中哪些op会对Tensor的内存共享性产生影响。

2、内存连续性

2.1.维度变换操作(transpose, permute)

# -------------- transpose op ------------------ #
# transpose op
x = torch.arange(0,6).view(2,3)
print(x.is_contiguous())            # True
y = x                               # shallow copy
y = y.transpose(0,1)      # 张量的信息区发生变更,但存储区没发生变更,x和y共用一块存储区
             
# True: dont destropy ;False: transpose op destroy share contigouse
print(x.is_contiguous, y.is_contiguous()) 
print(x.data_ptr() == y.data_ptr()) # True : share meomery
# -------------- permute op -------------------- #
# permute op
x = torch.arange(0,6).view(2,3)
print(x.is_contiguous())            # True
y = x                               # shallow copy
y = y.permute(1,0)

# True: dont destropy ;False: transpose op destroy share contigouse
print(x.is_contiguous, y.is_contiguous())  
print(x.data_ptr() == y.data_ptr()) # True : share meomery

 上述代码是pytorch中两个常用的维度变换op:transpose和permute。从上述代码可以看出,二者都会破坏了原始张量的内存连续性,更准确的说是破坏了信息区的内存连续性。但由于y是由x浅拷贝过来的,所以y和x共用一块存储区。
例外!!!这里有个例外就是存在dim=1的Tensor:若某Tensor的shape=(1,2,3),则调用transpose()/permute()时只有在交换后的维度的非0相对dim没变情况下才不会破坏信息区的内存连续性,即is_contiguous() == True;若破坏了非0dim的相对位置,则is_contiguous() == False。举个例子:比如交换后的shape变成(1,3,2)/(3,2,1)/(3,1,2),则破坏了内存连续性;若交换后shape变成(2,3,1)/(2,1,3),则依旧内存连续。

2.2.view和reshape

 pytorch中另外两个常用的维度变化操作就是:view和reshape。先贴两段code,看是如何影响内存连续性和内存共享性的。

# ------- view op need contiguous----- #
x = torch.arange(0,6).view(2,3)
y = x.permute(1,0)
# Error: permute导致信息区的内存不连续,view操作会报错
y = y.view(2,3)
# -------- reshape op ---------------- #
x = torch.arange(0,6).view(2,3)
y = x.transpose(0,1)     # y的信息区不连续
y = y.reshape(2,3)       # 效果 == y.contiguous().view(2,3)
print(x.is_contiguous(), y.is_contiguous()) # true, true
print(x.data_ptr() == y.data_ptr())         # false

 长话短说:在调用transpose和permute操作后,会破坏张量在信息区的内存连续性。而view操作需要张量的内存连续,否则会报错!而reshape则可以无脑使用:1)若Tensor本来内存连续,则调用reshape操作相当于调用view,并不会深拷贝源张量;2)若Tensor内存不连续,则reshape操作会首先深拷贝一份张量使其连续,然后在进行view操作。其效果等同于.contiguous().view(2,3)。
总的来说:view op不会深拷贝张量但需要内存连续;reshape op在张量内存不连续情况下会发生深拷贝!还有别忘了:.contiguous()方法会对张量进行深拷贝。

2.3.维度拼接:cat和stack op

# -------- torch.cat op ----------- #
x = torch.tensor([[1,2,3]], dtype=torch.float32)
y = torch.tensor([[4,5,6]], dtype=torch.float32)
z = torch.cat((x, y), dim=0)
print(z.data_ptr() == x.data_ptr()) # False
# -------- torch.stack op --------- #
v = torch.stack((x,y),dim = 0)
print(v.data_ptr() == x.data_ptr()) # False

 这两个比较容易理解,拼接产生新的张量自然会开辟新的内存,且内存连续。

2.4. squeeze()和unsqueeze()

# -------- torch.squeeze op ----------- #
x = torch.tensor([[1,2,3]])
y = x.squeeze()
print(y.data_ptr() == x.data_ptr()) # True
# -------- torch.unsqueeze op --------- #
z = x.unsqueeze(0)
print(z.data_ptr() == x.data_ptr()) # True

 一句话:squeeze()和unsqueeze()共享内存。而内存连续性则和源张量保持一致,即x是内存连续,则y和z也是内存连续;x不连续则y和z也不一致。

2.5. expand 和 repeat

x = torch.arange(0,6).view(1,2,3)
y = x.permute(0,2,1)
print(y.is_contiguous())   # False
y = y.expand(size=(2,3,2))
print(y.is_contiguous())   # False
print(y.data_ptr() == x.data_ptr()) # True

 有了前面基础,这两个op就容易了,都是复制张量。expand内存共享;而repeat会深拷贝,内存不共享。内存连续性和源张量保持一致。
 这里注意下expand,即内存共享,也就是说,pytorch调用expand后实际上并没有在内存中开辟新的内存存储数据。你将y的值进行修改的结果会同时把x的值也更改掉。

# ------------- expand op ----------------- #
x = torch.arange(0,6).view(1,2,3)
y = x
y = y.expand(size=(2,2,3))
print(y.data_ptr() == x.data_ptr()) # True

y[1][0][0] = 100
print(x[0][0])  # [100,1,2]
print(y[0][0])	# [100,1,2]
# ------------ repeat op ------------------ #
x = torch.arange(0,6).view(1,2,3)
y = x
y = y.repeat(repeats=(2,1,1))
print(y.data_ptr() == x.data_ptr()) # True

y[1][0][0] = 100
print(x[0][0])  # [0,1,2]
print(y[1][0])  # [100,1,2]

2.6. numpy和from_numpy内存共享

2.7.切片

x = torch.arange(0,6).view(1,2,3)
y = x[0][0]
print(y.data_ptr() == x.data_ptr()) # True

y[0] = 100
print(x)  # [100,1,2]

 浅拷贝,改变y的值会同时改变x的值。

总结

 写这种文章阅读量不高,但是这些看似不起眼的知识往往会造成意想不到的错误。本文讲的是一些Tensor的偏底层的知识,而在用pytorch搭建神经网络过程中何时采用深浅拷贝,比如clone和detach等op,会对网络的梯度训练产生何种影响呢?敬请期待后续文章。

你可能感兴趣的:(pytorch源码解读,pytorch,python,深度学习)