TensorFlow 线性代数编译框架 XLA

TensorFlow 线性代数编译框架 XLA

XLA 的工作原理

学过 C 语言的人可能知道,LLVM 是一个编译器的框架系统,用 C++编写而成,用于优化
以任意编程语言编写的程序的编译时间(compile time)、链接时间(link time)、运行时间(run
time)以及空闲时间(idle time)。在基于 LLVM 的编译器中,前端负责解析、验证和诊断输入代码中
的错误,然后将解析的代码转换为 LLVM 中间表示(intermediate representation,IR)。该 IR 
通过一系列分析和优化过程来改进代码,然后发送到代码生成器中,以产生本地机器代码。

TensorFlow 线性代数编译框架 XLA_第1张图片
XLA 目前支持在 x86-64 和 NVIDIA GPU 上进行 JIT 编译,以及在 x86-64 和 ARM 上进行
AOT 编译。因此,AOT 编译方式更适合移动端和嵌入式的深度学习使用。下面我们就以 JIT
编译为例进行说明。

JIT 编译方式

通过 XLA 运行 TensorFlow 计算有两种方法,一是打开 CPU 或 GPU 设备上的 JIT 编译,
二是将操作符放在 XLA_CPU 或 XLA_GPU 设备上。

打开 JIT 编译

打开 JIT 编译可以有两种方式。下面是在会话上打开,这种方式会把所有可能的操作符编
程成 XLA 计算。用法示例如下:
config = tf.ConfigProto()
config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1
sess = tf.Session(config=config)
另一种方式是为一个或多个操作符手动打开 JIT 编译。这是通过使用属性_XlaCompile =
true 标记要编译的操作符来完成的。用法示例如下:
jit_scope = tf.contrib.compiler.jit.experimental_jit_scope
x = tf.placeholder(np.float32)
with jit_scope():
y = tf.add(x, x)

将操作符放在 XLA 设备上

with tf.device("/job:localhost/replica:0/task:0/device:XLA_GPU:0"):
output = tf.add(input1, input2)

JIT 编译在 MNIST 上的实现

不使用 XLA 来运行时,如下:
python mnist_softmax_xla.py --xla=false
运行完成后生成时间线文件 timeline.ctf.json,使用 Chrome 跟踪事件分析器(在浏览器中访
问 chrome://tracing),打开该时间线文件,呈现的时间线如图 15-3 所示。

在这里插入图片描述

图 15-3 中最左侧一列列出了本机的 4 个 GPU。可以清晰地看到图中 MatMul 操作符,跨越
4 个 CPU 的时间消耗情况。
让我们使用 XLA 来训练模型,如下:
TF_XLA_FLAGS=--xla_generate_hlo_graph=.* python mnist_softmax_xla.py
运行完成后,得到的时间线图像如图 15-4 所示。

在这里插入图片描述
关键的训练代码如下:

config = tf.ConfigProto()
jit_level = 0
if FLAGS.xla:
# 开启 XLA 的 JIT 编译
jit_level = tf.OptimizerOptions.ON_1
config.graph_options.optimizer_options.global_jit_level = jit_level
run_metadata = tf.RunMetadata()
sess = tf.Session(config=config)
tf.global_variables_initializer().run(session=sess)
run_metadata = tf.RunMetadata()
sess = tf.Session(config=config)
tf.global_variables_initializer().run(session=sess)
# 训练
train_loops = 1000
for i in range(train_loops):
	batch_xs, batch_ys = mnist.train.next_batch(100)
# 在最后一次循环中,创建时间线文件,可以用 chrome://tracing/打开和分析
if i == train_loops - 1:
	sess.run(train_step,
	feed_dict={x: batch_xs, y_: batch_ys},
	options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),
	run_metadata=run_metadata)
	trace = timeline.Timeline(step_stats=run_metadata.step_stats)
	trace_file = open('xlatimeline.ctf.json', 'w')
	trace_file.write(trace.generate_chrome_trace_format())
else:
	sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})

AOT 主要应用场景是一些内存较小的嵌入式设备、手机、树莓派等

你可能感兴趣的:(Tensorflow,tf,tensorflow,XLA,移动端,线性代数)