【torch.no_grad()】

torch.no_grad是一个类,pytorch官网描述如下:

 一个上下文管理器,disable梯度计算。disable梯度计算对于推理是有用的,当你确认不会调用Tensor.backward()的时候。这可以减少计算所用内存消耗。这个模式下,每个计算结果的requires_grad=False,尽管输入的requires_grad=True。

上下文管理器是thread local的,不会影响其它线程的计算。

x = torch.tensor([1.], requires_grad=True)
with torch.no_grad():
    y = x * 2
y.requires_grad  # False

也可以作为装饰器。 

@torch.no_grad()
def doubler(x):
    return x * 2
z = doubler(x)
z.requires_grad  # False

在我们对模型进行验证的时候,可以使用下面两种格式:

model.eval()
with torch.no_grad():
    pass

或者使用装饰器

@torch.no_grad()
def eval():
    ...

你可能感兴趣的:(PyTorch学习,python,深度学习,pytorch,人工智能)