需要在Android studio上开发一个apk,使得可以调用Tensorflow生成的模型进行计算。
在app/src
文件夹下的build.gradle
中添加tensorflow的应用包(需联网下载)。
implementation 'org.tensorflow:tensorflow-android:1.13.1'
把生成好的模型文件放入assets
文件夹中,以便后续的模型调用。
首先,编写模型初始化、调用的代码如下所示。其中getPredict()
用于调用模型进行计算,本例中简化为直接输入数组inputdata []
,实际使用的时候需要接受真实的数据。
注意1:一定要保证输入、输出变量的名称、维度大小正确(如果不正确会有报错提醒,及时调整)
注意2:pd文件在生成时的TensorFlow版本需要和训练时相一致
package com.example.test;
import android.content.res.AssetManager;
import android.util.Log;
import org.tensorflow.contrib.android.TensorFlowInferenceInterface;
public class PredictionTF {
private static final String TAG = "PredictionTF";
// 设置模型输入/输出节点的数据维度
private static final int IN_COL = 6;
private static final int IN_ROW = 8;
private static final int OUT_COL = 1;
private static final int OUT_ROW = 1;
// 模型中输入变量的名称
private static final String inputName = "actor/InputData/X";
// 模型中输出变量的名称
private static final String outputName = "actor/FullyConnected_1/Softmax";
TensorFlowInferenceInterface inferenceInterface;
static {
// 加载libtensorflow_inference.so库文件
System.loadLibrary("tensorflow_inference");
Log.e(TAG,"libtensorflow_inference.so库加载成功");
}
PredictionTF(AssetManager assetManager, String modePath) {
// 初始化TensorFlowInferenceInterface对象
inferenceInterface = new TensorFlowInferenceInterface(assetManager,modePath);
Log.e(TAG,"TensoFlow模型文件加载成功");
}
public float[] getPredict() {
float[] inputdata = {0.361f ,-0.422f ,-0.992f ,-0.196f ,-0.564f ,0.947f ,-0.339f ,0.167f,
0.434f ,0.287f ,-0.704f ,0.065f ,-0.083f ,0.747f ,0.874f ,-0.796f,
-0.822f ,-0.366f ,0.21f ,-0.493f ,0.97f ,-0.779f ,0.947f ,0.118f,
0.798f ,0.911f ,0.42f ,-0.219f ,-0.572f ,0.033f ,-0.515f ,-0.846f,
-0.994f ,0.254f ,0.775f ,0.782f ,0.046f ,-0.403f ,0.056f ,0.731f,
-0.714f ,0.982f ,0.117f ,-0.912f ,0.467f ,-0.015f ,-0.998f ,0.703f};
//将数据feed给tensorflow的输入节点
inferenceInterface.feed(inputName, inputdata,1, IN_COL, IN_ROW);
float[] outputs = new float[6];
//运行tensorflow
inferenceInterface.run(new String[] { outputName }, false);
//获取输出节点的输出信息
inferenceInterface.fetch(outputName, outputs);
return outputs;
}
}
接着,编写主体的MainActivity
代码,由于网络最后输出的是一个1*6
的概率数组,所以加入了choose
以输出最终的结果。
package com.example.test;
import androidx.appcompat.app.AppCompatActivity;
import android.content.res.AssetManager;
import android.os.Bundle;
import android.util.Log;
import org.tensorflow.contrib.android.TensorFlowInferenceInterface;
import java.util.Random;
public class MainActivity extends AppCompatActivity {
PredictionTF preTF;
private static final String TAG = "MainActivity";
private static final String MODEL_FILE = "file:///android_asset/frozen_model(1).pb"; //模型存放路径
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);
preTF = new PredictionTF(getAssets(),MODEL_FILE);
float[] result = preTF.getPredict();
int true_result = choose(result);
Log.i(TAG, "输出的结果为:");
Log.i(TAG, String.valueOf(true_result));
}
public int choose(float result[]){
int rand_times = 100;
int A_DIM = 6;
float[] action_cumsum = {0,0,0,0,0,0};
action_cumsum[0] = result[0];
for (int i = 1; i < A_DIM; i++) {
action_cumsum[i] = action_cumsum[i-1] + result[i];
}
int[] hitcount = {0,0,0,0,0,0};
int action = 0;
int max = 0;
for (int i = 0; i < rand_times; i++) {
Random r = new Random();
float randomnum = r.nextFloat();
for (int j = 0; j < action_cumsum.length; j++) {
if (action_cumsum[j] >= randomnum) {
hitcount[j] = hitcount[j] + 1;
break;
}
}
}
for(int i = 0; i < hitcount.length; i++){
if(max < hitcount[i]){
max = hitcount[i];
action = i;
}
}
return action;
}
}