Tensorflow-Lite Android笔记

Tensorflow-Lite

针对移动设备和Iot设备的开源深度学习框架。可以让我们原本运行在服务器上的模型得以运行到移动设备或Iot设备上,使得服务器能够节省出更多的资源处理其他业务。
在这里仅做安卓设备的学习(博主没有学过IOS、嵌入式就不做学习了,具体流程其实都差不多),主要有几个步骤,小新对于官网文档实在难懂,陆陆续续在网上查阅多方资料还是磕磕碰碰,所幸后来想通了。
官网提供的例子各位看官可以去下载运行。
该博文旨在了解Tensorflow-Lite在安卓上使用的简单步骤

一、模型转换

这一步的基础是模型已经训练出来了,如果还没开始训练模型,请移步到其他大神的博客或者官网。
官方推荐用python做模型转换以下是代码

import tensorflow as tf
# 加载模型
model = tf.keras.models.load_model('xxx.h5')
converter = tf.lite.TFLiteConverter.from_keras_model(model)
# 量化:降低某些值的精确度使得tflite包减小,我试过不用量化打包8.5M,量化后2.1M。不太懂只敢用默认的量化规则,相信其应该还能再精简。
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()
open("xxx.tflite", "wb").write(tflite_model)

以上代码是keras模型,还有其他模型就不一一说明了。

二、导入安卓项目

上一步我们得到了tflite包,现在我们可以将它放到安卓应用的asset文件夹下面。注意:asset文件夹的资源在安卓打包时会进行压缩,所以为了避免代码中打开tflite包出现异常,我们需要配置参数使得安卓打包时不将tflite包压缩。

aaptOptions {
	noCompress "tflite"  //表示不让aapt压缩的文件后缀
}

其次我们需要在安卓项目中使用tensorflow-lite的类也是需要引入相关的包,已经由谷歌提供了:

dependencies {
    implementation 'org.tensorflow:tensorflow-lite:0.0.0-nightly'
}

这个aar里面包含了Android ABIs中的所有的二进制文件,但是其实我们用到的其实没有那么多,所以可以做一个删减,官方已经给出了方法:

android {
    defaultConfig {
        ndk {
            abiFilters 'armeabi-v7a', 'arm64-v8a'
        }
    }
}

三、加载模型及运用

由于是耗时操作,所以最好是放在子线程去执行。

Interpreter tflite = new Interpreter(FileUtil.loadMappedFile(mContext, "xxx.tflite"), new Interpreter.Options());
// input和output的数据类型在训练模型的时候就知道了的,如果不知道可以通过在python代码里面测试打印观察,以下为博主测试的数据类型
float[][] input = new float[1][20];
// 此处应该给input赋值,不过我略去了
float[][] output = new float[1][1];
tflite.run(input, output);
// 以下可以通过output的结果进行其他操作

python测试tensorflow-lite:

import tflite_runtime.interpreter as tflite
import numpy as np

interpreter = tflite.Interpreter(model_path='xxx.tflite')
interpreter.allocate_tensors()

input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

# Test model on random input data.
input_shape = input_details[0]['shape']
input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)
interpreter.set_tensor(input_details[0]['index'], input_data)

interpreter.invoke()

# The function `get_tensor()` returns a copy of the tensor data.
# Use `tensor()` in order to get a pointer to the tensor.
output_data = interpreter.get_tensor(output_details[0]['index'])
print(input_data)
print(output_data)

在我的模型测试中input和output都是一个二维数组,所以在安卓项目中也是定义成二维数组,具体的可以参考。

一开始我遇到的问题是不知道怎么在安卓中确定input和output的类型,导致浪费了一些时间,而且官方例子都有一个labels.txt,我以为这个是必须的,但是在后来才确定这只是一个根据output来查找对应的结果的一个类型标签文件,因为官方例子差不多都是图像识别类,但是我们自己的模型不一定是这种类型,所以不一定需要labels.txt。我们只需要知道output提供的结果就行,不管是官方的例子还是我们的例子。

至此,tensorflow-lite在安卓上的应用就到此结束。

你可能感兴趣的:(Tensorflow-Lite,Android)