前言
当我们把使用Python训练的模型固化成PB文件之后,再进行相应的模型压缩之后可以考虑往Mobile端移植了,本文主要讲解TensorFlow Model移植到Android端。
TensorFlow1.0之后推出了Java版本,所以间接为Android开发TensorFlow程序带来便利,以前我们需要用JNI去编写,可是JNI难于调试,C++代码对于普通Android开发者来讲还是比Java繁琐,所以本文以Java API讲述开发过程。
正文
下面就正式开始一直TensorFlow model到Android中啦。
- 引入依赖
在TensorFlow更新到1.2.0版本之后,TensorFlow为广大开发者提供了gradle依赖,现在我们想要引入TensorFlow只需要在gradle中加入
compile 'org.tensorflow:tensorflow-android:1.2.0-rc0'
即可引入TensorFlow的库。
- 复制PB文件
快速开发的话直接把PB文件放在assets文件夹里就行,如果正式上线的时候觉得PB文件一起打包较大的话可以放在服务器,打开APP的时候提示下载再复制进去就好。
- 创建TensorFlowInterface类
这个类指的是我们读取、识别等一系列方法存放的类,名字随你取。
- 载入TensorFlow
在类的第一行加入这句话,会在加载类的时候首先加载TensorFlow
{
System.loadLibrary("tensorflow_inference");
}
- 定义常量
在这一步,我们先定义一些常量,比如输入节点名、输出节点名、输出图像的尺寸、通道、输入节点数据类型、输出节点数据类型。代码如下
private static final String input_layer = "inputs/X";
private static final String output_layer = "output/predict";
private Context context;
private static final int HEIGHT = 64;
private static final int WIDTH = 256;
private static final int CHANNEL = 1;
private float[] inputs = new float[HEIGHT*WIDTH*CHANNEL];
private long[] outputs = new long[11];
- 初始化模型
这一步TensorFlow的模型会载入到内存中,传入assets和PB文件名
TensorFlowInferenceInterface inferenceInterface = new TensorFlowInferenceInterface(context.getAssets(),"rounded_graph.pb");
- 喂数据给输入节点
这里的参数是输入节点名,输入数据,输入数据的shape
inferenceInterface.feed(input_layer,inputs,1,16384);
- run session
inferenceInterface.run(new String[] { output_layer }, false);
- 获取输出数据
根据你在Python定义的输出格式,new一个接收输出数据的变量,从输出节点获取数据
byte[] outPuts = new byte[88];
inferenceInterface.fetch(output_layer,outPuts);
- 数据变换
从输出节点获取到数据之后就需要你对自己的输出数据进行操作,比如我在我们model里最终输出的结果进行了Argmax的操作,Argmax返回的值类型是Int64的,在Android里只有long对应,但fetch方法的接受变量的参数类型只有double、float、int、byte,所以这里需要使用byte获取,再进行转换。这里跟传统的byte[8]转long有些不同,具体处理方式要看你定义的数据格式,我这里的byte[8]用网上的方法转long发现数值非常大,于是遍历一遍byte[8],发现每个子元素都是相同的数值,所以这里只取第一个元素,组成一个新的数组,再对这个数组进行解析。
long[] tOutputs=new long[11];
for (int i=0;i<11;i++)
{
int k=i*8;
tOutputs[i]=outPuts[k];
Log.i("output",tOutputs[i]+"");
}
String outputStr="";
for(int i=0;i<11;i++){
long char_idx=tOutputs[i];
long char_code = 0;
if (char_idx<10){
char_code = char_idx + (int)('0');
}
else if (char_idx<36){
char_code = char_idx-10 + (int)('A');
}
else if (char_idx<62){
char_code = char_idx + (int)('a');
}
outputStr+= (char)char_code;
}
后记
有Java API确实相比C++来的更直观方便,而且native debug也比JNI好操作,等TensorFlowLite出来的时候,Android TensorFlow应用会更加广泛吧。