focal loss的tensorflow实现

最近在进行分类任务的时候,发现了数据存在类别不平衡问题。除了类别不平衡问题之外还有难学样本和易学样本之间的不平衡问题。因此考虑使用了focal loss。这里直接上代码:

def focal_loss(logits, labels, gamma):
    '''
    :param logits:  [batch_size, n_class]
    :param labels: [batch_size]
    :return: -(1-y)^r * log(y)
    '''
    softmax = tf.reshape(tf.nn.softmax(logits), [-1])  # [batch_size * n_class]
    labels = tf.range(0, logits.shape[0]) * logits.shape[1] + labels
    prob = tf.gather(softmax, labels)
    weight = tf.pow(tf.subtract(1., prob), gamma)
    loss = -tf.reduce_mean(tf.multiply(weight, tf.log(prob)))
    return loss   

附上链接:
论文:Focal Loss for Dense Object Detection
更好地理解focal loss

你可能感兴趣的:(focal,loss,样本不平衡,机器学习)