使用apex中的混合精度训练模型时发生显存泄露

问题描述:

使用apex对模型使用混合精度加速(O1),在几乎不损失精度情况下,大大减少模型训练时间。但是当我使用optuna(或者Ray Tune)这些自动寻找超参工具时,随着循环次数增加,显存占用一直在增加,跑了几个超参之后就报OOM的错误了,随着循环次数增加,显存占用一直在增加。在超参训练结束后加入torch.cuda.empty_cache(),仍然出现以上的问题。

原因分析:

开始我认为,是我使用optuna的方式有问题,导致在跑每个超参时发生内存泄露,于是我把代码结构修改了,然并没有什么用。在github上apex搜索memory leak,发现还有很多人有着同样的问题。nvidia的工程师在github上给出的解释,在使用amp.initialize,opt_level选择半精度或者混合精度之后,optimizers中的参数再训练结束后不会被正确释放,造成了内存泄露。然而一年半过去了还是这个毛病…

https://github.com/NVIDIA/apex/issues/439#issuecomment-522360104
I believe the increase in memory might not be related to amp, but is probably caused by storing the internal parameters of the optimizers

If you leave the default settings as use_amp = False, clean_opt = False, you will see a constant memory usage during the training and an increase after switching to the next optimizer.
Setting clean_opt=True will delete the optimizers and thus clean the additional memory.
However, this cleanup doesn’t seem to work properly using amp at the moment.

解决方案:

最终在别人的回答中找到了答案,放弃使用nvidia的apex,直接使用pytorch1.6以上原生支持的apex的写法.

If you’re performing multiple convergence runs in the same script, you should use a new GradScaler instance for each run. GradScaler is a lightweight, self-contained object, so you can construct a new one anytime with the usual.

e.g.
我原本伪代码类似这样:

for lr in [0.1,0.01]:
		model,optimizer=amp.initialize(model,optimizer,opt_level='O1')
	   for epoch in range(1,opt.max_epoch+1):
					with amp.scale_loss(loss,optimizer) as scaled_loss:
					    scaled_loss.backward()

OOM

改后

for  lr in [0.1,0.01] :
		 scaler = torch.cuda.amp.GradScaler()
		  for epoch in range(1,opt.max_epoch+1):
		                with autocast():
		                out_linear= model(inputs)
		                loss = criterion(out_linear, targets_a) * lam + criterion(out_linear, targets_b) * (1. - lam)
		            	scaler.scale(loss).backward()
		           		scaler.step(optimizer)
		            	scaler.update()

github apex gpu memory leak issue
https://github.com/NVIDIA/apex/issues/439

你可能感兴趣的:(深度学习,人工智能,python)