KL散度 pytorch实现

KL散度 KL Divergence

D K L D_{KL} DKL 是衡量两个概率分布之间的差异程度。

考虑两个概率分布 P P P, Q Q Q(譬如前者为模型输出data对应的分布,后者为期望的分布),则KL散度的定义如下:
D K L = ∑ x P ( x ) l o g P ( x ) Q ( x ) D_{KL} = \sum_xP(x)log\frac{P(x)}{Q(x)} DKL=xP(x)logQ(x)P(x)

D K L = ∫ x P ( x ) l o g P ( x ) Q ( x ) D_{KL} = \int_xP(x)log\frac{P(x)}{Q(x)} DKL=xP(x)logQ(x)P(x)

具体知识参考https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence

pytorch 实现

torch.nn.functional.kl_div(input, target, size_average=None, reduce=None, reduction=‘mean’, log_target=False)

The Kullback-Leibler divergence Loss

See KLDivLoss for details.

  • Parameters

    input – Tensor of arbitrary shape

    target – Tensor of the same shape as input

    size_average (bool, optional) – Deprecated (see reduction). By default, the losses are averaged over each loss element in the batch. Note that for some losses, there multiple elements per sample. If the field size_average is set to False, the losses are instead summed for each minibatch. Ignored when reduce is False. Default: True

    reduce (bool, optional) – Deprecated (see reduction). By default, the losses are averaged or summed over observations for each minibatch depending on size_average. When reduce is False, returns a loss per batch element instead and ignores size_average. Default: True

    reduction (string*,* optional) – Specifies the reduction to apply to the output: 'none' | 'batchmean' | 'sum' | 'mean'. 'none': no reduction will be applied 'batchmean': the sum of the output will be divided by the batchsize 'sum': the output will be summed 'mean': the output will be divided by the number of elements in the output Default: 'mean'

    log_target (bool) – A flag indicating whether target is passed in the log space. It is recommended to pass certain distributions (like softmax) in the log space to avoid numerical issues caused by explicit log. Default: False

input与target是shape相同的tensor, 往往是 number * feature的大小,即从number个样本 计算出feature服从的emperical distribution。

size_average 和 reduce参数已经启用

输出的shape与input相同

需要调整的是reduction参数,常用的是mean和sum

你可能感兴趣的:(机器学习,python学习)