pytorch 训练过程内存泄露/显存泄露debug记录:dataloader和dataset导致的泄露

背景

微调 mask-rcnn 代码,用的是 torchvision.models.detection.maskrcnn_resnet50_fpn 代码,根据该代码的注释,输入应该是:
images, targets=None
(List[Tensor], Optional[List[Dict[str, Tensor]]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]

所以我写的 dataset 是这样的:

def _load_ann(self):
	transformed_anns = {'boxes': boxes : List[List],
						'labels': categories: List[int],
						'masks': masks: List[str]}
	self.anns[filename] = transformed_anns

def __getitem__(self, item) -> (Tensor, Optional[Dict[str, Tensor]]):
	img_name = self.img_list[item]
	img = cv2.imread(os.path.join(self.root_path, self.split, img_name))
	if self.split == 'train':
		return img, self.anns[img_name]

大概思路是:先把所有的标注信息读入内存,然后按照 img_name 把标注信息(也就是 target )取出来

这里有个令人纠结的地方:__get_item__ 返回的到底是什么格式的数据?一开始我是直接把 boxes labels masks 都直接写成 tensor 返回,但是一次性把所有的 masks 都读到内存,太大了。再加上看了这个 pytorch内存泄露-dataloader - 知乎 ,这里建议 dataset 的 __get_item__ 返回的都是 python 的基础数据类型,所以我就改成了上面的样子。其实返回什么类型的都行,只要在 dataloader 的 collate_fn 方法里面都转成可以送入模型的数据形式就行了。

因为 dataset 是上面的写法,所以对应的 collate_fn 写法是:

def collate_fn(datas: List[Tuple[Tensor, Dict]]):
	imgs = []
	targets = []
	for data in datas:
		img, target = data
		imgs.append(transforms.ToTensor()(img))

		target['boxes'] = torch.tensor(target['boxes'], dtype=torch.float)
		target['labels'] = torch.tensor(target['labels'], dtype=torch.int64)

		masks = target['masks']
		masks = [cv2.imread(mask, 0) for mask in masks]
		masks = np.stack(masks, axis=0)
		masks = masks / 255
		masks = masks.astype(np.uint8)
		target['masks'] = torch.from_numpy(masks)

		targets.append(target)
	return imgs, targets

错误排查及解决方法

把所有的数据送入 model 的代码都注释掉,只保留如下代码:

for e in range(epoch):
	for i, (imgs, targets) in enumerate(train_dataloader):
		imgs = [img.to(device) for img in imgs]
		targets = [_to_device(target, device) for target in targets]

watch -n 1 nvidia-smi 监控显存占用,发现一直在涨。毫无疑问肯定是 dataloader 导致的显存泄露

然后就是排查,到底是谁?是谁想害朕??

排查方法是:分别注释掉 imgs / boxes / labels / masks ,观察注释掉谁的时候不会显存泄露。

发现,是 masks 导致的内存泄露。

但是这很怪啊,明明 masks 和 imgs 是一样的数据类型,为什么前者会显存泄露,但是后者不会?于是我把 masks 单独拿出来,像 imgs 一样放在 list 里面,不会内存泄露。但是一旦把 imgs 嵌套放在 targets 这个 dict 里面,就会显存泄露 orz

于是,既然是 masks 没有释放,所以我加一句:

for e in range(epoch):
	for i, (imgs, targets) in enumerate(train_dataloader):
		imgs = [img.to(device) for img in imgs]
		targets = [_to_device(target, device) for target in targets]
		# ... 传入模型的计算
		for target in targets:
			del target['masks']

但是没用,还是泄露。然后查了 pytorch 怎么释放 tensor,发现要主动调用 torch.cuda.empty_cache() 才会释放,所以我又加了一句:

for e in range(epoch):
	for i, (imgs, targets) in enumerate(train_dataloader):
		imgs = [img.to(device) for img in imgs]
		targets = [_to_device(target, device) for target in targets]
		# ... 传入模型的计算
		for target in targets:
			del target['masks']
		torch.cuda.empty_cache()

这回没有显存泄露了。

但是出现了新的问题,在 epoch=2 的时候报错 targets 没有 masks 这个 key;然后我 debug 发现,由 dataloader 取到的数据 label 和 boxes 在 collate_fn 之前就已经是 tensor 状态了,再往前倒,发现 dataset.anns 里面的数据居然被改了!这实在是太荒谬了。

所以我把 __get_item 改成:

def __getitem__(self, item) -> (Tensor, Optional[Dict[str, Tensor]]):
	img_name = self.img_list[item]
	img = cv2.imread(os.path.join(self.root_path, self.split, img_name))
	if self.split == 'train':
		return img, deepcopy(self.anns[img_name])

这样就没问题了

总结

  1. 查找内存泄露/显存泄露的位置:
    • 把数据送入模型的代码全部注释掉,观察显存是否上涨;上涨说明内存泄露出现在 dataloader(出现在非 dataloader 地方的最常见的显存泄露原因是,loss 打印/统计的时候没有写 loss.item()
    • 把不同的 data 组成部分注释掉,观察具体是哪个 data 导致的内存泄露
  2. pytorch 释放内存的方法:把 tensor 读到 gpu 就会有显存占用,一般可以自动释放,但是显存泄露的时候就没法释放。找到没有及时释放的代码位置,然后首先 del tensor 标记删除,随后需要调用 torch.cuda.empty_cache() 才能真正释放
  3. dataset 的 __get_item__ 方法注意,如果要返回内部维护的 list 类型的数据的话,不要直接返回该数据切片,而是返回 deepcopy() 防止内部维护的数据被外部修改

你可能感兴趣的:(debug,pytorch,人工智能,python)