PyTorch 中的交叉熵函数 CrossEntropyLoss 的计算过程

CrossEntropyLoss() 函数联合调用了 nn.LogSoftmax() 和 nn.NLLLoss()。

假设网络得到的输出为 h h h,它的维度大小为 B × C B\times C B×C,其中 B B B 是 batch_size, C C C 是分类的总数目。与之对应的训练数据的标签 y y y 维度是 1 × B 1\times B 1×B y y y 中元素的取值范围是 [ 0 , C − 1 ] [0, C-1] [0,C1],即
0 ≤ y [ j ] ≤ C − 1 j = 0 , 1 , ⋯   , B − 1 0\le y[j]\le C-1 \qquad j = 0, 1, \cdots, B-1 0y[j]C1j=0,1,,B1

我们将CrossEntropyLoss() 函数的计算过程拆解为如下两个步骤:

  1. 对输出 h h h,执行LogSoftmax(dim=1),得到 s s s,维度仍然是 B × C B\times C B×C
  2. s s s 执行 − log ⁡ ( ) -\log() log()操作,得到负对数概率 p p p,维度仍然是 B × C B\times C B×C

则交叉熵的计算公式为:
(1) L = 1 B ∑ i = 0 B { − log ⁡ ( p [ i , y [ i ] ] ) } L = \frac{1}{B}\sum_{i=0}^B\left\{-\log(p[i,y[i]])\right\} \tag{1} L=B1i=0B{log(p[i,y[i]])}(1)

式(1)其实是从式(2)化简得来的:
(2) L = 1 B ∑ i = 0 B { − ∑ j = 0 C − 1 y [ i , j ] log ⁡ ( p [ i , j ] ) } L = \frac{1}{B}\sum_{i=0}^B\left\{-\sum_{j=0}^{C-1}y[i, j]\log(p[i,j])\right\} \tag{2} L=B1i=0B{j=0C1y[i,j]log(p[i,j])}(2)

举例说明:

对于 C = 10 C=10 C=10 y = [ 7 , 7 , 2 , 4 ] y=[7, 7, 2, 4] y=[7,7,2,4] 的情况,可知 B = 4 B=4 B=4,首先需要把 y y y扩展为 B × C B\times C B×C 的矩阵:
y = [ 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 ] y = \begin{bmatrix} 0 & 0 & 0 & 0 & 0 & 0 & 0 & 1 & 0 & 0\\ 0 & 0 & 0 & 0 & 0 & 0 & 0 & 1 & 0 & 0\\ 0 & 0 & 1 & 0 & 0 & 0 & 0 & 0 & 0 & 0\\ 0 & 0 & 0 & 0 & 1 & 0 & 0 & 0 & 0 & 0 \end{bmatrix} y=0000000000100000000100000000110000000000
其中为1的元素位置,就是最终概率 p p p 中需要取值的位置。

网络得到的输出
h = [ − 0.1070 0.0083 − 0.0789 0.0341 0.0686 − 0.0088 0.0540 − 0.1017 0.0267 0.0925 − 0.0977 − 0.0053 − 0.0613 0.0576 0.0690 − 0.0104 0.0558 − 0.1133 0.0502 0.0775 − 0.1049 − 0.0091 − 0.0663 0.0611 0.0709 − 0.0168 0.0602 − 0.1072 0.0477 0.0878 − 0.1164 − 0.0018 − 0.0746 0.0531 0.0670 − 0.0142 0.0700 − 0.1005 0.0491 0.0939 ] h = \begin{bmatrix} -0.1070 & 0.0083 & -0.0789 & 0.0341 & 0.0686 & -0.0088 & 0.0540 & -0.1017 & 0.0267 & 0.0925\\ -0.0977 & -0.0053 & -0.0613 & 0.0576 & 0.0690 & -0.0104 & 0.0558 & -0.1133 & 0.0502 & 0.0775\\ -0.1049 & -0.0091 & -0.0663 & 0.0611 & 0.0709 & -0.0168 & 0.0602 & -0.1072 & 0.0477 & 0.0878\\ -0.1164 & -0.0018 & -0.0746 & 0.0531 & 0.0670 & -0.0142 & 0.0700 & -0.1005 & 0.0491 & 0.0939 \end{bmatrix} h=0.10700.09770.10490.11640.00830.00530.00910.00180.07890.06130.06630.07460.03410.05760.06110.05310.06860.06900.07090.06700.00880.01040.01680.01420.05400.05580.06020.07000.10170.11330.10720.10050.02670.05020.04770.04910.09250.07750.08780.0939

s = [ 0.0898 0.1007 0.0923 0.1034 0.1070 0.0990 0.1054 0.0902 0.1026 0.1096 0.0903 0.0990 0.0936 0.1055 0.1067 0.0985 0.1053 0.0889 0.1047 0.1076 0.0896 0.0986 0.0931 0.1058 0.1068 0.0979 0.1057 0.0894 0.1044 0.1087 0.0886 0.0993 0.0923 0.1049 0.1064 0.0981 0.1067 0.0900 0.1045 0.1093 ] s = \begin{bmatrix} 0.0898 & 0.1007 & 0.0923 & 0.1034 & 0.1070 & 0.0990 & 0.1054 & 0.0902 & 0.1026 & 0.1096\\ 0.0903 & 0.0990 & 0.0936 & 0.1055 & 0.1067 & 0.0985 & 0.1053 & 0.0889 & 0.1047 & 0.1076\\ 0.0896 & 0.0986 & 0.0931 & 0.1058 & 0.1068 & 0.0979 & 0.1057 & 0.0894 & 0.1044 & 0.1087\\ 0.0886 & 0.0993 & 0.0923 & 0.1049 & 0.1064 & 0.0981 & 0.1067 & 0.0900 & 0.1045 & 0.1093 \end{bmatrix} s=0.08980.09030.08960.08860.10070.09900.09860.09930.09230.09360.09310.09230.10340.10550.10580.10490.10700.10670.10680.10640.09900.09850.09790.09810.10540.10530.10570.10670.09020.08890.08940.09000.10260.10470.10440.10450.10960.10760.10870.1093

p = [ 2.4107 2.2954 2.3826 2.2696 2.2351 2.3125 2.2497 2.4054 2.2770 2.2112 2.4048 2.3124 2.3684 2.2495 2.2381 2.3175 2.2513 2.4204 2.2569 2.2296 2.4123 2.3165 2.3737 2.2463 2.2365 2.3242 2.2472 2.4146 2.2597 2.2196 2.4242 2.3096 2.3824 2.2547 2.2408 2.3220 2.2378 2.4083 2.2587 2.2139 ] p = \begin{bmatrix} 2.4107 & 2.2954 & 2.3826 & 2.2696 & 2.2351 & 2.3125 & 2.2497 & 2.4054 & 2.2770 & 2.2112\\ 2.4048 & 2.3124 & 2.3684 & 2.2495 & 2.2381 & 2.3175 & 2.2513 & 2.4204 & 2.2569 & 2.2296\\ 2.4123 & 2.3165 & 2.3737 & 2.2463 & 2.2365 & 2.3242 & 2.2472 & 2.4146 & 2.2597 & 2.2196\\ 2.4242 & 2.3096 & 2.3824 & 2.2547 & 2.2408 & 2.3220 & 2.2378 & 2.4083 & 2.2587 & 2.2139 \end{bmatrix} p=2.41072.40482.41232.42422.29542.31242.31652.30962.38262.36842.37372.38242.26962.24952.24632.25472.23512.23812.23652.24082.31252.31752.32422.32202.24972.25132.24722.23782.40542.42042.41462.40832.27702.25692.25972.25872.21122.22962.21962.2139

因此,最终的交叉熵
L = 2.4054 + 2.4204 + 2.3737 + 2.2408 4 = 2.36 L = \frac{2.4054 + 2.4204 + 2.3737 + 2.2408 }{4} = 2.36 L=42.4054+2.4204+2.3737+2.2408=2.36

你可能感兴趣的:(人工智能/深度学习/机器学习)