torch.no_grad()降低显存占用

这个函数一般在测试的时候使用,用来降低显存占用

  • 在训练的时候我们需要学习weight和bias的梯度,但是在测试的时候我们不需要获取他们的梯度,因此可以节省大量的显存
  • 在训练时我们crop成了一个一个块,例如size为256,在测试时如果将整幅图片都送入网络,有时会out of memory,因此需要no_grad语句

使用很简单,在测试时将测试代码放到他下面即可

with torch.no_grad():   # 可以显著的降低显存
	for i, img in enumerate(tqdm(test_loader), 0):
    	img = img.cuda()
    	output = model(img)

你可能感兴趣的:(PyTorch,深度学习,pytorch,神经网络)