Pytorch中反向传播计算图问题

Pytorch中反向传播计算图问题

问题复现:

  • pytorch中进行梯度计算的过程中,如果计算图已经完成了构建,那么即使变更了计算图中的数值结构,计算结果或出现的报错也不会改变

示例解析:

构架 x ∗ w x*w xw,并计算对应的mse loss:

# -*- coding: utf-8 -*-
import torch
import numpy as np

x = torch.ones((1, 3))
w = torch.full([3, 1], 2.)
mse = torch.nn.functional.mse_loss(torch.ones(1), x*w)

计算mse对应 w w w梯度:

torch.autograd.grad(mse, [w])

这时会报错:

Traceback (most recent call last):
  File "C:\data\PyCharm 2021.1.2\plugins\python\helpers\pydev\_pydevd_bundle\pydevd_exec2.py", line 3, in Exec
    exec(exp, global_vars, local_vars)
  File "", line 1, in <module>
  File "C:\data\anaconda\envs\torch\lib\site-packages\torch\autograd\__init__.py", line 225, in grad
    inputs, allow_unused, accumulate_grad=False)
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

会提示 w w w权重在构建的时候require_grad为False,那我们将w的赋予梯度计算信息,再次计算梯度:

w.requires_grad_()
torch.autograd.grad(mse, [w])

这时候还是会报错:

tensor([[2.],
        [2.],
        [2.]], requires_grad=True)

Traceback (most recent call last):
  File "C:\data\PyCharm 2021.1.2\plugins\python\helpers\pydev\_pydevd_bundle\pydevd_exec2.py", line 3, in Exec
    exec(exp, global_vars, local_vars)
  File "", line 1, in <module>
  File "C:\data\anaconda\envs\torch\lib\site-packages\torch\autograd\__init__.py", line 225, in grad
    inputs, allow_unused, accumulate_grad=False)
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

还是会提示没有梯度信息,这就是由于在构建mse计算图时,计算图已经构建完成,这时候即便更改计算图中权重 w w w的梯度请求信息,计算图也不会更新,从而报错;这时我们需要重新构建计算图,才能进行正常的计算。

mse = torch.nn.functional.mse_loss(torch.ones(1), x*w)
torch.autograd.grad(mse, [w])

输出结果:

(tensor([[0.6667],
        [0.6667],
        [0.6667]]),)

完整测试代码:

# -*- coding: utf-8 -*-
import torch
import numpy as np

x = torch.ones((1, 3))
w = torch.full([3, 1], 2.)
mse = torch.nn.functional.mse_loss(torch.ones(1), x*w)

grad = torch.autograd.grad(mse, [w])
print(grad)

w.requires_grad_()
torch.autograd.grad(mse, [w])

mse = torch.nn.functional.mse_loss(torch.ones(1), x*w)
torch.autograd.grad(mse, [w])

你可能感兴趣的:(pytorch,深度学习,python)