pytorch中的内存泄漏问题解决方案

现象

发现代码在训练的过程中内存占用量越来越大,直至被系统内核kill掉

溯源

通过不断地注释删除代码,将问题的根源定位到dataloader上。发现只要在dataset中的getitem函数中调用了init中加载的大型变量,就会内存报错。例如:


class GQA(Dataset):
	    def __init__(self, **args):
	    # 该文件是一个很大的文件,约30G
        with open('questions/{}_inputs.json'.format(self.split), 'r') as f:
            self.data = json.load(f)

	    def __getitem__(self, index):
	        entry = self.data[index]
	        return entry[0]

运行以上代码dataloader在调用dataset中的getitem函数且numworker>0时,就会发生内存泄漏。

解决

在init函数中的大型变量从list转化为numpy类型,问题便可以解决。改进后的代码如下:

class GQA(Dataset):
	    def __init__(self, **args):
        with open('questions/{}_inputs.json'.format(self.split), 'r') as f:
            self.data = json.load(f)
            # 添加这一行,其中dtype=object是为了适应列表中元素维度不一致的warning
            self.data = numpy.array(self.data, dtype = object)
            

	    def __getitem__(self, index):
	        entry = self.data[index]
	        return entry[0]

改完之后内存泄漏就消失了。同样的,返回值也需要用numpy.array()函数进行处理。

你可能感兴趣的:(pytorch,python,人工智能)