pytorch导数

在pytorch中损失计算

loss = torch.nn.MSEloss()
l = loss(y_hat, y)
l.backward()

之后对每个参数的导数是每个样本求导后的平均值
如果说loss不是求均方误差,而是把误差加起来(不做平均),那么得到对每个参数的导数就是每个样本求导后的和!

你可能感兴趣的:(pytorch导数)