Pytorch: 设置局部梯度

torch.no_grad()torch.enable_grad(),  torch.set_grad_enabled()

这三个函数对于设置局部梯度和赋能梯度计算。

上代码:

>>> x = torch.zeros(1, requires_grad=True)
>>> with torch.no_grad():
...     y = x * 2
>>> y.requires_grad
False

>>> is_train = False
>>> with torch.set_grad_enabled(is_train):
...     y = x * 2
>>> y.requires_grad
False

>>> torch.set_grad_enabled(True)  # this can also be used as a function
>>> y = x * 2
>>> y.requires_grad
True

>>> torch.set_grad_enabled(False)
>>> y = x * 2
>>> y.requires_grad
False

 

你可能感兴趣的:(Pytorch)