pytorch 中 torch.sqrt 的坑

记录一下,避免后面踩坑

单步调了半小时,幸好这个问题出现很频繁
发现是函数与导函数定义域问题。。。

sqrt(x) 函数的定义域为 [0, 无穷大)
sqrt(x) 的导函数的定义域 却是 (0, 无穷大)

这些函数定义域跟导函数的定义域不一样,正向传播可以得到正常结果,但是一旦backward就会得到Nan。。。

问题重现

import torch
a = torch.zeros(1)
a.requireds_grad = True
b = torch.sqrt(a)
b.backward()
print(a.grad)
# 得到nan

如何解决
让输入的值符合sqrt的导函数定义域就可以解决该问题了。举个例子:设 x 的定义域为 [0, 无穷大) ,给 x 加个很小的数,例如1e-8,使其输入值的定义域略微往右偏移,就可以避开 0 这个未定义值了;y = sqrt(x + 1e-8)

你可能感兴趣的:(pytorch 中 torch.sqrt 的坑)