PyTorch torch.no_grad()

torch.no_grad() 一般用于神经网络的推理阶段, 表示张量的计算过程中无需计算梯度


torch.no_grad 是一个类, 实现了 __enter__ 和 __exit__ 方法, 在进入环境管理器时记录梯度使能状态以及禁止梯度计算, 退出环境管理器时还原, 它还继承了 _DecoratorContextManager, 拥有装饰器的能力(依然是使用 with 语句)

# 摘自源码
class no_grad(_DecoratorContextManager):
    def __init__(self):
        self.prev = False

    def __enter__(self):
        self.prev = torch.is_grad_enabled()
        torch.set_grad_enabled(False)

    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
        torch.set_grad_enabled(self.prev)

class _DecoratorContextManager:
    """Allow a context manager to be used as a decorator"""

    def __call__(self, func: F) -> F:
        @functools.wraps(func)
        def decorate_context(*args, **kwargs):
            with self.__class__():
                return func(*args, **kwargs)
        return cast(F, decorate_context)

    def __enter__(self) -> None:
        raise NotImplementedError

    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
        raise NotImplementedError

另外, torch.no_grad 用于代替旧版本的 volatile=True


import torch

x = torch.tensor([1.0], requires_grad=True)

y_1: torch.Tensor = x * x
y_1.backward()
print("y_1:", y_1.requires_grad, x.grad)

with torch.no_grad():
    y_2 = x * x
    print("y_2:", y_2.requires_grad)


@torch.no_grad()
def demo(x):
    y_3 = x * x
    print("y_3:", y_3.requires_grad)


demo(x)

打印

y_1: True tensor([2.])
y_2: False
y_3: False

y_1 是通常情况, y_1依赖于x, 而x需要求导, 所以y_1也需要求导, y_2 和 y_3 明确无需求导


除了 torch.no_grad() 还有 torch.enable_grad() 明确需要求导以及 torch.set_grad_enabled(mode), 它们均支持环境管理器和装饰器

# 单独使用 torch.set_grad_enabled
torch.set_grad_enabled(False)
y_4 = x * x
print("y_4:", y_4.requires_grad)

torch.set_grad_enabled(True)
y_5 = x * x
print("y_5:", y_5.requires_grad)

结果

y_4: False
y_5: True

底层实现位于 “aten/src/ATen/core/grad_mode.cpp”

thread_local bool GradMode_enabled = true;

bool GradMode::is_enabled() {
  return GradMode_enabled;
}

void GradMode::set_enabled(bool enabled) {
  GradMode_enabled = enabled;
}

你可能感兴趣的:(PyTorch,pytorch)