Pytorch LSTM内存泄漏问题

环境:CentOS 7.9.2009,Python 3.8.12,torch 1.10.1+cu113

先来看一段代码:

# 初始化一个lstm模型,和一个输入,以及手动初始化lstm初始状态h0和c0
rnn = torch.nn.LSTM(640, 512, 1, batch_first=True)
xs = torch.randn([1, 80, 640])
h0, c0 = torch.zeros([1,1,512]), torch.zeros([1,1,512])
​
# 每次手动输入lstm的初始状态,再用计算之后得到的新的隐含状态更新h0和c0
while True:
    _, (h0, c0) = rnn(xs, (h0, c0))

但是上述代码会导致内存不断增长,直到耗尽所有内存,linux主动将其杀掉。

分析(分析没啥用,可以直接看结论)

问题就出在h0和c0,可能我们会觉得,执行rnn的前向过程也只是相当于调用一个函数,用函数的返回值来更新传入的参数是很常规的操作,而且函数接受的实参和返回的值一般会执行拷贝构造,也就是说函数里面的只是一个副本,传出来的也只是一个副本,而且出了作用域,资源就应该释放了,那为什么还会出现这个问题呢?

与C++不同,在python中引用、指针的概念被隐藏了,我们在用一个tensor赋值另一个变量的时候,实际上拿的是它的引用,例如如下例子:

a = torch.Tensor([0,1])
b = a
a[0] += 1
print(b)
# 此时b的值也变为了tensor([1., 1.])

Python有自己的回收机制,Python在内存中存储了每个对象的引用计数reference count。当一个Python对象被引用时其引用计数增加1,当其不再被一个变量引用时则计数减1,如果计数值变成0,那么相应的对象就会消失,分配给该对象的内存就会释放出来用作他用。

所以,我们再来看看最开始的会导致内存泄漏的代码发生了什么:

  • 最开始我们初始化了h0,这时候h0是指向一个全零tensor的引用

  • 经过lstm前向过程,计算得到了一个新的tensor

  • 我们将这个新的tensor赋给h0,也就是说h0不再指向老tensor,而是指向新tensor

正常到这里老tensor就应该释放了,就不应该有内存泄漏了。python的内存泄漏一般是由于循环引用导致,但循环引用如果达到一定的深度,也会触发回收机制回收。而且就上面的分析来看,这其中也不存在循环引用。

结论

实际上这既不是形参实参的问题,也不是循环引用导致的问题,也不是深拷贝浅拷贝的问题。其中的原因和pytorch的动态计算图有关,参与过计算的tensor依然是计算图的一部分。因为每次的输出节点又变为了下一次的输入节点,这样torch就会不断扩大计算图,所以while循环中,计算图会不断变大,内存就会线性增长。而且计算图是一直都在的,这部分内存一直不会释放。

知道了问题的原因,可以使用detach方法手动将其从计算图中脱离出来。

while True:
    _, (h0, c0) = rnn(xs, (h0.detach(), c0.detach()))

PS:严格意义上,这实际上不算是一个内存泄漏问题,只能算是使用方式不对。本来是想将lstm的forward过程当做一个独立的函数使用,但是torch却认为你的输出节点就是下一次的输入,需要不断的扩大计算图。

你可能感兴趣的:(lstm,pytorch,python,内存泄漏)