Pytorch训练过程中GPU显存不断增加的解决方案

Pytorch训练过程中显存不断增加原因之一

在使用pytorch利用测试集进行网络预测时,给网络输入数据,默认会构建计算图,构建计算图是为了方便后续的反向传播进行梯度计算,如果只是为了利用网络进行预测,则不需要构建完整的计算图。构建完整计算图会增加计算和累积内存消耗,导致所占GPU显存越来越大。

解决方案
在测试代码处于如下命令下:

with torch.no_grad():

例如:

with torch.no_grad():
	prediction = net(images)
	loss = loss_func(prediction , label) / batch_size

你可能感兴趣的:(pytorch,深度学习,人工智能)