翻译自Meta-Learning in 50 Lines of JAX
作者:Eric Jang
译者:尹肖贻
本文全部资料,位于Github地址:https://github.com/ericjang/maml-jax
说到适应环境,人与百兽皆同,用时有长有短。时间短的,就像我第一次摆弄洗浴用的温度把手,三两下就学会了;时间长的,就像学习演奏一种乐器,要想熟练演奏,终生都得刻意练习。
除了复杂的动物,各类低等生物也常表现出适应环境的行为。多细胞培养过程中,常可观察到高度灵活的适应行为,甚至在代际间积累了表观遗传学意义上的“记忆”。在较长的时间尺度上,进化本身可以被认为是种群级别的“学习”,即优势基因的代际传承;在较短的时间尺度上,粒子的能级跃迁也可被认为是激励下的“学习”,即在为了适应环境而做出反应。生物学家常有意地模糊“行为”(对环境的反应)、“学习”(从外界获取信息以提高适应度)和“优化”(提高适应度)的界限。
机器学习(ML)的核心问题,在于教会计算机获得自主使用数据的能力,去完成人类难以明确说明的任务。然而大多数机器学习专家所谓的“学习”,只是生物适应环境的行为里的很小的子集。深度学习模型很强大,但常须汗牛充栋的数据、擢发难数的梯度反传迭代才能训练除了。虽然学习过程旷日持久,但是模型的行为依旧呆板。在开发周期中,若想改变系统的输出(如更改一个错误),非得使用昂贵的重启手续不可。能不能设计一个训练更快、训练数据更少的系统呢?
“元学习”正是解决这个问题的热点课题。该课题的目的在于,不仅仅要模型“预测得好”,还要“学习得好”。虽然元学习在最近几年吸引了大量研究者,相关的问题和算法早已有之。(对于相关概念感兴趣的读者,请参看Hugo Larochelle的PPT和Lilian Weng的博文,二者对此有精彩的梳理)
本文不是介绍元学习系统方方面面的综述,而是一份开启你元学习研究的实用教程。特别是我要教你利用谷歌卓越的JAX库,用50行python代码,搭建元学习算法系统MAML。
读者可以自行下载Jupyter notebook版本的内容自洽的教程,复现本教程的结果。
从学习算子(learning operator)的角度,理解元学习
“元学习”这个词被研究人员严重滥用,以至于我在跟其他同行提到“元学习”的时候,很难在交流中保持统一。这一滥用现象的源头,在于一些术语定义含混,如“优化”、“学习”、“适应”、“记忆”,更不要说这些术语在应用场景下的泛滥使用。
这一节我试着用数学的角度定义“学习”和“元学习”,并解释一下为什么最近一大票不同的算法都打上“元学习”的牌子。要是你想要直接上手MAML+JAX代码的学习,请直接跳过这一节。
我们定义学习算子为针对某种场景的函数的算法,用以提升的表现效果。普通的学习算子,一般应用在深度学习和增强学习中,定义为针对某个损失函数的梯度下降算法。在典型的深度学习场景中,学习步骤往往持续几千乃至几百万次梯度更新;但在更一般的场景中,“学习”既可以发生在较短的时间尺度上(如求条件概率),也可以发生在更长的时间尺度上(如超参数的搜索)。除了显式的优化过程,学习还可以隐式地出现在动态系统中(如自回归神经网络RNN学习当下的条件概率)或概率推断中。
元学习算子(meta-learning operator)定义为两个学习算子嵌套的组合算子:“内环”和“外环”。进一步来说,是模型本身,为内环学习的算子。或者这么说,学习的学习规则,学习具体任务的规则。我们定义“任务”为这样一族逻辑自洽的问题,它们的可以充分更新。在元学习训练中,在的参数中选择对于一众任务最合适的一组;在元学习测试中,我们评估和的泛化能力是否支持不同的任务。
对于和的选择要依具体问题而定。在架构搜索(architecture search)的语境下(也称“学习如何学习”),的网络从零开始训练的过程相对缓慢,这时可以做神经控制器(neural controler),或者叫随机搜索算法,或者高斯过程搜索(Gaussian Process Bandit)。
可被称为元学习算子的机器学习问题有很多。在模仿式(元)学习(meta imitation learning)(或称为条件概率的目标强化学习(goal-conditioned reinforcement learning))中,的选择依据强化学习的代理人(agent)的操作反馈,比如针对在某项任务的数学表示(task embedding)的条件下、或在某种人为设定的场景(human demonstration)条件下的概率反馈。在元强化学习(meta reinforcement learning MRL)中,实现的手段是“快速强化学习(fast reinforcement learning)”算法,在其中代理人通过数次试错来优化自身策略。这里值得重申,强化学习的场景下,“学习(learning)”与“条件概率优化(conditioning)”没有区别,因为二者都要依赖测试时的输入(或称“环境提供的新信息”)。
MAML是一类通过随机梯度下降实现的元学习算法。形式化为:。随机梯度下降更新对于来说是可导的,这样就不必在优化的梯度反传时,使用额外的参数表示。
探索JAX:梯度
我们以JAX的即来即用的numpy库和梯度算符grad开始这个教程吧。
import jax.numpy as np
from jax import grad
梯度算符grad将一个python函数转化为另一个可求梯度的函数。这里,我们演示如何计算和的一阶导、二阶导、三阶导。
f = lambda x : np.exp(x)
g = lambda x : np.square(x)
print(grad(f)(1.)) # = e^{1}
print(grad(grad(f))(1.))
print(grad(grad(grad(f)))(1.))
print(grad(g)(2.)) # 2x=4
print(grad(grad(g))(2.)) # x=2
print(grad(grad(grad(g)))(2.)) #x=0
探索JAX:自向量化函数vmap
现在我们考虑一个简单的回归案例,我们去拟合函数,目的是熟悉怎样定义和训练网络。JAX设置了一些轻量级的函数,方便搭建简单的网络
from jax imort vmap # for auto-vectorizing functions
from functools import partial # for use with vmap
from jax import jit # for compiling fuctions for speedup
from jax.experimental import stax # neural network library
from jax.experimental.stax import Conv, Dense, MaxPool, Relu, Flatten, LogSoftmax # neural network layers
import matplotlib.pyplot as plt # visualization
我们将定义一个两层隐含层的神经网络,定义 in_shape为(-1,1),意思是可变的batchsize,而特征的维度是1(因为这是一个一维的回归问题)。JAX的工具箱提供的API全部是泛函形式的(这与TensorFlow不同,后者保持了图结构),所以我们返回了一个初始化参数的函数和一个前传网络。这些函数都是可调用的numpy序列的元组的列表(lists of tuples of numpy arrays)——一种存储网络参数的简单易用的数据结构。
# 使用stax来初始化或评估网络参数
net_init, net_apply = stax.serial(
Dense(40), Relu(),
Dense(40), Relu(),
Dense(1)
)
in_shape = (-1, 1, )
out_shape, net_params = net_init(in_shape)
然后,我们定义模型在一个batch的输入数据的平均平方误差(Mean-square Error MSE)损失。
def loss(params, inputs, targets):
# 计算一个batch的平均损失
predictions = net_apply(params, inputs)
return np.mean((targets - predictions)**2)
我们评估未初始化的网络在输入下的结果:
# 将K=100个输入成批推断
xrange_inputs = np.linspace(-5, 5, 100).reshape(100, 1) #(k,1)
targets = np.sin(xrange_inputs)
predictions = vmap(partial(net_apply, net_params))(xrange_inputs)
losses = vmap(partial(loss, net_params))(xrange_inputs, targets) # per-input loss
plt.plot(xrange_inputs, predictions, label='prediction')
plt.plot(xrange_inputs, losses, label='loss')
plt.plot(xrange_inputs, targets, label='targets')
plt.legend()
正如预期的那样,在随机初始化下,模型的预测(蓝线)完全偏离了目标函数(绿线)。
我们用梯度下降算法来更新参数。JAX的随机函数的产生器和numpy的随机函数产生器不同,所以用numpy的产生器(onp)来随机化网络参数。我们要引入tree_multimap函数来管理参数的梯度(对于TensorFlow的用户,这个函数类似于nest.map_stucture的张量的函数)
import numpy as onp
from jax.experimental import optimizers
from jax.tree_util import tree_multimap
# 对numpy的array的集合进行Element-wise级别的操作
我们初始化参数和优化器,将曲线拟合操作循环100次。值得注意,@jit这个修饰器,能将一整个训练函数(和优化器、内存和代码优化一起)都用上XLA编译成机器码。TensorFlow也使用XLA来加速统计类定义的网络。XLA使得计算非常快,和硬件的兼容性很强,因为它不需要返回一个Python解释器(或者在没有XLA时TensorFlow返回计算图解释器)。这里的代码只能运行在CPU、GPU或者TPU上。
opt_init, opt_update = optimizers.adm(step_size=1e-2)
opt_state = opt_init(net_params)
# 定义一个编译的更新步骤
@jit
def step(i, opt_state, x1, y1):
p = optimizers.get_params(opt_state)
g = grad(loss)(p, x1, y1)
return opt_update(i, g, opt_state)
for i in range(100):
opt_state = step(i, opt_state, xrange_inputs, targets)
net_params = optimizers.get_params(opt_state)
重新执行绘图代码
predictions = vmap(partial(net_apply, net_params))(xrange_inputs)
losses = vmap(partial(loss, net_params))(xrange_inputs, targets) # per-input loss
plt.plot(xrange_inputs, predictions, label='prediction')
plt.plot(xrange_inputs, losses, label='loss')
plt.plot(xrange_inputs, targets, label='target')
plt.legend()
在下面MAML的代码里,我们将反复用到上文提到的函数。
探索JAX:用数值检查MAML
在完成机器学习算法的代码时,一定要通过单元测试,测试案例的结果必须可以通过分析的方法得出真值。下面的例子对于toy目标函数做了测试代码。值得注意的是,默认情况下,JAX会对函数的第一个变量计算梯度。
# MAML的梯度检查
# 检查数值
g = lambda x, y: np.square(x) + y
x0 = 2.
y0 = 1.
print('grad(g)(x0) = {}'.format(grad(g)(x0, y0))) # 2x = 4
print('x0 - grad(g)(x0) = {}'.format(x0 - grad(g)(x0, y0))) # x - 2x = -2
def maml_objective(x, y):
return g(x - grad(g)(x, y), y)
# x**2 + 1 = 5
print('maml_objective(x,y)={}'.format(maml_objective(x0, y0)))
# x - (2x) = -2.
print('x0 - maml_objective(x,y) = {}'.format(x0 - grad(maml_objective)(x0, y0)))
用JAX编写MAML
现在我们拓展一下回归正弦曲线的例子,让正弦函数的相位和幅度都可以变化。这个例子是MAML论文里提到的简单示例。下图是从两个不同的任务中采样的点,每一个任务都有训练集(用来计算内部损失)和验证集(用来计算同一任务的外部损失)。
设 为损失函数,网络参数为 ,输入向量 ,输入标签 。设 和 为数据分布 中同分布采样的数据样本。则MAML优化的目标是:
MAML的内环优化算子通过梯度下降优化回归损失。外环损失maml_loss可以简单地通过使用内环算子后再优化总体损失。对于MAML的目标可以这样解释:它是对学习器交叉验证损失的可导的估计。元学习的结果
def inner_update(p, x1, y1, alpha=.1):
grads = grad(loss)(p, x1, y1)
inner_sgd_fn = lambda g, state: (state - alpha * g)
return tree_multimap(inner_sgd_fn, grads, p)
def maml_loss(p, x1, y1, x2, y2):
p2 = inner_update(p, x1, y1)
return loss(p2, x2, y2)
在每轮迭代中,我们都重新采样一个不同的数据集,训练和验证时输入的特征和标签是同一批。
opt_init, opt_update = optimizers.adam(step_size=1e-3)
# 这个学习率好于1e-2和1e-4
out_shape, net_params = net_init(in_shape)
opt_state = opt_init(net_parameters)
@jit
def step(i, opt_state, x1, y1, x2, y2):
p = optimizers.get_params(opt_state)
g = grad(maml_loss)(p, x1, y1, x2, y2)
l = maml_loss(p, x1, y1, x2, y2)
return opt_update(i, g, opt_state), l
K = 20
np_maml_loss = []
# Adam 优化
for i in range(20000):
# 定义任务
A = onp.random.uniform(low=0.1, high=.5)
phase = onp.random.uniform(low=0., high=np.pi)
# 元学习内环交叉验证
x1 = onp.random.uniform(low=-5., high=5., size=(K, 1))
y1 = A * onp.sin(x1 + phase)
# 元学习外环交叉验证,针对一个单独的样本
x2 = onp.random.uniform(low=-5., high=5.)
y2 = A * onp.sin(x2 + phase)
opt_state, l = step(i, opt_state, x1, y1, x2, y2)
np_maml_loss.append(l)
if i % 1000 == 0:
print(i)
net_params = optimizers.get_params(opt_state)
在元学习过程中,网络学着怎样快速地匹配 和 ,以减小交叉验证时在 的损失。在测试阶段(上图中的曲线所示),当我们遇到新任务(即新的振幅和相位的正弦函数,在训练阶段没有见过)时,模型利用inner_update算子来拟合目标值,这比从零开始用SGD训练数据要快得多。
为啥用inner_update比用SGD重新训练模型快呢?见证奇迹的时刻,当当当当:在多任务训练的场景下,inner_update算子能够在不同任务之间作泛化(generalization)。在上面的例子里,多个任务分别拟合了正弦函数的回归形式。在典型的深度学习场景下,泛化能力是从单一任务的不同样本中得到的(例如增强学习、图像分类等)。在元学习中,泛化能力是从多个任务中得到的,尽管每个任务的样本量是稀少的,但是共享的学习规则在不同的数据分布上是通用的。
# 推断的批数据数量设为100
targets = np.sin(xrange_inputs)
predictions = vmap(partial(net_apply, net_params))(xrange_inputs)
plt.plot(xrange_inputs, predictions, label='pre-update predictions')
plt.plot(xrange_inputs, targets, label='targets')
x1 = onp.random.uniform(low=-5., high=5., size=(K,1))
y1 = 1. * onp.sin(x1 + 0.)
for i in range(1, 5):
net_params = inner_update(net_params, x1, y1)
predictions = vmap(partial(net_apply, net_params))(xrange_inputs)
plt.plot(xrange_inputs, predictions, label='{}-shot predictions'.format(i))
plt.legend()
使用vmap实现批训练MAML的多任务梯度下降
到目前为止,我们可以用MAML算法,一次性地在多个任务上计算梯度,来减小学习算子的梯度的方差。这一技术受了SGD算法增加minibatch来减小参数梯度的方差的启发(这个技术让学习更加高效),在MAML的论文中被首次提到的。
由于vmap算子,我们能够自动地将单任务MAML的实现,改编成跨任务的“批次版本”。从软件工程和软件测试的角度来说,vmap非常好用,因为“批任务”MAML的实现可以重用非分批的代码,并且没有丢掉向量化的优势。这就意味着我们可以用单独的测试代码实现算法,验证好正确性,而后方便地扩展任务数量就可以实现批量训练的版本了。(比如应付更加复杂的机器人学习)
# vamp版本的maml损失
# 返回一个所有任务的实数
def batch_maml_loss(p, x1_b, y1_b, x2_b, y2_b):
task_losses = vmap(partial(maml_loss, p))(x1_b, y1_b, x2_b, y2_b)
return np.mean(task_losses)
下面就是一批不同的任务, outer_batch_size是元学习的训练的任务数,inner_batch_size是每个任务的数据点的个数。
def sample_tasks(outer_batch_size, inner_batch_size):
# 选择振幅和相位
As = []
phases = []
for _ in range(outer_batch_size):
As.append(onp.random.uniform(low=.1, high=.5))
phases.append(onp.random.uniform(low=0., high=np.pi))
def get_batch():
xs, ys = [], []
for A, phase in zip(As, phases):
x = onp.random.uniform(low=-5., high=5., size=(inner_batch_size, 1))
y = A * onp.sin(x + phase)
xs.append(x)
ys.append(y)
return np.stack(xs), np.stack(ys)
x1, y1 = get_batch()
x2, y2 = get_batch()
return x1, y1, x2, y2
现在,我们改写训练循环, 和刚才的代码是高度近似的。正如你所见,基于梯度的元学习需要应付两种方差:内环的损失的方差,和外环损失的方差。
opt_init, opt_update = optimizers.adam(step_size=1e-3)
out_shape, net_params = net_init(in_shape)
opt_state = opt_init(net_parameters)
# vmap版maml批训练损失
# 返回所有任务的损失均值
def batch_maml_loss(p, x1_b, y1_b, x2_b, y2_b):
task_losses = vmap(partial(maml_loss, p))(x1_b, y1_b, x2_b, y2_b)
return np.mean(task_losses)
@jit
def step(i, opt_state, x1, y1, x2, y2):
p = optimizers.get_params(opt_state)
g = grad(maml_loss)(p, x1, y1, x2, y2)
l = batch_maml_loss(p, x1, y1, x2, y2)
return opt_update(i, g, opt_state), l
K = 20
np_maml_loss = []
for i in range(20000):
x1_b, y1_b, x2_b, y1_b = sample_tasks(4, K)
opt_state, l = step(i, opt_state, x1_b, y1_b, x2_b, y2_b)
np_maml_loss.append(l)
if i % 1000 == 0:
print(i)
net_params = optimizers.get_params(opt_state)
当我们将训练步数和损失画成曲线图,我们发现分批的MAML训练收敛更快,并且在训练中有更小的梯度方差。
结论
在本教程中,我们研究了MAML算法,并用大约50行Python代码重现了原文中的正弦回归任务。我很高兴地发现,grad、vmap和jit实现MAML非常容易,它们将继续用于我的元学习研究。
那么,“优化”、“学习”、“适应”和“记忆”之间有什么区别呢?我认为它们是等效的,因为使用优化技术(MAML)实现记忆功能是可能的,反之亦然(例如基于RNN的元增强学习)。在强化学习中,模仿教师网络、或根据用户指定的目标进行调节、或从失败中恢复,都可以使用相同的机制。
思考“学习”和“元学习”的精确定义,并尝试将它们与生物智能相对应,使我认识到生命活动的每一个过程,都可归结于不同层面的学习行为:从分子层面的化学反应,到物种层面的遗传进化,行为适应存在于每个时间尺度。在未来我将对人造生命和机器学习的话题做更多阐述,但现在,是时候结束了。感谢您阅读本篇拟合正弦函数的简单教程!
致谢
感谢Matthew Johnson帮助校对本文,并解决了一些有关JAX的问题。