tflite图像的Demo官方已经给了,下载下来看了,其实写的比较复杂,如果需要改成时序信号的使用,需要改的东西挺多的,也比较耗时间,还不如自己新建一个。
其实,如果模型本来就不大的话,直接使用Tensorflow的.pb模型也是可以的,并且,感觉使用.pb的Java代码看上去清爽一些,参考:Android运行Keras/TF模型。
由于我并不是直接使用tensorflow来进行训练的,是使用keras训练的,由于Tensorflow也使用了Keras的API所以,Keras的模型(.h5/.hdf5)文件是可以直接转化为tflite的。
import tensorflow as tf
converter = tf.lite.TFLiteConverter.from_keras_model_file('models/unet_siftflow.hdf5')
tflite_model = converter.convert()
open("models/converted_model.tflite", "wb").write(tflite_model)
如果使用的tensorflow训练的,可以使用converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
来转化。这里的saved_model_dir
就是保持checkpoint的文件夹,如果已经使用freeze_graph.py冻结了的模型,则需要使用tf.lite.TFLiteConverter.from_frozen_graph(freezed_model)
。
导出模型以后需要进行输入输出的检查,使用下面这段代码进行检查:
# model_test.py
import numpy as np
import tensorflow as tf
# Load TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path="models/converted_model.tflite")
interpreter.allocate_tensors()
# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
input_shape = input_details[0]['shape']
print('input size should be:', input_shape)
input_data = np.array([[i for i in range(400)]], dtype=np.float32)
# input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)
interpreter.set_tensor(input_details[0]['index'], input_data)
interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])
print('the output is:', output_data, 'with size of', output_data.shape)
这里输出的是:
input size should be: [ 1 400]
the output is: [[1. 0.]] with size of (1, 2)
也就是说输入数据是(1,400)的,输出是(1,2)的,这个很重要,因为在Java中使用时候,需要输入输出数据的大小。
如果是一个图像项目,首先可以直接到tf_android图像实例项目下载官方Demo进行。然后替换一下自己的模型即可。
如果完全新建一个Android Studio项目,在新建一个空白项目后(新建一个Java项目,不要新建Kotlin项目),在src/main里面创建一个asset文件夹,把模型文件放在这里:
在build.gradle
(Module:App)添加依赖:
dependencies {
implementation 'org.tensorflow:tensorflow-lite:+'
}
在android下增加 aaptOptions:
aaptOptions {
noCompress "tflite"
}
这个时候,这个文件为:
apply plugin: 'com.android.application'
android {
compileSdkVersion 29
buildToolsVersion "29.0.2"
defaultConfig {
applicationId "com.example.tflite_as"
minSdkVersion 15
targetSdkVersion 29
versionCode 1
versionName "1.0"
testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
}
buildTypes {
release {
minifyEnabled false
proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro'
}
}
aaptOptions {
noCompress "tflite"
}
}
dependencies {
implementation fileTree(dir: 'libs', include: ['*.jar'])
implementation 'androidx.appcompat:appcompat:1.1.0'
implementation 'androidx.constraintlayout:constraintlayout:1.1.3'
testImplementation 'junit:junit:4.12'
androidTestImplementation 'androidx.test.ext:junit:1.1.1'
androidTestImplementation 'androidx.test.espresso:espresso-core:3.2.0'
implementation 'org.tensorflow:tensorflow-lite:+'
}
打开MainActicity.java或者其他要使用这个模型的地方,先导入库:
import org.tensorflow.lite.Interpreter
这是tensorflow需要的,另外还需要一些辅助的库:
import java.io.FileInputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
创建一个private的变量private Interpreter tflite;
并定义好加载模型的函数,以及执行predict的函数就可以了,为了体现执行的结果,在layout上创建一个TextView,ID为textView_1,用于显示结果。模型的输出与输出数据是比较容易出错的,详见代码中注释,MainActivity.java
的所有代码为:
package com.example.tflite_as;
import androidx.appcompat.app.AppCompatActivity;
import android.content.res.AssetFileDescriptor;
import android.os.Bundle;
import android.util.Log;
import android.widget.TextView;
import android.widget.Toast;
import java.io.FileInputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
import org.tensorflow.lite.Interpreter;
public class MainActivity extends AppCompatActivity {
private static final String TAG = "Test";
private Interpreter tflite;
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);
loadModule();
}
private void loadModule() {
String model = "converted_model";//模型的名字
try {
Interpreter.Options options = new Interpreter.Options();
options.setNumThreads(4);
options.setUseNNAPI(true);
options.setAllowFp16PrecisionForFp32(true);
// 加载模型文件
tflite = new Interpreter(loadModelFile(model), options);
Toast.makeText(MainActivity.this, model + " model load success", Toast.LENGTH_SHORT).show();
// 调用Test函数
test();
} catch (IOException e) {
Toast.makeText(MainActivity.this, model + " model load fail", Toast.LENGTH_SHORT).show();
e.printStackTrace();
}
}
// 加载模型文件的函数
private MappedByteBuffer loadModelFile(String model) throws IOException {
AssetFileDescriptor fileDescriptor = getApplicationContext().getAssets().openFd(model + ".tflite");
FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
FileChannel fileChannel = inputStream.getChannel();
long startOffset = fileDescriptor.getStartOffset();
long declaredLength = fileDescriptor.getDeclaredLength();
return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
}
public void test() {
Log.d(TAG, "-----test()----");
try {
// 模型输出,定义需要与模型的输出一致,上一节有检查输入输出
float[][] labelProbArray = new float[1][2];
// 这里需要使用ByteBuffer的形式进行输入,输入的数据是(1,400)个float的数据
// float是4个字节,所以这里需要写成1*400*4=1600,这里容易出错
ByteBuffer inputData = ByteBuffer.allocateDirect(1 * 400 * 4);
// 填入数据
inputData.order(ByteOrder.nativeOrder());
for (int i = 0; i < 400; i++) {
inputData.putFloat(i);
}
// 运行模型的predict功能
tflite.run(inputData, labelProbArray);
// 根据模型的定义,输出相应的信息
String result;
if (labelProbArray[0][0] > 0.5)
result = "Not Pulse";
else
result = "Pulse";
TextView resultView = findViewById(R.id.textView_1);
resultView.setText(result);
} catch (Exception e) {
e.printStackTrace();
}
}
}