各种loss实现

        bsz = pred.shape[0]
        if pred.dim() != target.dim():
            # one_hot_target, weight = _expand_onehot_labels(target, pred.size(-1))
            one_hot_target = F.one_hot(target).float()

        # pred_norm = pred.sigmoid() if self.require_sigmoid else pred
        # pred_norm = 1. / (torch.exp(2.*pred) + 1.0)
        # one_hot_target = one_hot_target.type_as(pred)
        pred_norm = torch.clamp_min(pred, 0.)

        if self.downweight_pos:
            pt = (1 - pred_norm) * one_hot_target + pred_norm * (1 - one_hot_target)
            focal_weight = (self.alpha * one_hot_target + (1 - self.alpha) * (1 - one_hot_target)) * pt.pow(self.gamma)
        else:
            pt = (1 / pred_norm) * one_hot_target + pred_norm * (1 - one_hot_target)
            focal_weight = pt.pow(self.gamma)

        pred_log_softmax = -F.log_softmax(pred, dim=1)
        loss = (one_hot_target*pred_log_softmax).sum() / bsz
        print('\n')
        print('nll_loss', loss)
        print('ce loss:', F.cross_entropy(pred, target))
        print('our binary loss:', -(pred.sigmoid().log()*one_hot_target+(1-one_hot_target)*(1-pred.sigmoid()).log()).mean())
        print('binary loss:', F.binary_cross_entropy_with_logits(pred, one_hot_target).mean())

        return loss

你可能感兴趣的:(torch,深度学习,pytorch,tensorflow)