pytorch with torch.no_grad() 功能函数详解

torch.no_grad() 是一个上下文管理器,被该语句 wrap 起来的部分将不会track 梯度。

用Anaconda3 虚拟环境测试一下功能;

>>> import torch
>>> k=torch.tensor([1.1], requires_grad=True)
>>> k
tensor([1.1000], requires_grad=True)
>>> h=k*2
>>> h
tensor([2.2000], grad_fn=)
>>>

可以看到不被wrap的情况下,grad_fn 为 addbackward 表示这个add 操作被track了

>>> with torch.no_grad():
...     h.mul_(2)
...
tensor([4.4000], grad_fn=)

在被包裹的情况下可以看到 grad_fn 还是为 add,mul 操作没有被 track. 但是注意,乘法操作是被执行了的

所以如果有不想被track的计算部分可以通过这么一个上下文管理器包裹起来。这样可以执行计算,但该计算不会在反向传播中被记录。

扩展:
同样还可以用 torch.set_grad_enabled()来实现不计算梯度。
例如:

def eval():
	torch.set_grad_enabled(False)
	...	# your test code
	torch.set_grad_enabled(True)

 

你可能感兴趣的:(pytorch,pytorch,人工智能)