Android端实现深度学习

这里截取了本人毕业设计关于移动端实现深度学习的章节。本章节将详细介绍如何实现移动端调用深度学习模型进行使用,简单来说就是两个步骤,生成可供调用的模型和调用模型。这里我们用到的人脸检测模型为第三章节训练出来的mAP最高的模型。
4.1 固定模型
为了使android能够调用,检测模型一定要转成pb文件。
4.1.1 读取检查点生成新的pbtxt文件
SSD模型训练过程中生成的checkpoints文件和pbtxt文件虽然包含了整个模型的参数,但是我们需要的只是生成预测值的那一部分网络结点。对于其他loss,gradient这些参数,我们并不需要。明确了我们需要的网络后,我们还要给它的输入层和输出层取名字,不然在android中无法调用这个模型。具体来说,我们需要写一段代码来重新得到需要的pbtxt文件,代码中关键的步骤如下:

1.首先定义一个有名字的占位符,用于表示输入输出数据的格式。告诉系统:这里有一个值/向量/矩阵,现在没法给你具体的数据,不过正式运行的会补上的。这里的正式运行指的就是后续在android中调用这个模型的时候,我们就要给它输入这个格式的数据。这里要注意的是,这个占位符的名字一定要取。因为调用的时候是要明确接收数据的变量的,否则调用模型将出错。本文定义如下:
这里写图片描述

2.在占位符后面紧跟着接收数据后,要处理的函数或者网络。一般情况下我们要对网络输出的预测值进行处理,便于后续在android编程中的调用。因为Tensorflow[8,9]和android里面的数据格式不同,最好是能都化成float一维数组形式。SSD模型有6层的输出,所以我们不能直接使用网络输出的预测值,而要将其合并为一层,再对最后的输出值取个名字,用于后续模型的读取。

3.恢复checkpoint文件,重新生成pbtxt文件。在经过网络得到最终输出值后,我们要再其后面添加saver,表明以上的网络模型参数将恢复。之后再打开一个Tensorflow会话,从checkpoint中恢复参数,并将上述的图(如果没有明确取名的话,图将为默认图),写入新的pbtxt文件。为了能够更清楚地被理解,这里将上一步的最后输出和这一步的代码写下来:

Android端实现深度学习_第1张图片
4.1.2 运行官方工具固定模型参数
首先明确运行文件需要的参数:输入的图文件(上一步骤生成的pbtxt文件),输入的检查点文件(上一步中中使用的checkpoint文件),输出的pb文件存放路径和名字,最后还要输出结点的名字(上一步中输出值的名字)。确定后,运行Tensorflow[8,9]官方工具freeze_graph.py文件,便会生成可供调用的pb文件。
4.2 Android项目调用
关于Android Studio环境的配置,大多可以自行搜索找到教程,本文就不做细说。关键的点有1.修改配置文件,2.根据tensorflow github上的教程生成libtensorflow_inference.so和libandroid_tensorflow_inference_java.jar,3.将上一节生成的pb文件放入项目的app/src/main/assets下,assets不存在则自行创建,4.将jar包放在app/libs下,并add as library,5.将so文件放置在app/src/main/jniLibs下,jniLibs不存在则自行创建。

4.2.1 具体调用方法
环境都配置好后,要能顺利调用深度学习模型还需要以下几个步骤:
1.导入jar包:

这里写图片描述

2.在要调用的java类的类定义首行,导入so文件:

这里写图片描述

3.定义变量和对象

Android端实现深度学习_第2张图片

4.Tensorflow接口初始化

这里写图片描述

5.人脸检测模型的调用

这里写图片描述

在得到推理输出的数据后,将其转变为最终需要的形式。本文中对数据的处理流程为,先进行数据格式的转换(数据转为xmin,ymin,xmax,ymax的形式)和挑选,具体为选择一定数量预测分数大于0.3的数据,再将它们进行非最大抑制处理得到最终的预测位置和类别。

你可能感兴趣的:(开发)