由于我现在还处在机器学习入门阶段,对很多知识也是一知半解,没有那个实力去写好的原创文章,所以还是翻译一篇文章分享给大家。如果有问题请参看原文或和我联系。原文地址:https://heartbeat.fritz.ai/intro-to-machine-learning-on-android-how-to-convert-a-custom-model-to-tensorflow-lite-e07d2d9d50e3
对于开发者来说,在移动设备上运行预先训练好的模型的能力意味着向边界计算(edge computing)迈进了一大步。[译注:所谓的边界计算,从字面意思理解,就是与现实世界的边界。数据中心是网络的中心,PC、手机、监控照相机处在边界。]数据能够直接在用户手机上处理,私人数据仍然掌握在他们手中。没有蜂窝网络的延迟,应用程序可以运行得更顺畅,并且可大幅减少公司的云服务账单。快速响应式应用现在可以运行复杂的机器学习模型,这种技术转变将赋予产品工程师跳出条条框框思考的力量,迎来应用程序开发的新潮流。
继Apple发布CoreML之后,Google发布了TensorFlow Lite的开发者预览版,这是TensorFlow Mobile的后续发展版本。通过在支持它的设备上利用硬件加速,TensorFlow Lite可以提供更好的性能。它也具有较少的依赖,从而比其前身有更小的尺寸。尽管目前还处于早期阶段,但显然谷歌将加速发展TF Lite,持续增加支持并逐渐将注意力从TFMobile转移。考虑到这一点,我们直接选择TFLite, 尝试创建一个简单的应用程序,做一个技术介绍。
显然从谷歌的TensorFlow Lite文档入手最好,这些文档主要在github上(https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite)。他们还发布了一些简单的教程来帮助其他人上手:
这些示例和教程更侧重于使用预先训练的模型或重新训练现有的模型。但是用户自己的模型呢? 如果我有一个训练的模型,想将其转换为.tflite文件,该怎么做?有一些简略提示我该怎么做,我按图索骥,无奈有一些进入了死胡同。经过一天费尽心思的搜索,一小撮脚本和几杯咖啡,我终于让它能够工作了 - 一个简单的,转换过的MNIST.tflite模型。(我发誓,这不会是另一个MNIST训练教程,Google和许多其他开发人员已经用尽了这个话题)。
在这篇文章中,我们将学习一些通用的技巧,一步一步为移动设备准备一个TFLite模型。
首先,我想选择一个未经过预先训练或转换成.tflite文件的TensorFlow模型,理所当然我选择使用MNIST数据训练的简单的神经网络(目前支持3种TFLite模型:MobileNet、Inception v3和On Device Smart Reply)。
幸运的是,Google在其模型库(model zoo)中开放了大量研究模型和可用模型,这其中包括MNIST训练脚本。我们将在本节中引用该代码,大致浏览一下,熟悉它。
我们应该对此训练脚本进行一些修改,以便稍后进行转换。
class Model(tf.keras.Model):
...
def __call__(self, inputs, training):
# Input layer
y = tf.reshape(inputs, self._input_shape)
y = self.conv1(y)
y = self.max_pool2d(y)
y = self.conv2(y)
y = self.max_pool2d(y)
y = tf.layers.flatten(y)
y = self.fc1(y)
y = self.dropout(y, training=training)
# Returns a logit layer
return self.fc2(y)
从这段代码,我们清楚地看到输入层是tf.reshape,所以给它一个名字。
y = tf.reshape(inputs, self._input_shape, name='input_tensor’)
一个好的做法是为输入和输出图层命名。这将为您在后面节省一些时间和精力,因此您不必在tensorboard上四处搜索以填写转换工具的某些参数。(另外一个好处是,如果您共享模型而没有共享训练脚本,开发人员可以研究模型并快速识别图形的输入输出)。
def model_fn(features, labels, mode, params):
...
logits = model(image, training=False)
predictions = {
'classes': tf.argmax(logits, axis=1),
'probabilities': tf.nn.softmax(logits, name='softmax_tensor'),
}
if mode == tf.estimator.ModeKeys.PREDICT:
return tf.estimator.EstimatorSpec(
mode=tf.estimator.ModeKeys.PREDICT,
predictions=predictions,
export_outputs={
'classify': tf.estimator.export.PredictOutput(predictions)
})
我们还需要在TensorFlow图中暴露softmax图层,因为它是用于推断的输出图层。现在它嵌入在推断方法中。作一个简单的修正,将其移出,这样当我们训练此模型时,图形将包含此图层。 显然有更好的方法来修改它,但这是编辑现有MNIST脚本的简单方法。
总而言之,我们研究了训练脚本,并专门命名了模型推理所需的输入和输出层。请记住,我们正在使用的MNIST脚本同时进行训练和推理。了解训练和推理层之间的区别很重要。 由于我们希望准备好的模型仅用于移动平台上的推断(在MNIST数据的情况下预测手写数字),因此我们只需要预测所需的图层。请记住,我们正在使用的MNIST脚本既有训练又有预测。稍后,我们将在Tensorboard中看到分离两者。
这里有完整的mnist.py文件供您参考。
python official/mnist/mnist.py --export_dir /tmp/mnist_saved_model --model-dir /tmp/mnist_graph_def_with_ckpts
这些导出目录保存检查点和定义图形的protobuf文件。我们来分析一下从训练文件中保存的不同的TF格式。
github文档中,对GraphDef(.pb)、FrozenGraphDef(带有冻结变量的.pb)、SavedModel(.pb - 用于推断服务器端的通用格式)和Checkpoint文件(在训练过程中的序列化变量)有明确的解释。 这是我创建的一张图表,展示了如何从一个转换到另一个,一步一步解释这中间涉及到的东西。
从MNIST训练脚本中,我们得到文本可读形式(.pbtxt)的Graph Def、检查点和保存的图形。 重要的是要注意GraphDef、Saved Model、FrozenGraph和Optimized Graphs都以protobuf格式保存(.pb)
>> ls /tmp/mnist_graph_def_with_ckpts
checkpoint
model.ckpt-48000
model.ckpt-35626
model.ckpt-39410
model.ckpt-43218
model.ckpt-47043
model.ckpt-48000
graph.pbtxt
.pbtxt是图形def的文本格式。 您应该能够像任何.pb文件一样使用它。
我强烈建议使用Tensorboard来检查图表。请参考附录了解如何导入和使用它。
审查.pbtxt图,我们看到:
训练后在Tensorboard中可视化graph.pbtxt - 在这里,我们标记了输入和输出图层以及仅用于模型训练中的不必要图层。
使用Tensorboard,我们可以看到训练脚本中生成的每个图层。由于我们命名了输入和输出图层,因此我们可以轻松识别它们,然后开始了解哪些图层对于推断是必需的,哪些图层可以丢弃掉的。 绿线框起来的所有内容都用于在训练过程中调整权重。同样,input_tensor之前的所有内容也是不必要的。在移动设备上运行之前,我们需要裁剪此图。 TFLite中大多数训练层也不受支持(请参阅附录)。
freeze_graph
--input_graph=/tmp/mnist_graph_def_with_ckpts/graph.pbtxt
--input_checkpoint=/tmp/mnist_graph_def_with_ckpts/model.ckpt-48000
--input_binary=false
--output_graph=/tmp/mnist_graph_def_with_ckpts/frozen_mnist.pb
--output_node_names=softmax_tensor
结果是:/tmp/mnist_graph_def_with_ckpts/frozen_mnist.pb下的冻结图。此时,再次检查Tensorboard中的图形是个好主意。
请注意,freeze_graph实际上删除了训练中使用的大部分图层。但是,我们仍然有一些与TFLite不兼容的东西。具体来说,请注意“dropout”和“iterator”层。这些图层用于训练,仍然需要裁剪。为了这一目的,我们使用优化器。
optimize_for_inference工具(安装指南)接受输入和输出名称,并执行另一次传递以去除不必要的图层。
optimize_for_inference \
--input=/tmp/mnist_graph_def_with_ckpts/frozen_mnist.pb \
--output=/tmp/mnist_graph_def_with_ckpts/opt_mnist_graph.pb \
--frozen_graph=True \
--input_names=input_tensor \
--output_names=softmax_tensor
我们需要指定输入和输出名称(input_tensor&softmax_tensor)。这个任务删除了图中的所有预处理。
在Tensorboard中评估opt_mnist_graph.pb。 注意dropout和iterator现在不见了。
结果应该是准备好转换为TFLite的图表。如果仍有不受支持的图层,请检查graph_transform工具。在本例中,所有操作都受支持。
最后一步是运行toco工具,及TensorFlow Lite优化转换器。唯一可能令人困惑的部分是输入形状。使用Tensorboard或summarize_graph工具,您可以获得形状。
在Tensorboard中,如果我们评估input_tensor,你会看到形状?x28x28x1。这里? 代表batch_size。在我们的例子中,我们将构建一个Android应用程序,该应用程序一次只能检测一个图像,因此在下面的toco工具中,我们将形状设置为1x28x28x1。
toco \
--input_file=/tmp/mnist_graph_def_with_ckpts/opt_mnist_graph.pb \
--input_format=TENSORFLOW_GRAPHDEF \
--output_format=TFLITE \
--inference_type=FLOAT \
--input_type=FLOAT \
--input_arrays=input_tensor \
--output_arrays=softmax_tensor \
--input_shapes=1,28,28,1 \
--output_file=/tmp/mnist_graph_def_with_ckpts/mnist.tflite
这里您得到一个可以直接添加到Android项目的TFLite文件。如果您已经完成了前面的步骤并确保所有操作都与TensorFlow Lite兼容,那么这部分应该非常简单。如果您有任何问题,请随时在下面留言。
从培训脚本开始,我们能够检查和修改TensorFlow图表,以便用于移动设备。通过遵循这些步骤,我们修剪了不必要的操作,并能够成功地将protobuf文件(.pb)转换为TFLite(.tflite)。
在接下来的文章中,我们将切换到移动开发并看看如何使用我们新近转换的mnist.tflite文件在Android应用程序中检测手写数字。
使用Tensorboard
# From anywhere though I suggest you make it outside of the git repos
mkdir training_summaries
# Runs tensorboard in the background at http://localhost:6006
tensorboard --logdir training_summaries &
# Using my modified import_pb_to_tensorboard.py in the tensorflow repo (feel free to edit to your liking)
import_pb_to_tensorboard.py --model_dir /tmp/mnist_graph_def_with_ckpts/graph.pbtxt --log_dir training_summaries/mnist --graph_type=PbTxt
training_summarizes目录用于存储导入图形的结果
支持的TFLite操作
Google正在继续增加对更多操作的支持,这里列出了当前可用的列表。