pytorch 中的 backward()

今天在学 pytorch 反向传播时发现 backward() 函数是可以往里面传参的,于是仔细查了一下这个函数及其参数是干什么的。

github上有大牛分析如下:

https://sherlockliao.github.io/2017/07/10/backward/

这里再简单总结一下。

如果 backward() 没有参数,调用 backward() 函数的变量必须是一个标量,即形状为(1,),否则就会报错。这时候其实相当于传参 backward(torch.Tensor([1.0])),这里的 1.0 可以理解为最后每个梯度要乘以的步长。

如果 backward() 有参数,比如 backward(torch.FloatTensor([0.1, 1.0, 0.001])),那么此时调用 backward() 函数的变量可以不是标量,也可以是向量,但是要注意的是,这个向量的维度必须和 backward() 的参数向量的维度相同,其实此时就相当于函数有多个输出,每个输出都要算一个梯度,并且每个输出算出来的梯度的步长对应于 backward() 的参数向量对应的维度,且最终某个输入 x 的梯度是所有这些输出针对 x 算出的梯度的加权求和(权值向量就是 backward 的参数向量)。

你可能感兴趣的:(pytorch 中的 backward())