一行代码解决 Pytorch 测试时显存爆满

问题:

同样的 batch size,Pytorch 模型在训练时显存正常,验证、测试时每个batch显存一直逐步增长直到爆满。

解决方法:

使用 with torch.no_grad(),在测试的时候让模型不要保存梯度:

with torch.no_grad():
	for batch, data in test_dataloader():
		test(data) # 你的test代码

这样在每个batch时梯度不会被保存,避免梯度数据堆积消耗显存。

你可能感兴趣的:(python,pytorch,机器学习,python,机器学习,深度学习,神经网络)