Pyro简介:产生式模型实现库(二),推断

我们仍然以一个例子来说明Pyro的推断功能。首先,我们引入头文件。

import matplotlib.pyplot as plt
import numpy as np
import torch

import pyro
import pyro.infer
import pyro.optim
import pyro.distributions as dist

pyro.set_rng_seed(101)

例子:测量物体的重量

假如我们要测量物体的重量,而秤并不怎样精确,每每测量的结果存在稍许差异。为了补偿秤本身的误差,我们要把过程的“噪声”(即造成误差的不明因素)积分处理。下面的过程描述了数据产生的原理:


实现代码:

def scale(guess):
    weight = pyro.sample('weight', dist.Normal(guess, 1.))
    return pyro.sample('measurement', dist.Normal(weight, .75))

条件概率

上述是“正向”的数据产生过程,建模过程是很容易的。在实际的场景里,我们却只能通过观察数据,来“反推”数据的产生过程。Pyro中,产生数据的状态,是用sample()来实现的。
考虑scale,假如我们给定guess = 8.5作为输入,并且观察到measurement==9.5这样的样本,我们希望了解weight的分布范围是多少,即

Pyro提供了pyro.condition来限制采样的状态。pyro.condition是一个“高阶函数”,即输入一个模型函数和一份观察值的字典,返回一个基于观察的新模型函数。

conditioned_scale = pyro.condition(scale, data={'measurement': 9.5})

这和Python的其他函数是一样的。我们可以用lambdadef的方法重写上面的句子:

def deferred_conditioned_scale(measurement, guess):
    return pyro.condition(scale, data={'measurement':  measurement})(guess)

还有一种更省事的写法,用obs这一关键字来提示pyro.condition观察值的情况。

def scale_obs(guess): # 该函数与 conditioned_scale是等价的
    weight = pyro.sample('weight', dist.Normal(guess, 1.))
    # 条件为给定观察值9.5
    return pyro.sample('measurement', dist.Normal(weight, 1.), obs=9.5)

多提一句,Pyro在pyro.condion中也集成了朱迪亚贝尔的“执行”命令pyro.do

用guide函数,灵活地推断

conditioned_scale函数里,我们在给定guessmeasurement==data的条件下,对weight进行推断。
推断算法在Pyro框架里,如pyro.infer.SVI,被定义在pyro.infer类中。对于被推断的任何随机函数,我们称其为guideguides,用来表示后验分布的近似结果。guide函数需要满足两个条件:

  1. 所有的独立变量(它们不依赖于其他随机变量),在model中出现的,也必在guide中出现。
  2. guide与model具有相同的参数(argument)。

guide在多种场景下发挥作用,如重要采样、拒绝采样、序列蒙特卡洛采样、MCMC、独立Metropolis-Hastings采样、变分推断、推断网络,等等。现在已经在Pyro完成封装的,有重要采样、MCMC、变分推断。在未来其余场景也将陆续完成。
虽然在不同场景下,guide可以灵活规定,原则上我们需要在guide中涵盖独立变量的完整采样过程。
scale中,给定guessmeasurement后,其后验概率为。由于这个例子比较简单,我们可以手算其后验概率的形式。(感兴趣的读者请参阅:http://www.stat.cmu.edu/~brian/463-663/week09/Chapter%2003.pdf)

def perfect_guide(guess):
    # sigma=0.75,tau=1,n=1,x=9.5,M=guess=8.5
    loc = (.75 ** 2 * guess + 9.5) / (1 + .75 ** 2) # 9.14
    scale = np.sqrt(.75 ** 2 / (1 + .75 ** 2)) # 0.6
    return pyro.sample('weight', dist.Normal(loc, scale))

从参数化的随机函数,到变分推断

上面的例子中,我们计算出了精确的后验概率分布。这是一种极为幸运的情况,而非通例。哪怕仍旧用scale这个简单的例子,如果weight经过某种非线性操作,后验分布就不再具有精确解了。

def scale(guess):
    weight = pyro.sample('weight', dist.Normal(guess, 1.))
    return pyro.sample('measurement', dist.Normal(some_nonlinear_function(weight), .75))

这时,我们需要重新估计一个函数,它的采样结果能最大程度地符合观察结果,或使某一损失函数最小化,这一过程叫做变分推断。在Pyro中,我们利用pyro.param来具体化guides函数的可选范围。
pyro.param是Pyro的键值对组成的容器。和pyro.sample一样,pyro.param通过第一个参数来命名。第一次声明pyro.sample的名字,容器中就会存储这个参数的名字和值,在以后再次调用时返回它的值。这个过程就像下面的sample_param_store.setdefault一样。

simple_param_store = {}
a = simple_param_store('a', torch.randn(1))

举个例子,我们要在scale_posterior_guide中,参数化ab,而非人工实例化它们:

def scale_parmeterized_guide(guess):
    a = pyro.param('a', torch.tensor(guess))
    b = pyro.param('b', torch.tensor(1.))
    return pyro.sample('weight', dist.Normal(a, torch.abs(b)))

插句题外话,上面的b加上了torch.abs函数,是因为正态分布的标准差必须是非负数。我们也可以通过Pytorch的constraint module来明确规定这一限制。

from torch.distributions import constraints

def scale_parameterized_guide_constrained(guess):
    a = pyro.param('a', torch.tensor(guess))
    b = pyro.param('b', torch.tensor(1.), constraint=constrains.positive)
    return pyro.sample('weight', dist.Normal(a, b)) # 不再需要 torch.abs

话说回来。Pyro这个代码库的最直接目的,就是执行随机变分推断(SVI)。这类操作包含下面三个特点:

  1. 参数都是实值张量
  2. 通过model和guide的执行历史,采样并计算得到损失函数的蒙特卡洛估计
  3. 通过梯度下降法,搜索最优的参数值

结合Pytorch的GPU加速和自动求导机制,Pyro能够在高维参数空间高效完成变分推断。在后面的教程中,我们会详细介绍。这里给出一个简单的例子:

guess = 8.5
pyro.clear_param_store()
svi = pyro.infer.SVI(model=conditioned_scale, 
                     guide=scale_parameterized_guide, 
                     optim=pyro.optim.SGD({'lr':0.001, 'momentum':0.1}),
                     loss=pyro.infer.Trace_ELBO())
losses, a, b = [], [], []
num_steps = 2500
for t in range(num_steps):
    losses.append(svi.step(guess))
    a.append(pyro.sample('a').item())
    b.append(pyro.sample('b').item())

plt.plot(losses)
plt.title('ELBO')
plt.xlabel('step')
plt.ylabel('loss')
print('a = ', pyro.sample('a').item())
print('b = ', pyro.sample('b').item())

a = 9.107474327087402
b = 0.6285384893417358

Pyro简介:产生式模型实现库(二),推断_第1张图片
plt.subplot(1, 2, 1)
plt.plot([0, num_steps], [9.14, 9.14], 'k:')
plt.plot(a)
plt.ylabel('a')

plt.subplot(1,2,2)
plt.plot([0, num_steps], [0.6, 0.6], 'k:')
plt.plot(b)
plt.ylabel('b')
plt.tight_layout()
Pyro简介:产生式模型实现库(二),推断_第2张图片

由图可见,SVI的推断值,与真值是相当接近的。这正是我们所希望的。
应该注意的是,guide的参数优化过程,被存放在参数容器中。当我们需要做后验采样时,我们可以直接从guide中采样,为下游的任务所利用。

接下来的教程,我们将使用神经网络来构建scale函数,并用随机变分推断的方法构建图像的产生式模型,敬请期待。

你可能感兴趣的:(Pyro简介:产生式模型实现库(二),推断)