Log_Softmax()激活函数、NLLLoss()损失函数、CrossEntropyLoss()损失函数浅析

Log_Softmax()激活函数、NLLLoss()损失函数、CrossEntropyLoss()损失函数

1. Log_Softmax()激活函数。

Softmax() 函数的值域是 [ 0 , 1 ] [0, 1] [0,1],公式:
σ ( z ) j = e z j ∑ k = 1 n e z k \sigma(z)_j =\frac{e^{z_j}}{\sum_{k=1}^{n} {e^{z_k}}} σ(z)j=k=1nezkezj

Log_Softmax() 函数的值域是 ( − ∞ , 0 ] (- \infty, 0] (,0],公式:
l o g _ s o f t m a x = l o g e [ σ ( z ) j ] log\_softmax =log_e \left[\sigma(z)_j \right] log_softmax=loge[σ(z)j]

2. NLLLoss()损失函数

NLLLoss() ,即负对数似然损失函数(Negative Log Likelihood)。
NLLLoss() 损失函数公式:
n l l l o s s = − 1 N ∑ k = 1 N y k ( l o g _ s o f t m a x ) nllloss =-\frac{1}{N} {\sum_{k=1}^{N} y_k \left(log\_softmax \right)} nllloss=N1k=1Nyk(log_softmax)

y k y_k yk :one_hot 编码之后的数据标签

NLLLoss() 损失函数运算的结果,即是 y k y_k yk 与 经过 l o g _ s o f t m a x ( ) log\_softmax() log_softmax() 函数激活后的数据,两者相乘,再求平均值,最后取反。

实际使用NLLLoss()损失函数时,传入的标签,无需进行 one_hot 编码。

3. CrossEntropyLoss()损失函数

CrossEntropyLoss()损失函数,是将Log_Softmax()激活函数与NLLLoss()损失函数的功能综合在一起了。
c r o s s _ e n t r o p y = l o g _ s o f t m a x + n l l l o s s cross\_entropy = log\_softmax + nllloss cross_entropy=log_softmax+nllloss

传入的数据无需激活,标签无需做 one_hot 编码。

实际应用中,常选用NLLLoss()函数,如此可以控制数据的激活操作。


代码

代码运行之后,三种方式的运算结果是一致的。

import torch.nn.functional as F
import torch


# 手动实现 NLLLoss() 函数功能
data = torch.randn(5, 5)  # 随机生成一组数据
target = torch.tensor([0, 2, 4, 3, 1])  # 标签
one_hot = F.one_hot(target).float()  # 对标签作 one_hot 编码

exp = torch.exp(data)  # 以e为底作指数变换
sum = torch.sum(exp, dim=1).reshape(-1, 1)  # 按行求和
softmax = exp / sum  # 计算 softmax()
log_softmax = torch.log(softmax)  # 计算 log_softmax()
nllloss = -torch.sum(one_hot * log_softmax) / target.shape[0]  # 标签乘以激活后的数据,求平均值,取反
print("nllloss:", nllloss)


# 调用 NLLLoss() 函数计算
Log_Softmax = F.log_softmax(data, dim=1)  # log_softmax() 激活
Nllloss = F.nll_loss(Log_Softmax, target)  # 无需对标签作 one_hot 编码
print("Nllloss:", Nllloss)


# 直接使用交叉熵损失函数 CrossEntropy_Loss()
cross_entropy = F.cross_entropy(data, target)  # 无需对标签作 one_hot 编码
print('cross_entropy:', cross_entropy)

你可能感兴趣的:(三两闲事,深度学习,python,神经网络,机器学习,pytorch)