Focal Loss--Keras

import keras.bankend as K
import tensorflow as tf

def catergorical_focal_loss(gamma = 2.0, alpha = 0.25):
    """
    Implementation of Focal Loss from the paper in multiclass classification
    Formula:
        loss = -alpha*((1-p_t)^gamma)*log(p_t)
    Parameters:
        alpha -- the same as wighting factor in balanced cross entropy
        gamma -- focusing parameter for modulating factor (1-p)
    Default value:
        gamma -- 2.0 as mentioned in the paper
        alpha -- 0.25 as mentioned in the paper
    """

    def focal_loss(y_true, y_pred):
        # Define epsilon so that the backpropagation will no result in NaN
        # for o divisor case
        epsilon = K.epsilon()
        # Add the epsilon to prediction value
        # y_pred = y_pred + epsilon
        # Clip the prediction value
        y_pred = K.clip(y_pred, epsilon, 1.0-epsilon)
        # Calculate cross entropy
        cross_entropy = -y_true * k.log(y_pred)

        # Calculate weight that consists of modulating factor and weighting factor
        weight = alpha * y_true * K.pow((1-y_pred), gamma)
        # Calculate focal loss
        loss = weight * cross_entropy
        # Sum the losses in mini_batch
        loss = K.sum(loss, axis=1)
        return loss

    return focal_loss

def binary_focal_loss(gamma=2.0, alpha=0.25):
    """
        Implementation of Focal Loss from the paper in multiclass classification
        Formula:
            loss = -alpha_t*((1-p_t)^gamma)*log(p_t)

            p_t = y_pred, if y_true = 1
            p_t = 1-y_pred, otherwise

            alpha_t = alpha, if y_true=1
            alpha_t = 1-alpha, otherwise

            cross_entropy = -log(p_t)
        Parameters:
            alpha -- the same as wighting factor in balanced cross entropy
            gamma -- focusing parameter for modulating factor (1-p)
        Default value:
            gamma -- 2.0 as mentioned in the paper
            alpha -- 0.25 as mentioned in the paper
        """

    def focal_loss(y_true, y_pred):
        # Define espislon so that the backpropagation will not result int NaN
        # for 0 divisor case
        epsilon = K.epsilon()
        # Add the epsilon to prediction value
        # y_pred = y_pred + epsilon
        # Clip the prediction value
        y_pred = K.clip(y_pred, epsilon, 1.0 - epsilon)
        # Calculate p_t
        p_t = tf.where(K.equal(y_true, 1), alpha_factor, 1-alpha_factor)
        # Calculate alpha_t
        alpha_factor = K.once_like(y_true)*alpha
        alpha_t = tf.where(K.equal(y_true, 1), alpha_factor, 1-alpha_factor)
        # Calculate cross entropy
        cross_entropy = -K.log(p_t)
        weight = alpha_t * K.pow((1-p_t), gamma)
        # Calculate focal loss
        loss = weight * cross_entropy
        # Sum the losses in mini_batch
        loss = K.sum(loss, axis=1)

        return loss
    return focal_loss

你可能感兴趣的:(深度学习之Trick)