边缘计算:基于tflite实现andriod边缘端回归预测推理实战

读了本文,你可以实现从云端利用DNN模型进行训练,模型保存.h5格式(基于keras)或是saved model格式(tf2.0版本),模型转化为tflite,利用android studio 编写java接口程序,实现模型最终的推理预测,并利用studio自带的手机模拟器,将推理结果显示到手机上,最终的效果如下。

边缘计算:基于tflite实现andriod边缘端回归预测推理实战_第1张图片

实现步骤:

1、下载MPG数据集

边缘计算:基于tflite实现andriod边缘端回归预测推理实战_第2张图片

2、利用tf2.0实现云端训练,生成mymodel.h5或者.savedmodel目录

可以参考这篇帖子,写的相对清晰

时序信号的模型使用tflite的示例_sinat_18131557的博客-CSDN博客_tflite文件怎么打开

 3、转化为tflite

参考代码:以.h5为例,其他类同

mymodel = load_model('mymodel.h5')
converter = tf.lite.TFLiteConverter.from_keras_model(mymodel)
converter.post_training_quantize = True
tflite_model = converter.convert()
open('converted_model.tflite', 'wb').write(tflite_model)

4、利用android studio打包成apk,完成手机端推理预测

这一步对于仅熟悉云端的同学来说很陌生,因为手机安卓端是不一样的,因此可以打包成apk实现,本步骤写的稍微详细些

4.1 利用android studio新建工程,设置项参考下图

边缘计算:基于tflite实现andriod边缘端回归预测推理实战_第3张图片

边缘计算:基于tflite实现andriod边缘端回归预测推理实战_第4张图片

边缘计算:基于tflite实现andriod边缘端回归预测推理实战_第5张图片

 4.2 添加一个可以显示预测结果的控件

在Android -> app -> res -> layout -> activity_main.xml
android:id="@+id/result"

边缘计算:基于tflite实现andriod边缘端回归预测推理实战_第6张图片

4.3 将转化完成的tflite模型,放到app-src-main-assets目录下;

边缘计算:基于tflite实现andriod边缘端回归预测推理实战_第7张图片

4.4 修改如下文件,此步骤很关键: 

 修改app下的 Gradle Scripts -> build.gradle,注意是Module:My_Application.app这个,添加依赖项:

implementation 'org.tensorflow:tensorflow-lite:+'

边缘计算:基于tflite实现andriod边缘端回归预测推理实战_第8张图片

 同时添加如下代码:

aaptOptions {
    noCompress "tflite"
}

边缘计算:基于tflite实现andriod边缘端回归预测推理实战_第9张图片

完整build.gradle代码,供参考不一定完全相同:

plugins {
    id 'com.android.application'
}

android {
    compileSdkVersion 33
    buildToolsVersion "33.0.0"

    defaultConfig {
        applicationId "com.example.myapplication"
        minSdkVersion 26
        targetSdkVersion 33
        versionCode 1
        versionName "1.0"

        testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
    }

    buildTypes {
        release {
            minifyEnabled false
            proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro'
        }
    }

    aaptOptions {
        noCompress "tflite"
    }

    compileOptions {
        sourceCompatibility JavaVersion.VERSION_1_8
        targetCompatibility JavaVersion.VERSION_1_8
    }
    buildFeatures {
        viewBinding true
    }
    sourceSets {
        main {
            assets {
                srcDirs 'src\\main\\assets'
            }
        }
    }
}

dependencies {

    implementation 'androidx.appcompat:appcompat:1.2.0'
    implementation 'com.google.android.material:material:1.2.1'
    implementation 'androidx.constraintlayout:constraintlayout:2.0.1'
    testImplementation 'junit:junit:4.+'
    androidTestImplementation 'androidx.test.ext:junit:1.1.2'
    androidTestImplementation 'androidx.test.espresso:espresso-core:3.3.0'
    implementation 'org.tensorflow:tensorflow-lite:+'
}

4.5  修改app -> java -> com.example.myapplication -> MainActivity,该步骤主要是加载tflite模型,定义输入输出,调用模型推理,并在安卓手机模拟器上显示推理结果:

边缘计算:基于tflite实现andriod边缘端回归预测推理实战_第10张图片

MainActivity的完整代码如下,主要改动以黑色加粗标出:

package com.example.myapplication;

import androidx.appcompat.app.AppCompatActivity;

import android.content.Context;
import android.content.res.AssetFileDescriptor;
import android.os.Bundle;
import android.util.Log;
import android.widget.TextView;
import android.widget.Toast;

import java.io.FileInputStream;
import java.io.IOException;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
import org.tensorflow.lite.Interpreter;



public class MainActivity extends AppCompatActivity {
    private static final String TAG = "Test";
    private Interpreter tflite;
    private Context mContext;
    @Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);
        setContentView(R.layout.activity_main);
        loadModule();
    }

    private void loadModule() {
        String model = "converted_model";//模型的名字
        try {
            Interpreter.Options options = new Interpreter.Options();
            options.setNumThreads(4);
            options.setUseNNAPI(true);
            options.setAllowFp16PrecisionForFp32(true);
            // 加载模型文件
            tflite = new Interpreter(loadModelFile(model), options);
            Toast.makeText(MainActivity.this, model + " model load success", Toast.LENGTH_SHORT).show();
            // 调用Test函数
            test();
        } catch (IOException e) {
            Toast.makeText(MainActivity.this, model + " model load fail", Toast.LENGTH_SHORT).show();
            e.printStackTrace();

        }
    }

    // 加载模型文件的函数
    private MappedByteBuffer loadModelFile(String model) throws IOException {
        AssetFileDescriptor fileDescriptor = getApplicationContext().getAssets().openFd(model + ".tflite");
        FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
        FileChannel fileChannel = inputStream.getChannel();
        long startOffset = fileDescriptor.getStartOffset();
        long declaredLength = fileDescriptor.getDeclaredLength();
        return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
    }

    public void test() {
        Log.d(TAG, "-----test()----");
        try {
//          推理预测数据初始化

//            30.02904
            float[] inputData = new float[]{-0.869348F, -1.009459F, -0.784052F, -1.025303F, -0.379759F,
                   -0.516397F, 0.774676F, -0.465148F, -0.495225F};

//          模型输出,定义需要与模型的输出一致,上一节有检查输入输出
            float[][] labelProbArray = new float[1][1];
//             这里需要使用ByteBuffer的形式进行输入,输入的数据是(1,400)个float的数

//            运行模型的predict功能
            tflite.run(inputData, labelProbArray);

//             根据模型的定义,输出相应的信息
            String result;
            result = labelProbArray[0][0] + "";
            Log.d("test", "setText = " + result);
            TextView tv = findViewById(R.id.result);
            tv.setText(result);

        } catch (Exception e) {
            e.printStackTrace();
        }
    }
}

 4.6 run app即可,会生成apk

边缘计算:基于tflite实现andriod边缘端回归预测推理实战_第11张图片

边缘计算:基于tflite实现andriod边缘端回归预测推理实战_第12张图片

 4.5 最终显示结果,结束

边缘计算:基于tflite实现andriod边缘端回归预测推理实战_第13张图片

错误处理:

1、Apps targeting Android 12 and higher are required to specify an explicit value for `android:exported` when the corresponding component has an intent filter defined. See https://developer.android.com/guide/topics/manifest/activity-element#exported for details

如下:

边缘计算:基于tflite实现andriod边缘端回归预测推理实战_第14张图片

 修改:app -> manifests -> AndroidManifest.xml ,在添加如下代码

android:exported="true"

边缘计算:基于tflite实现andriod边缘端回归预测推理实战_第15张图片

你可能感兴趣的:(tensorflow学习,边缘计算,tflite)