Python和PyTorch对比实现cross-entropy交叉熵损失函数及反向传播

摘要

本文使用纯 Python 和 PyTorch 对比实现cross-entropy交叉熵损失函数及其反向传播.

相关

原理和详细解释, 请参考文章 :

通过案例详解cross-entropy交叉熵损失函数

系列文章索引 :
https://blog.csdn.net/oBrightLamp/article/details/85067981

正文

1. 定义:

E = − ∑ i = 1 k y i l o g ( s i ) E = -\sum_{i = 1}^{k}y_{i}log(s_{i})\\ E=i=1kyilog(si)

2. 梯度:

▽ E ( s ) = ( ∂ E ∂ s 1 , ∂ E ∂ s 2 , ⋯   , ∂ E ∂ s k ) = ( − y 1 s 1 , − y 2 s 2 , ⋯   , − y k s k ) \triangledown E_{(s)} =(\frac{\partial E}{\partial s_{1}},\frac{\partial E}{\partial s_{2}}, \cdots, \frac{\partial E}{\partial s_{k}}) \\ \quad \\ =( -\frac{y_1}{s_1}, -\frac{y_2}{s_2},\cdots,-\frac{y_k}{s_k}) E(s)=(s1E,s2E,,skE)=(s1y1,s2y2,,skyk)

3. Python和PyTorch对比实现

import torch
import numpy as np


class Entropy:
    def __init__(self):
        self.nx = None
        self.ny = None
        self.dnx = None

    def loss(self, nx, ny):
        self.nx = nx
        self.ny = ny
        loss = np.sum(- ny * np.log(nx))
        return loss

    def backward(self):
        self.dnx = - self.ny / self.nx
        return self.dnx


np.random.seed(123)
np.set_printoptions(precision=3, suppress=True, linewidth=120)

entropy = Entropy()

x = np.random.random([5, 10])
y = np.random.random([5, 10])
x_tensor = torch.tensor(x, requires_grad=True)
y_tensor = torch.tensor(y, requires_grad=True)

loss_numpy = entropy.loss(x, y)
grad_numpy = entropy.backward()

loss_tensor = (- y_tensor * torch.log(x_tensor)).sum()
loss_tensor.backward()
grad_tensor = x_tensor.grad

print("Python Loss :", loss_numpy)
print("PyTorch Loss :", loss_tensor.data.numpy())

print("\nPython dx :")
print(grad_numpy)
print("\nPyTorch dx :")
print(grad_tensor.data.numpy())

"""
输出 :
Python Loss : 22.6094161164
PyTorch Loss : 22.609416116382963

Python dx :
[[ -0.173  -2.888  -2.658  -0.989  -0.476  -0.719  -0.425  -0.995  -1.82   -1.302]
 [ -1.95   -0.804  -1.425 -11.306  -2.116  -0.113  -4.185  -1.389  -0.365  -1.076]
 [ -0.151  -1.042  -0.866  -1.184  -0.022  -1.841  -1.539  -0.696  -0.521  -1.102]
 [ -3.461  -1.596  -1.287  -0.788  -2.173  -2.695  -0.838  -0.049  -0.323  -0.793]
 [ -1.13   -8.609  -1.122  -1.838  -0.685  -2.762  -0.313  -0.405  -0.464  -0.56 ]]
 
PyTorch dx :
[[ -0.173  -2.888  -2.658  -0.989  -0.476  -0.719  -0.425  -0.995  -1.82   -1.302]
 [ -1.95   -0.804  -1.425 -11.306  -2.116  -0.113  -4.185  -1.389  -0.365  -1.076]
 [ -0.151  -1.042  -0.866  -1.184  -0.022  -1.841  -1.539  -0.696  -0.521  -1.102]
 [ -3.461  -1.596  -1.287  -0.788  -2.173  -2.695  -0.838  -0.049  -0.323  -0.793]
 [ -1.13   -8.609  -1.122  -1.838  -0.685  -2.762  -0.313  -0.405  -0.464  -0.56 ]]
"""

你可能感兴趣的:(深度学习编程)