时序信号的模型使用tflite的示例

tflite图像的Demo官方已经给了,下载下来看了,其实写的比较复杂,如果需要改成时序信号的使用,需要改的东西挺多的,也比较耗时间,还不如自己新建一个。

其实,如果模型本来就不大的话,直接使用Tensorflow的.pb模型也是可以的,并且,感觉使用.pb的Java代码看上去清爽一些,参考:Android运行Keras/TF模型。

模型转化

由于我并不是直接使用tensorflow来进行训练的,是使用keras训练的,由于Tensorflow也使用了Keras的API所以,Keras的模型(.h5/.hdf5)文件是可以直接转化为tflite的。

import tensorflow as tf

converter = tf.lite.TFLiteConverter.from_keras_model_file('models/unet_siftflow.hdf5')
tflite_model = converter.convert()
open("models/converted_model.tflite", "wb").write(tflite_model)

如果使用的tensorflow训练的,可以使用converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)来转化。这里的saved_model_dir就是保持checkpoint的文件夹,如果已经使用freeze_graph.py冻结了的模型,则需要使用tf.lite.TFLiteConverter.from_frozen_graph(freezed_model)

导出模型以后需要进行输入输出的检查,使用下面这段代码进行检查:

# model_test.py
import numpy as np
import tensorflow as tf

# Load TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path="models/converted_model.tflite")
interpreter.allocate_tensors()

# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

input_shape = input_details[0]['shape']
print('input size should be:', input_shape)
input_data = np.array([[i for i in range(400)]], dtype=np.float32)
# input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)
interpreter.set_tensor(input_details[0]['index'], input_data)
interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])
print('the output is:', output_data, 'with size of', output_data.shape)

这里输出的是:

input size should be: [  1 400]
the output is: [[1. 0.]] with size of (1, 2)

也就是说输入数据是(1,400)的,输出是(1,2)的,这个很重要,因为在Java中使用时候,需要输入输出数据的大小。

建立Android项目

如果是一个图像项目,首先可以直接到tf_android图像实例项目下载官方Demo进行。然后替换一下自己的模型即可。
如果完全新建一个Android Studio项目,在新建一个空白项目后(新建一个Java项目,不要新建Kotlin项目),在src/main里面创建一个asset文件夹,把模型文件放在这里:

时序信号的模型使用tflite的示例_第1张图片

build.gradle(Module:App)添加依赖:

dependencies {
    implementation 'org.tensorflow:tensorflow-lite:+'
}

在android下增加 aaptOptions:

    aaptOptions {
        noCompress "tflite"
    }

这个时候,这个文件为:

apply plugin: 'com.android.application'

android {
    compileSdkVersion 29
    buildToolsVersion "29.0.2"
    defaultConfig {
        applicationId "com.example.tflite_as"
        minSdkVersion 15
        targetSdkVersion 29
        versionCode 1
        versionName "1.0"
        testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
    }
    buildTypes {
        release {
            minifyEnabled false
            proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro'
        }
    }
    aaptOptions {
        noCompress "tflite"
    }
}

dependencies {
    implementation fileTree(dir: 'libs', include: ['*.jar'])
    implementation 'androidx.appcompat:appcompat:1.1.0'
    implementation 'androidx.constraintlayout:constraintlayout:1.1.3'
    testImplementation 'junit:junit:4.12'
    androidTestImplementation 'androidx.test.ext:junit:1.1.1'
    androidTestImplementation 'androidx.test.espresso:espresso-core:3.2.0'
    implementation 'org.tensorflow:tensorflow-lite:+'
}

打开MainActicity.java或者其他要使用这个模型的地方,先导入库:

import org.tensorflow.lite.Interpreter

这是tensorflow需要的,另外还需要一些辅助的库:

import java.io.FileInputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;

创建一个private的变量private Interpreter tflite;并定义好加载模型的函数,以及执行predict的函数就可以了,为了体现执行的结果,在layout上创建一个TextView,ID为textView_1,用于显示结果。模型的输出与输出数据是比较容易出错的,详见代码中注释,MainActivity.java的所有代码为:

package com.example.tflite_as;

import androidx.appcompat.app.AppCompatActivity;


import android.content.res.AssetFileDescriptor;
import android.os.Bundle;
import android.util.Log;
import android.widget.TextView;
import android.widget.Toast;

import java.io.FileInputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;

import org.tensorflow.lite.Interpreter;

public class MainActivity extends AppCompatActivity {

    private static final String TAG = "Test";
    private Interpreter tflite;

    @Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);
        setContentView(R.layout.activity_main);
        loadModule();
    }

    private void loadModule() {
        String model = "converted_model";//模型的名字
        try {
            Interpreter.Options options = new Interpreter.Options();
            options.setNumThreads(4);
            options.setUseNNAPI(true);
            options.setAllowFp16PrecisionForFp32(true);
            // 加载模型文件
            tflite = new Interpreter(loadModelFile(model), options);
            Toast.makeText(MainActivity.this, model + " model load success", Toast.LENGTH_SHORT).show();
            // 调用Test函数
            test();
        } catch (IOException e) {
            Toast.makeText(MainActivity.this, model + " model load fail", Toast.LENGTH_SHORT).show();
            e.printStackTrace();

        }
    }

    // 加载模型文件的函数
    private MappedByteBuffer loadModelFile(String model) throws IOException {
        AssetFileDescriptor fileDescriptor = getApplicationContext().getAssets().openFd(model + ".tflite");
        FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
        FileChannel fileChannel = inputStream.getChannel();
        long startOffset = fileDescriptor.getStartOffset();
        long declaredLength = fileDescriptor.getDeclaredLength();
        return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
    }

    public void test() {
        Log.d(TAG, "-----test()----");
        try {
            // 模型输出,定义需要与模型的输出一致,上一节有检查输入输出
            float[][] labelProbArray = new float[1][2];
            // 这里需要使用ByteBuffer的形式进行输入,输入的数据是(1,400)个float的数据
            // float是4个字节,所以这里需要写成1*400*4=1600,这里容易出错
            ByteBuffer inputData = ByteBuffer.allocateDirect(1 * 400 * 4);
            // 填入数据
            inputData.order(ByteOrder.nativeOrder());
            for (int i = 0; i < 400; i++) {
                inputData.putFloat(i);
            }
            // 运行模型的predict功能
            tflite.run(inputData, labelProbArray);
            // 根据模型的定义,输出相应的信息
            String result;
            if (labelProbArray[0][0] > 0.5)
                result = "Not Pulse";
            else
                result = "Pulse";
            TextView resultView = findViewById(R.id.textView_1);
            resultView.setText(result);

        } catch (Exception e) {
            e.printStackTrace();
        }
    }
}

参考

  1. 开始使用 TensorFlow Lite
  2. Android 快速上手
  3. tf_android图像实例项目

你可能感兴趣的:(Python,深度学习,tensorflow,python,深度学习)