均方差损失函数MSELoss详解及反向传播中的梯度求导

摘要

本文给出均方差损失函数 MSELoss 的定义并求解其在反向传播中的梯度.

相关

系列文章索引 :
https://blog.csdn.net/oBrightLamp/article/details/85067981

正文

均方差损失函数 MSELoss 定义简洁, 梯度求导简单, 应用广泛.

1. 梯度

设向量 s 作为预测值, 向量 y 为实际值, 由 MSELoss 函数计算得出误差值 error (标量 e ), 求 e 关于 s 的梯度.
e = M S E L o s s ( s , y ) = 1 n ∑ t = 1 n ( s t − y t ) 2 e = MSELoss(s, y) =\frac{1}{n}\sum^{n}_{t=1} (s_t-y_t)^2 e=MSELoss(s,y)=n1t=1n(styt)2
求解过程 :
d e d s = 2 n ( ( s 1 − y 1 ) , ( s 2 − y 2 ) , ( s 3 − y 3 ) , ⋯   , ( s n − y n ) ) d e d y = − d e d s \frac{de}{ds}=\frac{2}{n}((s_1-y_1),(s_2-y_2),(s_3-y_3),\cdots,(s_n-y_n))\\ \frac{de}{dy}=-\frac{de}{ds} dsde=n2((s1y1),(s2y2),(s3y3),,(snyn))dyde=dsde

2. 代码实现

import torch
import numpy as np


class MSELoss:
    def __init__(self):
        self.x = None
        self.y = None

    def __call__(self, x, y):
        self.x = x
        self.y = y
        return np.sum(np.square(x - y)) / x.size

    def backward(self):
        dx = 2 * (self.x - self.y) / self.x.size
        return dx, -dx


np.random.seed(123)
np.set_printoptions(precision=6, suppress=True, linewidth=80)

x_numpy = np.random.random(27)
y_numpy = np.random.random(27)
x_torch = torch.tensor(x_numpy, requires_grad=True)
y_torch = torch.tensor(y_numpy, requires_grad=True)

loss_func_numpy = MSELoss()
loss_func_torch = torch.nn.MSELoss().float()

loss_numpy = loss_func_numpy(x_numpy, y_numpy)
loss_torch = loss_func_torch(x_torch, y_torch)

loss_torch.backward()
dx_numpy, dy_numpy = loss_func_numpy.backward()

print(loss_numpy)
print(loss_torch.data.numpy())
print("----------")
print(dx_numpy)
print(x_torch.grad.numpy())
print("----------")
print(dy_numpy)
print(y_torch.grad.numpy())

"""
0.116960116566
0.1169601165663142
----------
[ 0.034682 -0.000561 -0.029935  0.034016  0.021168 -0.000575  0.03608   0.019185
  0.012494 -0.002536 -0.040756 -0.015934 -0.004686 -0.041798  0.02092   0.031164
 -0.01721  -0.051175  0.020822  0.003614 -0.026012  0.02444   0.008264  0.036326
 -0.007696 -0.020748 -0.013576]
[ 0.034682 -0.000561 -0.029935  0.034016  0.021168 -0.000575  0.03608   0.019185
  0.012494 -0.002536 -0.040756 -0.015934 -0.004686 -0.041798  0.02092   0.031164
 -0.01721  -0.051175  0.020822  0.003614 -0.026012  0.02444   0.008264  0.036326
 -0.007696 -0.020748 -0.013576]
----------
[-0.034682  0.000561  0.029935 -0.034016 -0.021168  0.000575 -0.03608  -0.019185
 -0.012494  0.002536  0.040756  0.015934  0.004686  0.041798 -0.02092  -0.031164
  0.01721   0.051175 -0.020822 -0.003614  0.026012 -0.02444  -0.008264 -0.036326
  0.007696  0.020748  0.013576]
[-0.034682  0.000561  0.029935 -0.034016 -0.021168  0.000575 -0.03608  -0.019185
 -0.012494  0.002536  0.040756  0.015934  0.004686  0.041798 -0.02092  -0.031164
  0.01721   0.051175 -0.020822 -0.003614  0.026012 -0.02444  -0.008264 -0.036326
  0.007696  0.020748  0.013576]
"""

你可能感兴趣的:(深度学习编程)