很多时候我们需要对比两个概率分布: 我们有一个随机变量和它的两个不同的概率分布 (P 与 Q), 比如一个真实的分布和一个模型估计的分布. 这种情况下, 我们需要对分布之间的差异进行量化, 这种量化就是计算两个统计对象(如概率分布)之间的统计距离.
一种方式是计算两个分布之间的距离 distance, 但这很难且无法解释. 另一种方法更通用, 即计算两个分布之间的散度 divergence.
Divergence 是一种度量但却不对称, 可以理解为 divergence 表示对 P 与 Q 的差异程度的评分 (a scoring of how P differs from Q), 且 P 和 Q 的 divergence 与 Q 和 P 的 divergence 值不同.
Divergence 是信息论中许多计算的基础, 也是机器学习中很多计算的基础. 比如 mutual information (information gain) 和 cross-entropy 都是用作分类模型的损失函数.
信息论中常用的两个 divergence 是 KL divergence (Kullback-Leibler divergence) 和 JS divergence (Jensen-Shannon divergence).
KL divergence 也称为 relative entropy, 度量了一个概率分布不同与另一个概率分布的程度(difference/ dissimilarity), 两个分布 P 和 Q 的 KL divergence 的计算公式为 (随机变量X为离散时):
sum 中的值是每个事件的 divergence.
当 P 中的一个事件的概率很大, 但 Q 中相同事件的概率很小时, 此时的 divergence 很大. 当 P 的概率很小, 而 Q 的概率很大时, divergence 也很大, 但没有上一种情况那么大. 所以需要注意:
KL divergence 也可用于测量连续随机变量的分布之间的差异, 此时将上述公式中的 sum 符号换位 integral 符号.
其中 log 可以以 2 为底, 以 bits 为单位; 也可以自然对数 e 为底, 以 nats 为单位. 当 KL divergence 等于 0 时, 表示两个分布相同, 否则为整数. 如果我们想要逼近一个目标概率分布 P, 我们模型的输出的分布为Q, 这时的 KL divergence (以2为底) 计算的是表达随机变量中事件所需的额外 bits. 我们的近似越好, 所需的额外信息就越少.
考虑一个随机变量 X, 它的三个事件为 X 等于三个不同的颜色. 对于 X,我们有两种不同的概率分布 P 和 Q:
# define distributions
events = ['red', 'green', 'blue']
p = [0.10, 0.40, 0.50]
q = [0.80, 0.15, 0.05]
可以将这些概率可视化以便进行比对:
# plot of distributions
from matplotlib import pyplot
print('P=%.3f Q=%.3f' % (sum(p), sum(q)))
# plot first distribution
pyplot.subplot(2,1,1)
pyplot.bar(events, p)
# plot second distribution
pyplot.subplot(2,1,2)
pyplot.bar(events, q)
# show the plot
pyplot.show()
定义一个计算 KL divergence (以 2 为底) 的方法, 并使用这个方法计算 P 与 Q, Q 与 P 的 KL divergence:
from math import log2
def kl_divergence(p, q):
return sum(p[i] * log2(p[i]/q[i]) for i in range(len(p)))
# calculate (P || Q)
kl_pq = kl_divergence(p, q)
print('KL(P || Q): %.3f bits' % kl_pq)
# calculate (Q || P)
kl_qp = kl_divergence(q, p)
print('KL(Q || P): %.3f bits' % kl_qp)
结果为:
KL(P || Q): 1.927 bits
KL(Q || P): 2.002 bits
若将方法种的 log2() 换为自然对数 log(), 结果为:
KL(P || Q): 1.336 nats
KL(Q || P): 1.401 nats
SciPy 库 提供了 kl_div() 和 rel_entr() 函数用于计算 KL divergence.
# example of calculating the kl divergence (relative entropy) with scipy
from scipy.special import rel_entr
# define distributions
p = [0.10, 0.40, 0.50]
q = [0.80, 0.15, 0.05]
# calculate (P || Q)
kl_pq = rel_entr(p, q)
print('KL(P || Q): %.3f nats' % sum(kl_pq))
# calculate (Q || P)
kl_qp = rel_entr(q, p)
print('KL(Q || P): %.3f nats' % sum(kl_qp))
KL(P || Q): 1.336 nats
KL(Q || P): 1.401 nats
运行结果与前面我们自己定义的以e为底的函数结果相同.
Jensen-Shannon divergence 是另一种度量两个概率分布之间的差异的方法. JS divergence 以 KL divergence 为基础, 但它是对称的, 归一化的:
计算公式为:
其中,
当使用以2为底的 log 时, JS divergence 是一种 KL divergence 的平滑和归一化版本, 当值为 0 时, 表示"完全相同", 当值为 1 时, 表示"最不相同".
JS divergence 的均方根表示 Jensen-Shannon distance, 即 JS distance.
首先定义一个 JS divergence 的方法(使用了前述的 kl_divergence() ):
from numpy import asarray
from math import log2
# define distributions
p = asarray([0.10, 0.40, 0.50])
q = asarray([0.80, 0.15, 0.05])
# calculate the kl divergence
def kl_divergence(p, q):
return sum(p[i] * log2(p[i]/q[i]) for i in range(len(p)))
# calculate the js divergence
def js_divergence(p, q):
m = 0.5 * (p + q)
return 0.5 * kl_divergence(p, m) + 0.5 * kl_divergence(q, m)
分别测试一下 JS divergence 和 JS distance:
from math import sqrt
# calculate JS(P || Q)
js_pq = js_divergence(p, q)
print('JS(P || Q) divergence: %.3f bits' % js_pq)
print('JS(P || Q) distance: %.3f' % sqrt(js_pq))
# calculate JS(Q || P)
js_qp = js_divergence(q, p)
print('JS(Q || P) divergence: %.3f bits' % js_qp)
print('JS(Q || P) distance: %.3f' % sqrt(js_qp))
JS(P || Q) divergence: 0.420 bits
JS(P || Q) distance: 0.648
JS(Q || P) divergence: 0.420 bits
JS(Q || P) distance: 0.648
运行结果可以看出 JS divergence 和 JS distance 都是对称的.
同样地, SciPy 提供了 jensenshannon() 用于计算 JS distance:
# calculate the jensen-shannon distance metric
from scipy.spatial.distance import jensenshannon
from numpy import asarray
# define distributions
p = asarray([0.10, 0.40, 0.50])
q = asarray([0.80, 0.15, 0.05])
# calculate JS(P || Q)
js_pq = jensenshannon(p, q, base=2)
print('JS(P || Q) Distance: %.3f' % js_pq)
# calculate JS(Q || P)
js_qp = jensenshannon(q, p, base=2)
print('JS(Q || P) Distance: %.3f' % js_qp)
JS(P || Q) Distance: 0.648
JS(Q || P) Distance: 0.648