我们仍然以一个例子来说明Pyro的推断功能。首先,我们引入头文件。
import matplotlib.pyplot as plt
import numpy as np
import torch
import pyro
import pyro.infer
import pyro.optim
import pyro.distributions as dist
pyro.set_rng_seed(101)
例子:测量物体的重量
假如我们要测量物体的重量,而秤并不怎样精确,每每测量的结果存在稍许差异。为了补偿秤本身的误差,我们要把过程的“噪声”(即造成误差的不明因素)积分处理。下面的过程描述了数据产生的原理:
实现代码:
def scale(guess):
weight = pyro.sample('weight', dist.Normal(guess, 1.))
return pyro.sample('measurement', dist.Normal(weight, .75))
条件概率
上述是“正向”的数据产生过程,建模过程是很容易的。在实际的场景里,我们却只能通过观察数据,来“反推”数据的产生过程。Pyro中,产生数据的状态,是用sample()
来实现的。
考虑scale
,假如我们给定guess = 8.5
作为输入,并且观察到measurement==9.5
这样的样本,我们希望了解weight
的分布范围是多少,即
Pyro提供了pyro.condition
来限制采样的状态。pyro.condition
是一个“高阶函数”,即输入一个模型函数和一份观察值的字典,返回一个基于观察的新模型函数。
conditioned_scale = pyro.condition(scale, data={'measurement': 9.5})
这和Python的其他函数是一样的。我们可以用lambda
或def
的方法重写上面的句子:
def deferred_conditioned_scale(measurement, guess):
return pyro.condition(scale, data={'measurement': measurement})(guess)
还有一种更省事的写法,用obs
这一关键字来提示pyro.condition
观察值的情况。
def scale_obs(guess): # 该函数与 conditioned_scale是等价的
weight = pyro.sample('weight', dist.Normal(guess, 1.))
# 条件为给定观察值9.5
return pyro.sample('measurement', dist.Normal(weight, 1.), obs=9.5)
多提一句,Pyro在pyro.condion
中也集成了朱迪亚贝尔的“执行”命令pyro.do
。
用guide函数,灵活地推断
在conditioned_scale
函数里,我们在给定guess
和measurement==data
的条件下,对weight
进行推断。
推断算法在Pyro框架里,如pyro.infer.SVI
,被定义在pyro.infer
类中。对于被推断的任何随机函数,我们称其为guide或guides,用来表示后验分布的近似结果。guide函数需要满足两个条件:
- 所有的独立变量(它们不依赖于其他随机变量),在model中出现的,也必在guide中出现。
- guide与model具有相同的参数(argument)。
guide在多种场景下发挥作用,如重要采样、拒绝采样、序列蒙特卡洛采样、MCMC、独立Metropolis-Hastings采样、变分推断、推断网络,等等。现在已经在Pyro完成封装的,有重要采样、MCMC、变分推断。在未来其余场景也将陆续完成。
虽然在不同场景下,guide可以灵活规定,原则上我们需要在guide中涵盖独立变量的完整采样过程。
在scale
中,给定guess
和measurement
后,其后验概率为。由于这个例子比较简单,我们可以手算其后验概率的形式。(感兴趣的读者请参阅:http://www.stat.cmu.edu/~brian/463-663/week09/Chapter%2003.pdf)
def perfect_guide(guess):
# sigma=0.75,tau=1,n=1,x=9.5,M=guess=8.5
loc = (.75 ** 2 * guess + 9.5) / (1 + .75 ** 2) # 9.14
scale = np.sqrt(.75 ** 2 / (1 + .75 ** 2)) # 0.6
return pyro.sample('weight', dist.Normal(loc, scale))
从参数化的随机函数,到变分推断
上面的例子中,我们计算出了精确的后验概率分布。这是一种极为幸运的情况,而非通例。哪怕仍旧用scale
这个简单的例子,如果weight
经过某种非线性操作,后验分布就不再具有精确解了。
def scale(guess):
weight = pyro.sample('weight', dist.Normal(guess, 1.))
return pyro.sample('measurement', dist.Normal(some_nonlinear_function(weight), .75))
这时,我们需要重新估计一个函数,它的采样结果能最大程度地符合观察结果,或使某一损失函数最小化,这一过程叫做变分推断。在Pyro中,我们利用pyro.param
来具体化guides函数的可选范围。
pyro.param
是Pyro的键值对组成的容器。和pyro.sample
一样,pyro.param
通过第一个参数来命名。第一次声明pyro.sample
的名字,容器中就会存储这个参数的名字和值,在以后再次调用时返回它的值。这个过程就像下面的sample_param_store.setdefault
一样。
simple_param_store = {}
a = simple_param_store('a', torch.randn(1))
举个例子,我们要在scale_posterior_guide
中,参数化a
和b
,而非人工实例化它们:
def scale_parmeterized_guide(guess):
a = pyro.param('a', torch.tensor(guess))
b = pyro.param('b', torch.tensor(1.))
return pyro.sample('weight', dist.Normal(a, torch.abs(b)))
插句题外话,上面的b
加上了torch.abs
函数,是因为正态分布的标准差必须是非负数。我们也可以通过Pytorch的constraint module来明确规定这一限制。
from torch.distributions import constraints
def scale_parameterized_guide_constrained(guess):
a = pyro.param('a', torch.tensor(guess))
b = pyro.param('b', torch.tensor(1.), constraint=constrains.positive)
return pyro.sample('weight', dist.Normal(a, b)) # 不再需要 torch.abs
话说回来。Pyro这个代码库的最直接目的,就是执行随机变分推断(SVI)。这类操作包含下面三个特点:
- 参数都是实值张量
- 通过model和guide的执行历史,采样并计算得到损失函数的蒙特卡洛估计
- 通过梯度下降法,搜索最优的参数值
结合Pytorch的GPU加速和自动求导机制,Pyro能够在高维参数空间高效完成变分推断。在后面的教程中,我们会详细介绍。这里给出一个简单的例子:
guess = 8.5
pyro.clear_param_store()
svi = pyro.infer.SVI(model=conditioned_scale,
guide=scale_parameterized_guide,
optim=pyro.optim.SGD({'lr':0.001, 'momentum':0.1}),
loss=pyro.infer.Trace_ELBO())
losses, a, b = [], [], []
num_steps = 2500
for t in range(num_steps):
losses.append(svi.step(guess))
a.append(pyro.sample('a').item())
b.append(pyro.sample('b').item())
plt.plot(losses)
plt.title('ELBO')
plt.xlabel('step')
plt.ylabel('loss')
print('a = ', pyro.sample('a').item())
print('b = ', pyro.sample('b').item())
a = 9.107474327087402
b = 0.6285384893417358
plt.subplot(1, 2, 1)
plt.plot([0, num_steps], [9.14, 9.14], 'k:')
plt.plot(a)
plt.ylabel('a')
plt.subplot(1,2,2)
plt.plot([0, num_steps], [0.6, 0.6], 'k:')
plt.plot(b)
plt.ylabel('b')
plt.tight_layout()
由图可见,SVI的推断值,与真值是相当接近的。这正是我们所希望的。
应该注意的是,guide的参数优化过程,被存放在参数容器中。当我们需要做后验采样时,我们可以直接从guide中采样,为下游的任务所利用。
接下来的教程,我们将使用神经网络来构建scale
函数,并用随机变分推断的方法构建图像的产生式模型,敬请期待。