pytorch笔记:03)softmax和log_softmax,以及CrossEntropyLoss

softmax在神经网络里面比较常见,简而言之,就是多分类的概率输出

sotfmax(xi)=exp(xi)jexp(xj) s o t f m a x ( x i ) = exp ⁡ ( x i ) ∑ j exp ⁡ ( x j )

但是在pytorch里面发现额外有个log_softmax( 对softmax取了一个In的对数),为啥这样做呢?
其实涉及到 对数似然损失函数,对于用于分类的softmax激活函数,对应的损失函数一般都是用对数似然函数,即:
J(W,b,aL,y)=kyklnaLk J ( W , b , a L , y ) = − ∑ k y k l n a k L

其中 lnak l n a k 为softmax函数的输出元, yk y k 的取值为0或者1,如果某一训练样本的输出为第i类。则 yi=1 y i = 1 ,其余的 ji j ≠ i 都有 yj y j =0。由于每个样本只属于一个类别,所以这个对数似然函数可以简化为:
J(W,b,aL,y)=lnaLi J ( W , b , a L , y ) = − l n a i L

其中 i i 即为训练样本真实的类别序号。

pytorch里面提供了一个实现 torch.nn.CrossEntropyLoss(This criterion combines nn.LogSoftmax() and nn.NLLLoss() in one single class),其整合了上面的步骤。这和tensorflow中的tf.nn.softmax_cross_entropy_with_logits函数的功能是一致的。必须明确一点:在pytorch中若模型使用CrossEntropyLoss这个loss函数,则不应该在最后一层再使用softmax进行激活

然而在keras中,我们固化了模型的搭建,诸如:

model.add(Dense(num_classes, activation='softmax'))
model.compile(loss=keras.losses.categorical_crossentropy,optimizer=keras.optimizers.Adadelta(),metrics=['accuracy'])

我们通常在最后一层使用softmax进行激活,保证输出神经元的值即分类的概率值,然后在compile中使用损失函数categorical_crossentropy,这符合常理。其实可以看下keras底层的实现,其实它帮我们手动地计算了crossentropy。

def categorical_crossentropy(target, output, from_logits=False):

    if not from_logits:
        # scale preds so that the class probas of each sample sum to 1
        output /= tf.reduce_sum(output,
                                len(output.get_shape()) - 1,
                                True)
        # manual computation of crossentropy
        _epsilon = _to_tensor(epsilon(), output.dtype.base_dtype)
        output = tf.clip_by_value(output, _epsilon, 1. - _epsilon)
        return - tf.reduce_sum(target * tf.log(output),
                               len(output.get_shape()) - 1)
    else:
        return tf.nn.softmax_cross_entropy_with_logits(labels=target,
                                                       logits=output)

题外话:
为什么要纠结这个问题?
在天池的一个比赛中,要输出每个类别的取值概率,使用keras直接输出最后一层即可;然而在pytorch中softmax整合到了损失函数中,最后一层没有使用softmax进行激活。

reference:
对数似然损失函数 http://www.cnblogs.com/pinard/p/6437495.html

你可能感兴趣的:(机器·深度学习)