Tensorflow Lite初探(Android)

一、背景:

11月15日,谷歌正式发布了TensorFlow Lite开发者预览版。

TensorFlow Lite 是 Google I/O 2017 大会上的其中一个重要宣布,有了TensorFlow Lite,应用开发者可以在移动设备上部署人工智能。

Google 表示 Lite 版本 TensorFlow 是 TensorFlow Mobile 的一个延伸版本。尽管是一个轻量级版本,依然是在智能手机和嵌入式设备上部署深度学习的一大动作。此前,通过TensorFlow Mobile API,TensorFlow已经支持手机上的模型嵌入式部署。TensorFlow Lite应该被视为TensorFlow Mobile的升级版。

TensorFlow Lite 目前仍处于“积极开发”状态,目前仅有少量预训练AI模型面世,比如MobileNet、用于计算机视觉物体识别的Inception v3、用于自然语言处理的Smart Reply,当然,TensorFlow Lite上也可以部署用自己的数据集定制化训练的模型。

TensorFlow Lite可以与Android 8.1中发布的神经网络API完美配合,即便在没有硬件加速时也能调用CPU处理,确保模型在不同设备上的运行。 而Android端版本演进的控制权是掌握在谷歌手中的,从长期看,TensorFlow Lite会得到Android系统层面上的支持。

Tensorflow Lite初探(Android)_第1张图片

其组件包括:

  • TensorFlow 模型(TensorFlow Model):保存在磁盘中的训练模型。
  • TensorFlow Lite 转化器(TensorFlow Lite Converter):将模型转换成 TensorFlow Lite 文件格式的项目。
  • TensorFlow Lite 模型文件(TensorFlow Lite Model File):基于 FlatBuffers,适配最大速度和最小规模的模型。

github链接如下:https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite

二、环境:

Android Studio 3.0, SDK Version API26, NDK Version 14

步骤:
1. 将此项目导入到Android Studio:
https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/java/demo
2. 下载移动端的模型(model)和标签数据(lables):
https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip
3. 下载完成解压mobilenet_v1_224_android_quant_2017_11_08.zip文件得到一个xxx.tflite和labes.txt文件,分别是模型和标签文件,并且把这两个文件复制到assets文件夹下。
4. 构建app,run……

详情可参考:http://blog.csdn.net/wu__di/article/details/78570303

三、源码分析:

整个demo的代码非常少,仅包含4个java文件(相信随着正式版的发布,会有更加丰富的功能以及更多的预训练模型):
Tensorflow Lite初探(Android)_第2张图片

其中:
- AutoFitTextureView: 一个自定义View;
- CameraActivity: 整个app的入口activity,这个activity只做了一件事,就是加载了一个fragment;
- Camera2BasicFragment: 入口activity中加载的fragment,其中实现了所有跟UI相关的代码;首先在onActivityCreated中,初始化了一个ImageClassifier对象,此类是整个demo的核心,用于加载模型并实现推理运算功能。然后开启了一个后台线程,在线程中反复地对从摄像头获取的图像进行分类操作。

/** Load the model and labels. */
  @Override
  public void onActivityCreated(Bundle savedInstanceState) {
    super.onActivityCreated(savedInstanceState);
    try {
      classifier = new ImageClassifier(getActivity());
    } catch (IOException e) {
      Log.e(TAG, "Failed to initialize an image classifier.");
    }
    startBackgroundThread();
  }

startBackgroundThread()中做的轮询操作:

private Runnable periodicClassify =
      new Runnable() {
        @Override
        public void run() {
          synchronized (lock) {
            if (runClassifier) {
              classifyFrame();
            }
          }
          backgroundHandler.post(periodicClassify);
        }
      };

其中,classifyFrame()代码如下:

/** Classifies a frame from the preview stream. */
  private void classifyFrame() {
    if (classifier == null || getActivity() == null || cameraDevice == null) {
      showToast("Uninitialized Classifier or invalid context.");
      return;
    }
    Bitmap bitmap =
        textureView.getBitmap(ImageClassifier.DIM_IMG_SIZE_X, ImageClassifier.DIM_IMG_SIZE_Y);
    String textToShow = classifier.classifyFrame(bitmap);
    bitmap.recycle();
    showToast(textToShow);
  }

大致过程就是从控件textureView中以指定的长宽读取一个Bitmap出来(也就是摄像头的实时画面),然后交给classifier的classifyFrame进行处理,返回一个结果,这个结果就是图片分类的结果,然后显示在手机屏幕上。

ImageClassifier:demo最重要的部分,但只有两个函数比较重要,一个是构造函数:

/** Initializes an {@code ImageClassifier}. */
  ImageClassifier(Activity activity) throws IOException {
    tflite = new Interpreter(loadModelFile(activity));
    labelList = loadLabelList(activity);
    imgData =
        ByteBuffer.allocateDirect(
            DIM_BATCH_SIZE * DIM_IMG_SIZE_X * DIM_IMG_SIZE_Y * DIM_PIXEL_SIZE);
    imgData.order(ByteOrder.nativeOrder());
    labelProbArray = new byte[1][labelList.size()];
    Log.d(TAG, "Created a Tensorflow Lite Image Classifier.");
  }

其中Interpreter类非常关键,这是Android app与tensorflow lite之间的桥梁,位于org.tensorflow:tensorflow-lite-0.1.1中:
Tensorflow Lite初探(Android)_第3张图片
这个包实现了对张量(tensor)的基本操作,而整个tensorflow就是以张量为单位处理各种运算。

tflite = new Interpreter(loadModelFile(activity))这里通过loadModelFile将asset中的tflite格式的模型文件加载并返回一个MappedByteBuffer传给Interpreter。labelList = loadLabelList(activity)将asset中的labels文件中的分类标签加载到字符串列表labelList中。imgData则是一个存放输入张量的buffer,一个非常典型的(batch_size, x, y, channel)结构,在这里可以理解为一个placeholder。
最后labelProbArray是一个1 x labelList.size()的张量,可以认为是一个向量,元素的个数就是模型输出结果的总类别数,每一个元素代表模型判断到图片为某一类别的概率,对应于labels。

另一个是实现图片分类的函数:

/** Classifies a frame from the preview stream. */
  String classifyFrame(Bitmap bitmap) {
    if (tflite == null) {
      Log.e(TAG, "Image classifier has not been initialized; Skipped.");
      return "Uninitialized Classifier.";
    }
    convertBitmapToByteBuffer(bitmap);
    // Here's where the magic happens!!!
    long startTime = SystemClock.uptimeMillis();
    tflite.run(imgData, labelProbArray);
    long endTime = SystemClock.uptimeMillis();
    Log.d(TAG, "Timecost to run model inference: " + Long.toString(endTime - startTime));
    String textToShow = printTopKLabels();
    textToShow = Long.toString(endTime - startTime) + "ms" + textToShow;
    return textToShow;
  }

首先convertBitmapToByteBuffer将bitmap中的像素值读出,并放入刚才初始化的imgData中,这里相当于为placeholder填充了数据。然后是最关键的一行tflite.run(imgData, labelProbArray),喂数据,得出结果,分类的结果存入labelProbArray中。

#对于这行代码,有没有似曾相识的感觉:
tf.Session().run(output, feed_dict={x:input})

最后labelProbArray转换为需要显示的文字,传给UI层。

四、关于tflite模型

关于tflite,官方有比较详细的说明:
https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite

这里总结一下,生成tflite有两种方式,一种是直接在模型设计流程中,通过tflite提供的接口tf.contrib.lite.toco_convert将推理图转化为可供移动端直接使用的tflite文件(由于目前是预览版,这个接口在正式版的tensorflow中还无法使用):

import tensorflow as tf
img = tf.placeholder(name="img", dtype=tf.float32, shape=(1, 64, 64, 3))
val = img + tf.constant([1., 2., 3.]) + tf.constant([1., 4., 4.])
out = tf.identity(val, name="out")
with tf.Session() as sess:
  tflite_model = tf.contrib.lite.toco_convert(sess.graph_def, [img], [out])
  open("converteds_model.tflite", "wb").write(tflite_model)

还有就是将已经训练好的模型文件,转化为tflite格式。由于涉及到模型文件,这里先科普一下tensorflow的模型持久化。

从这里可以找到一些现成的模型:
https://github.com/tensorflow/models

随便下载一个,比如research/adv_imagenet_models当中的模型ens4_adv_inception_v3_2017_08_18.tar.gz,解压后可以得到这些文件:
这里写图片描述
这些文件保存了模型的信息,一般可通过如下代码生成:

import tensorflow as tf

...

saver = tf.train.Saver()
with tf.Session() as sess:
    saver.save(sess, "/model/xxxx.ckpt") #在session中将计算图和变量信息保存到ckpt文件中

虽然只指定了一个文件路径,但是这个目录下会生成3个文件,分别是xxx.ckpt.data,xxx.ckpt.meta,xxx.ckpt.index,正如上图所示。其中,xxx.ckpt.meta保存了计算图结构,xxx.ckpt.data保存了所有变量的取值,xxx.ckpt.index保存了所有变量名。有了这三个文件,就能得到模型的信息并加载到其他项目中。
还有一种文件需要介绍一下,*.pb,官方的描述是这样的:

  • GraphDef (.pb) - a protobuf that represents the TensorFlow training and or computation graph. This contains operators, tensors, and variables definitions.
  • FrozenGraphDef - a subclass of GraphDef that contains no variables. A GraphDef can be converted to a frozen graphdef by taking a checkpoint and a graphdef and converting every variable into a constant with the value looked up in the checkpoint.

这里可以简单理解为*.pb文件有两种情况,一种是仅保存了计算图结构,不包含变量值,可以通过如下代码生成:

tf.train.write_graph()

还有一种就是上面提到的FrozenGraphDef ,不仅包含计算图结构,还包含了训练产生的变量值,这类*.pb可以直接被加载用于推理运算,tensorflow mobile的一个android应用demo就是很好的例子:
https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android
这个demo里,android应用正是通过FrozenGraphDef的*.pb文件将模型加载到app中,从而实现模型的推理功能。

那么如何使用现有的模型文件生成tflite呢?正式需要这样一个包含计算图和变量值的冻结图文件(*.pb)。
如果已经有了这个冻结图文件,根据官方文档,可以使用如下命令生成tflite:

bazel build tensorflow/contrib/lite/toco:toco

bazel-bin/tensorflow/contrib/lite/toco/toco -- \
  --input_file=$(pwd)/mobilenet_v1_1.0_224/frozen_graph.pb \
  --input_format=TENSORFLOW_GRAPHDEF  --output_format=TFLITE \
  --output_file=/tmp/mobilenet_v1_1.0_224.lite --inference_type=FLOAT \
  --input_type=FLOAT --input_arrays=input \
  --output_arrays=MobilenetV1/Predictions/Reshape_1 --input_shapes=1,224,224,3

如果没有冻结图,也可以根据包含变量值的ckpt和仅包含计算图结构的pb文件生成一个冻结图文件:

bazel build tensorflow/python/tools:freeze_graph

bazel-bin/tensorflow/python/tools/freeze_graph\
    --input_graph=/tmp/mobilenet_v1_224.pb \
    --input_checkpoint=/tmp/checkpoints/mobilenet-10202.ckpt \
    --input_binary=true --output_graph=/tmp/frozen_mobilenet_v1_224.pb \
    --output_node_names=MobileNet/Predictions/Reshape_1

最后,如果想要使用一些现成的tflite模型,可以从这里找到:
https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/g3doc/models.md

你可能感兴趣的:(人工智能,深度学习,移动端)