js-divergence的pytorch实现

文章目录

    • 参考文档
    • JSD实现代码
    • 一些注意事项

参考文档

  • pytorch 中的 kl_div 函数
  • 关于logsoftmax与softmax的区别
  • KL散度、JS散度、Wasserstein距离–原理

JSD实现代码

若有纰漏,敬请指出,感谢!

def js_div(p_output, q_output, get_softmax=True):
    """
    Function that measures JS divergence between target and output logits:
    """
    KLDivLoss = nn.KLDivLoss(reduction='batchmean')
    if get_softmax:
        p_output = F.softmax(p_output)
        q_output = F.softmax(q_output)
    log_mean_output = ((p_output + q_output )/2).log()
    return (KLDivLoss(log_mean_output, p_output) + KLDivLoss(log_mean_output, q_output))/2

一些注意事项

  1. 关于dlv函数的使用:

    函数中的 p q 位置相反(也就是想要计算D(p||q),要写成kl_div(q.log(),p)的形式),而且q要先取 log

  2. JS 散度度量了两个概率分布的相似度,基于KL散度的变体,解决了KL散度非对称的问题。所以jsd(q, p)与jsd(p, q)一致。
    js-divergence的pytorch实现_第1张图片

你可能感兴趣的:(pytorch)