技术分享 | 能微分会加速的 NumPy —— JAX

目录

# 使用介绍 #

# 自动微分

# vmap 和 pmap

# JIT 编译

# 内部实现 #

# Trace 变换

# Jaxpr:JAX 中间表达式

# 总结 #

参考


JAX [1] 是 Google 推出的可以对 NumPy 和 Python 代码进行自动微分并跑到 GPU/TPU(Google 自研张量加速器)加速的机器学习库。Numpy [2] 是 Python 著名的数组运算库,官方版本只支持 CPU 运行(后面 Nvidia 推出的 CuPy 支持 GPU 加速,这里按住不表)。JAX 前身是 AutoGrad [3],2015 年哈佛大学来自物理系和 SEAS(工程与应用科学学院)的师生发表论文推出的支持 NumPy 程序自动求导的机器学习库。AutoGrad 提供和 NumPy 库一致的编程接口,用户导入 AutoGrad 就可以让原来写的 NumPy 程序拿来求导。JAX 在 2018 年将 XLA [4](Tensorflow 线性代数领域编译器)引入进来,使得 Python 程序可以通过 XLA 编译跑到 GPU/TPU 加速器上。简单地理解 JAX = NumPy + AutoGrad + XLA。可以说,XLA 加持下的 JAX,才真正具备了实施深度学习训练的基础和能力。

JAX Github: https://github.com/google/jax

JAX API Docs: https://jax.readthedocs.io/en/latest/


# 使用介绍 #

# 自动微分

JAX 提供兼容 NumPy 风格的接口,照顾用户原 NumPy 编程习惯。JAX 面向 Python 用户提供自动微分接口,包括生成梯度函数、求导等。

例 1:使用 jax.grad() 求导

from jax import grad

def f(x):
  return x * x * x

D_f = grad(f) # 3x^2
D2_f = grad(D_f) # 6x
D3_f = grad(D2_f) # 6

f(1.0) # 1.0
D_f(1.0) # 3.0
D2_f(1.0) # 6.0
D3_f(0.0) # 6.0 (always)

jax.grad:只接受输出标量的原始函数 f,生成对应的梯度函数 ▼f▼f 接受和原始函数一样的入参 x,输出为参数梯度 dx▼f 亦可被 grad(),相当于对原始函数计算高阶梯度,但需满足一样的要求:输出为标量。如果被求导的函数计算结果不止一个数值,不能直接传给 grad()。需要先 reduce 成一个标量。

例 2:对数组函数求导

from jax import numpy as np
from jax import grad
import matplotlib.pyplot as plt

def f(x):
    return x * x * x

D_f = grad(lambda x: np.sum(f(x)))
D2_f = grad(lambda x: np.sum(D_f(x)))
D3_f = grad(lambda x: np.sum(D2_f(x)))

x = np.linspace(-1, 1, 200)
plt.plot(x, f(x), x, D_f(x), x, D2_f(x), x, D3_f(x))
plt.show()

和例 1 相比主要区别在于例 2 分别对函数(fD_fD2_f)结果进行求和(sum)再求导。函数 f 和它的一阶、二阶、三阶导函数曲线如下图所示。

技术分享 | 能微分会加速的 NumPy —— JAX_第1张图片

JAX 支持不同模式自动微分。grad() 默认采取反向模式自动微分。另外显式指定模式的微分接口有 jax.vjp 和 jax.jvp

  • jax.vjp:反向模式自动微分。根据原始函数 f、输入 x 计算函数结果  y 并生成梯度函数 ▼f▼f 输入是 dy,输出是 dxgrad() 实现上底层调用 vjp(),可看做 vjp() 的一种特例。

  • jax.jvp:前向模式自动微分,根据原始函数 f、输入 x 和 dx 计算结果 y 和 dy。在函数输入参数数量少于或持平输出参数数量的情况下前向模式自动微分比反向模式更省内存,内存利用效率上更具优势 [5]。

jvp() 中微分计算和原始函数计算是同时完成的。多次调用可能存在对原始函数重复计算。JAX 提供前向模式自动微分的缓存优化接口 jax.linearize。该接口根据原始函数 f 和输入 x,计算函数结果 y 并生成导函数 f'。导函数 f' 输入是 dx,输出是 dylinearize() 实现上为前向模式 jvp() 加上 partial evaluation(缓存了原始函数计算过程数据),在内存占用方面更接近于反向模式自动微分,相对于前向模式来说还是比较耗内存的。

为方便对照列出这几个微分接口的形式化信息:

微分接口 类型签名
grad() (a -> b) -> a -> T a
value_and_grad() (a -> b) -> a -> (b, T a)
jvp() ((a -> b), a, T a) -> (b, T b)
vjp() ((a -> b), a) -> (b, (T b -> T a))
linearize() ((a -> b), a) -> (b, (T a -> T b))

vmap 和 pmap

jax.vmap 负责对函数进行向量化,可指定向量化维度。假设要对原始函数 f 进行自动微分,输入输出参数数量持平,并且要多次执行导函数 f'

方式 1:使用 jvp()。前向模式自动微分省内存。但,多次执行 jvp() 意味着多次计算原始函数。慢!

for in_tangent in in_tangents:
  y, out_tangent = jax.jvp(f, (x,), (in_tangent,))

方式 2:使用 linearize(),只计算一次原始函数,比方式 1 的计算效率高。但是,如前面所说,linearize() 内存占用接近反向模式自动微分。耗内存!

y, f_jvp = jax.linearize(f, x)
for in_tangent in in_tangents:
  out_tangent = f_jvp(in_tangent)

方式 3:使用 jvp() 加 vmap()。利用前向模式自动微分省内存的特点优化方式 2 的内存占用问题,同时通过向量化相比方式 1 提高了计算效率。前提是提前已知导函数执行需要的这些输入数据。

pushfwd = partial(jvp, f, (x,))
y, out_tangents = vmap(pushfwd, out_axes=(None, 0))((in_tangents,))

jax.pmap 帮助实现 SPMD(即 single program, multiple data)编程,比如在 GPU 多个卡上并行计算,对用户屏蔽底层通信操作。调用 pmap() 之后,经过 JAX 编译可将数据和计算任务分布到多个设备上执行。

JIT 编译

JAX 通过 XLA 后端对 Python 函数进行 JIT 编译得到优化后的函数。被 JIT 的函数必须是纯函数,其中副作用代码只会执行一次。输入输出参数类型必须满足:数组、标量、容器(tuple、list、dict)的一种。

def f(x):
  return x * x * x
print(jax.jit(f)(1.0)) # 1.0
print(jax.jit(grad(f))(1.0)) # 3.0

JAX 要求参与微分、JIT 的必须是 Python 纯函数。使用 JAX 表达神经网络的计算过程都是由 Python 函数组成。假设 N 个全连接层串行构成神经网络结构。JAX 代码示意如下。

def loss(params, batch):
  inputs, targets = batch
  preds = predict(params, inputs)
  # 损失函数
  return -np.mean(np.sum(preds * targets, axis=1))

def predict(params, inputs):
  activations = inputs
  for w, b in params[:-1]:
    # 前 (N-1) 层分别由带 bias 的矩阵乘 + tanh 激活函数组成
    outputs = np.dot(activations, w) + b
    activations = np.tanh(outputs)
  # 第 N 层由带 bias 的矩阵乘组成
  final_w, final_b = params[-1]
  logits = np.dot(activations, final_w) + final_b
  # 归一层
  return logits - logsumexp(logits, axis=1, keepdims=True)

@jit
def update(params, batch):
  # 求 loss 函数的参数梯度
  grads = grad(loss)(params, batch)
  # 依次分别更新权重和 bias
  return [(w - 0.001 * dw, b - 0.001 * db)
          for (w, b), (dw, db) in zip(params, grads)]

params = ... # 初始化 N 层参数(权重和 bias)
for epoch in range(num_epochs):
  for _ in range(num_batches):
    # 参数随着 epoch 和 batch 迭代变化
    params = update(params, next(batches))
  # 基于最新参数数据评估精度(需要跑一遍前向 predict)
  test_acc = accuracy(params, (test_images, test_labels))

# 内部实现 #

JAX 提供的 grad()jvp()vmap()pmap() 等接口指定原始函数用于变换。最后用于真正执行计算求值的是变换生成的新函数,比如各种微分相关的函数、映射优化的函数。JAX 实现里,这些参与变换的原始函数以动态方式被记录并生成中间表达式。这一过程叫做 trace。JAX 生成的中间表达式叫 Jaxpr。Jaxpr 经过内部解释器执行变换。过程如下图所示 [1]。

技术分享 | 能微分会加速的 NumPy —— JAX_第2张图片

前向模式自动微分、vmap 等情况下,trace 只需要携带少部分上下文信息即可变换生成新函数。但反向模式自动微分等情况下,trace 需要生成 Jaxpr 以记录更多信息,再变换生成新函数。JIT 编译例外,JIT 会生成 Jaxpr,但底层通过 XLA 编译生成二进制代码并运行,不再回到 Python 代码。

# Trace 变换

JAX 提供多种 tracer,包括 jvp tracer、vjp tracer、vmap tracer、jaxpr builder tracer 等。这些 tracer 是在 JAX 代码运行过程中工作的。做 trace 的过程同时是 Python 代码特例化过程。能够被 trace 的都是 JAX 要求导的变量信息和操作信息,其他信息包括 Python 原生控制流、自定义类型变量 & 操作、打印语句等不会被 trace。

JAX 被 trace 的 Python 数组可看做抽象的符号表示,只有类型和 shape 信息,没有具体元素数值。比如调用 JAX 提供的 numpy.sum(),不会立刻触发 sum 计算。直到需要访问 Python 数组数值时才会真正求值,相当于惰性求值(Lazy Evaluation)。

JAX 变换负责对 trace 结果执行求值,求值后得到的 JAX 数组包含具体元素数值。JAX 数组通过 to_py() 可主动转成 NumPy 数组。

# Jaxpr:JAX 中间表达式

Jaxpr 全称 JAX Program Representation,用于待变换的函数的内部表示。Jaxpr 是强类型的,函数式的,定义形式符合 ANF form。引入 Jaxpr 主要有两方面考虑:

  1. JIT 需要对 Python 代码建立这样的中间表示来完成动态编译和计算;

  2. 反向模式自动微分对原始函数进行反向传播。Jaxpr 程序表示也可以帮助实现这一点。

Jaxpr 生成时机分为两种方式,对应的 trace 变换方法亦不同:

  • 求值前生成(偏静态)。通过 trace 生成 Jaxpr。而 trace 和 Jaxpr 执行分别位于不同阶段(类似多阶段计算)。不支持依赖数据的控制流。JIT 适用于这种。

  • 求值时生成(偏动态)。trace 不急于生成 Jaxpr,直到最后变换时刻生成 Jaxpr。变换过程就像正常调用 Python 函数一样。允许依赖数据的控制流。定制 jvp 适用于这种。

Jaxpr 函数是强类型的、纯函数的表示,输入、输出都带有类型信息。函数输出只依赖输入,不依赖全局变量。Jaxpr 变量类型只能是数组、标量、容器(tuple、list、dict)的一种。Jaxpr 定义形式比较简单。

jaxpr ::= { lambda Var* ; Var+.
            let Eqn*
            in  [Expr+] }

Eqn ::= let Var+ = Primitive [Param*] Expr+

JAX 内置常规数学原语和微分规则。

Primitive := add | sub | sin | mul | ...

如想查看函数对应的 Jaxpr,比如自动微分生成的新函数内部形式,可以使用 jax.make_jaxpr

def f(x):
    return x * x

jax.make_jaxpr(f)(1.0)
""" { lambda  ; a.
  let b = mul a a
  in (b,) }
"""
jax.make_jaxpr(jax.grad(f))(1.0)
""" { lambda  ; a.
  let _ = mul a a
      b = mul 1.0 a
      c = mul 1.0 a
      d = add_any b c
  in (d,) }
"""
jax.make_jaxpr(jax.linearize(f, x)[1])(x)
"""
{ lambda a ; b.
  let c = mul b a
      d = mul a b
      e = add_any c d
  in (e,) }
"""

注意 linearize() 生成的函数捕获原始函数 f 的输入 x(亦即 linearize() 函数第二个参数),反映在 Jaxpr 就是 lambda a; 的变量捕获。


# 总结 #

JAX 在机器学习开发上以 NumPy 库 API 作为切入点,提供 AI 需要的自动微分和 JIT 编译加速功能。JAX 编程上贴近 Python 原生语法,体验类似 PyTorch。但与 PyTorch 明显不同的是,JAX 把 Python 代码限制到纯函数。PyTorch 对 tensor 对象求导,而 JAX 选择对函数求导,包括 NumPy 函数和其他 Python 函数。JAX 自动微分除了反向模式,还提供前向模式以及高阶混合模式微分。为了平衡自动微分计算和内存墙的问题,支持按需 checkpoint。

JAX 内部表示是纯函数式的,但考虑到 Python 语言高度动态性特点,对用户使用上有一些编程限制。比如 JAX 自动微分的 Python 函数只支持纯函数,要求用户自行保证这一点。如用户代码写了副作用,可能经过 JAX 变换生成的函数执行结果不符合期望。因 JAX trace 函数为纯函数,当全局变量、配置信息发生变化,可能需要重新 trace。JAX 只能对固定类型进行微分求导,不支持自定义类型如 class 变量等分析和微分求导。JAX 最新版本是 v0.2。相对于其他 Python AI 框架,JAX 用户数和受到的关注度偏少。最近大热的 Alpha Fold2 开源项目里有 JAX 的身影,有希望给 JAX 带一波热度。

参考

[1] James Bradbury, Roy Frostig, Peter Hawkins, Matthew James Johnson, Chris Leary, Dougal Maclaurin, George Necula, Adam Paszke, Jake Vander-Plas, Skye Wanderman-Milne, and Qiao Zhang. JAX: composable transfor-mations of Python+NumPy programs, 2018.

[2] NumPy - https://numpy.org/

[3] HIPS/Autograd - https://github.com/HIPS/autograd

[4] XLA: Optimizing Compiler for Machine Learning  |  TensorFlow - https://www.tensorflow.org/xla

[5] Atilim Gunes Baydin, Barak A. Pearlmutter, and Alexey AndreyevichRadul. Automatic differentiation in machine learning: a survey. CoRR,abs/1502.05767, 2015.


技术分享 | 能微分会加速的 NumPy —— JAX_第3张图片

你可能感兴趣的:(技术文章,编程语言,numpy)