KL Divergence 与 JS Divergence

1. 统计距离 (Statistical Distance)

很多时候我们需要对比两个概率分布: 我们有一个随机变量和它的两个不同的概率分布 (P 与 Q), 比如一个真实的分布和一个模型估计的分布. 这种情况下, 我们需要对分布之间的差异进行量化, 这种量化就是计算两个统计对象(如概率分布)之间的统计距离.

一种方式是计算两个分布之间的距离 distance, 但这很难且无法解释. 另一种方法更通用, 即计算两个分布之间的散度 divergence.

Divergence 是一种度量但却不对称, 可以理解为 divergence 表示对 P 与 Q 的差异程度的评分 (a scoring of how P differs from Q), 且 P 和 Q 的 divergence 与 Q 和 P 的 divergence 值不同.

Divergence 是信息论中许多计算的基础, 也是机器学习中很多计算的基础. 比如 mutual information (information gain) 和 cross-entropy 都是用作分类模型的损失函数.

信息论中常用的两个 divergence 是 KL divergence (Kullback-Leibler divergence) 和 JS divergence (Jensen-Shannon divergence).

2. KL 散度 (KL Divergence)

2.1 介绍

KL divergence 也称为 relative entropy, 度量了一个概率分布不同与另一个概率分布的程度(difference/ dissimilarity), 两个分布 P 和 Q 的 KL divergence 的计算公式为 (随机变量X为离散时):

KL(P || Q) = -\sum_{x \in X} P(x)\times \log(\frac{Q(x)}{P(x)})=\sum_{x \in X} P(x)\times \log(\frac{P(x)}{Q(x)})

sum 中的值是每个事件的 divergence.

当 P 中的一个事件的概率很大, 但 Q 中相同事件的概率很小时, 此时的 divergence 很大. 当 P 的概率很小, 而 Q 的概率很大时, divergence 也很大, 但没有上一种情况那么大. 所以需要注意:

KL(P||Q) \neq KL(Q||P)

KL divergence 也可用于测量连续随机变量的分布之间的差异, 此时将上述公式中的 sum 符号换位 integral 符号. 

其中 log 可以以 2 为底, 以 bits 为单位; 也可以自然对数 e 为底, 以 nats 为单位. 当 KL divergence 等于 0 时, 表示两个分布相同, 否则为整数. 如果我们想要逼近一个目标概率分布 P, 我们模型的输出的分布为Q, 这时的 KL divergence (以2为底) 计算的是表达随机变量中事件所需的额外 bits. 我们的近似越好, 所需的额外信息就越少. 

2.2 python 实例

考虑一个随机变量 X, 它的三个事件为 X 等于三个不同的颜色. 对于 X,我们有两种不同的概率分布 P 和 Q:

# define distributions
events = ['red', 'green', 'blue']
p = [0.10, 0.40, 0.50]
q = [0.80, 0.15, 0.05]

可以将这些概率可视化以便进行比对:

# plot of distributions
from matplotlib import pyplot
print('P=%.3f Q=%.3f' % (sum(p), sum(q)))
# plot first distribution
pyplot.subplot(2,1,1)
pyplot.bar(events, p)
# plot second distribution
pyplot.subplot(2,1,2)
pyplot.bar(events, q)
# show the plot
pyplot.show()

KL Divergence 与 JS Divergence_第1张图片

 定义一个计算 KL divergence (以 2 为底) 的方法, 并使用这个方法计算 P 与 Q, Q 与 P 的 KL divergence:

from math import log2
def kl_divergence(p, q):
	return sum(p[i] * log2(p[i]/q[i]) for i in range(len(p)))
# calculate (P || Q)
kl_pq = kl_divergence(p, q)
print('KL(P || Q): %.3f bits' % kl_pq)
# calculate (Q || P)
kl_qp = kl_divergence(q, p)
print('KL(Q || P): %.3f bits' % kl_qp)

结果为:

KL(P || Q): 1.927 bits

KL(Q || P): 2.002 bits

若将方法种的 log2() 换为自然对数 log(), 结果为:

KL(P || Q): 1.336 nats

KL(Q || P): 1.401 nats

SciPy 库 提供了 kl_div() 和 rel_entr() 函数用于计算 KL divergence.  

# example of calculating the kl divergence (relative entropy) with scipy
from scipy.special import rel_entr
# define distributions
p = [0.10, 0.40, 0.50]
q = [0.80, 0.15, 0.05]
# calculate (P || Q)
kl_pq = rel_entr(p, q)
print('KL(P || Q): %.3f nats' % sum(kl_pq))
# calculate (Q || P)
kl_qp = rel_entr(q, p)
print('KL(Q || P): %.3f nats' % sum(kl_qp))

KL(P || Q): 1.336 nats

KL(Q || P): 1.401 nats

 运行结果与前面我们自己定义的以e为底的函数结果相同.

3. JS 散度 (JS Divergence)

3.1 介绍

Jensen-Shannon divergence 是另一种度量两个概率分布之间的差异的方法. JS divergence 以 KL divergence 为基础, 但它是对称的, 归一化的:

JS(P || Q) = JS (Q || P)

 计算公式为:

JS(P||Q)=\frac{1}{2}KL(P||M)+\frac{1}{2}KL(Q||M)

其中, 

M = \frac{1}{2}(P+Q)

当使用以2为底的 log 时, JS divergence 是一种 KL divergence 的平滑和归一化版本, 当值为 0 时, 表示"完全相同", 当值为 1 时, 表示"最不相同". 

JS divergence 的均方根表示 Jensen-Shannon distance, 即 JS distance.

3.2 python 实例

首先定义一个 JS divergence 的方法(使用了前述的 kl_divergence() ):

from numpy import asarray
from math import log2

# define distributions
p = asarray([0.10, 0.40, 0.50])
q = asarray([0.80, 0.15, 0.05])

# calculate the kl divergence
def kl_divergence(p, q):
	return sum(p[i] * log2(p[i]/q[i]) for i in range(len(p)))

# calculate the js divergence
def js_divergence(p, q):
	m = 0.5 * (p + q)
	return 0.5 * kl_divergence(p, m) + 0.5 * kl_divergence(q, m)

分别测试一下 JS divergence 和 JS distance:

from math import sqrt
# calculate JS(P || Q)
js_pq = js_divergence(p, q)
print('JS(P || Q) divergence: %.3f bits' % js_pq)
print('JS(P || Q) distance: %.3f' % sqrt(js_pq))

# calculate JS(Q || P)
js_qp = js_divergence(q, p)
print('JS(Q || P) divergence: %.3f bits' % js_qp)
print('JS(Q || P) distance: %.3f' % sqrt(js_qp))

JS(P || Q) divergence: 0.420 bits

JS(P || Q) distance: 0.648

JS(Q || P) divergence: 0.420 bits

JS(Q || P) distance: 0.648

 运行结果可以看出 JS divergence 和 JS distance 都是对称的.

同样地, SciPy 提供了 jensenshannon() 用于计算 JS distance:

# calculate the jensen-shannon distance metric
from scipy.spatial.distance import jensenshannon
from numpy import asarray
# define distributions
p = asarray([0.10, 0.40, 0.50])
q = asarray([0.80, 0.15, 0.05])
# calculate JS(P || Q)
js_pq = jensenshannon(p, q, base=2)
print('JS(P || Q) Distance: %.3f' % js_pq)
# calculate JS(Q || P)
js_qp = jensenshannon(q, p, base=2)
print('JS(Q || P) Distance: %.3f' % js_qp)

JS(P || Q) Distance: 0.648

JS(Q || P) Distance: 0.648

你可能感兴趣的:(deep,learning,机器学习)