读了本文,你可以实现从云端利用DNN模型进行训练,模型保存.h5格式(基于keras)或是saved model格式(tf2.0版本),模型转化为tflite,利用android studio 编写java接口程序,实现模型最终的推理预测,并利用studio自带的手机模拟器,将推理结果显示到手机上,最终的效果如下。
实现步骤:
1、下载MPG数据集
2、利用tf2.0实现云端训练,生成mymodel.h5或者.savedmodel目录
可以参考这篇帖子,写的相对清晰
时序信号的模型使用tflite的示例_sinat_18131557的博客-CSDN博客_tflite文件怎么打开
3、转化为tflite
参考代码:以.h5为例,其他类同
mymodel = load_model('mymodel.h5')
converter = tf.lite.TFLiteConverter.from_keras_model(mymodel)
converter.post_training_quantize = True
tflite_model = converter.convert()
open('converted_model.tflite', 'wb').write(tflite_model)
4、利用android studio打包成apk,完成手机端推理预测
这一步对于仅熟悉云端的同学来说很陌生,因为手机安卓端是不一样的,因此可以打包成apk实现,本步骤写的稍微详细些
4.1 利用android studio新建工程,设置项参考下图
4.2 添加一个可以显示预测结果的控件
在Android -> app -> res -> layout -> activity_main.xml android:id="@+id/result"
4.3 将转化完成的tflite模型,放到app-src-main-assets目录下;
4.4 修改如下文件,此步骤很关键:
修改app下的 Gradle Scripts -> build.gradle,注意是Module:My_Application.app这个,添加依赖项:
implementation 'org.tensorflow:tensorflow-lite:+'
同时添加如下代码:
aaptOptions { noCompress "tflite" }
完整build.gradle代码,供参考不一定完全相同:
plugins { id 'com.android.application' } android { compileSdkVersion 33 buildToolsVersion "33.0.0" defaultConfig { applicationId "com.example.myapplication" minSdkVersion 26 targetSdkVersion 33 versionCode 1 versionName "1.0" testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner" } buildTypes { release { minifyEnabled false proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro' } } aaptOptions { noCompress "tflite" } compileOptions { sourceCompatibility JavaVersion.VERSION_1_8 targetCompatibility JavaVersion.VERSION_1_8 } buildFeatures { viewBinding true } sourceSets { main { assets { srcDirs 'src\\main\\assets' } } } } dependencies { implementation 'androidx.appcompat:appcompat:1.2.0' implementation 'com.google.android.material:material:1.2.1' implementation 'androidx.constraintlayout:constraintlayout:2.0.1' testImplementation 'junit:junit:4.+' androidTestImplementation 'androidx.test.ext:junit:1.1.2' androidTestImplementation 'androidx.test.espresso:espresso-core:3.3.0' implementation 'org.tensorflow:tensorflow-lite:+' }
4.5 修改app -> java -> com.example.myapplication -> MainActivity,该步骤主要是加载tflite模型,定义输入输出,调用模型推理,并在安卓手机模拟器上显示推理结果:
MainActivity的完整代码如下,主要改动以黑色加粗标出:
package com.example.myapplication; import androidx.appcompat.app.AppCompatActivity; import android.content.Context; import android.content.res.AssetFileDescriptor; import android.os.Bundle; import android.util.Log; import android.widget.TextView; import android.widget.Toast; import java.io.FileInputStream; import java.io.IOException; import java.nio.MappedByteBuffer; import java.nio.channels.FileChannel; import org.tensorflow.lite.Interpreter; public class MainActivity extends AppCompatActivity { private static final String TAG = "Test"; private Interpreter tflite; private Context mContext; @Override protected void onCreate(Bundle savedInstanceState) { super.onCreate(savedInstanceState); setContentView(R.layout.activity_main); loadModule(); } private void loadModule() { String model = "converted_model";//模型的名字 try { Interpreter.Options options = new Interpreter.Options(); options.setNumThreads(4); options.setUseNNAPI(true); options.setAllowFp16PrecisionForFp32(true); // 加载模型文件 tflite = new Interpreter(loadModelFile(model), options); Toast.makeText(MainActivity.this, model + " model load success", Toast.LENGTH_SHORT).show(); // 调用Test函数 test(); } catch (IOException e) { Toast.makeText(MainActivity.this, model + " model load fail", Toast.LENGTH_SHORT).show(); e.printStackTrace(); } } // 加载模型文件的函数 private MappedByteBuffer loadModelFile(String model) throws IOException { AssetFileDescriptor fileDescriptor = getApplicationContext().getAssets().openFd(model + ".tflite"); FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor()); FileChannel fileChannel = inputStream.getChannel(); long startOffset = fileDescriptor.getStartOffset(); long declaredLength = fileDescriptor.getDeclaredLength(); return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength); } public void test() { Log.d(TAG, "-----test()----"); try { // 推理预测数据初始化 // 30.02904 float[] inputData = new float[]{-0.869348F, -1.009459F, -0.784052F, -1.025303F, -0.379759F, -0.516397F, 0.774676F, -0.465148F, -0.495225F}; // 模型输出,定义需要与模型的输出一致,上一节有检查输入输出 float[][] labelProbArray = new float[1][1]; // 这里需要使用ByteBuffer的形式进行输入,输入的数据是(1,400)个float的数 // 运行模型的predict功能 tflite.run(inputData, labelProbArray); // 根据模型的定义,输出相应的信息 String result; result = labelProbArray[0][0] + ""; Log.d("test", "setText = " + result); TextView tv = findViewById(R.id.result); tv.setText(result); } catch (Exception e) { e.printStackTrace(); } } }
4.6 run app即可,会生成apk
4.5 最终显示结果,结束
错误处理:
1、Apps targeting Android 12 and higher are required to specify an explicit value for `android:exported` when the corresponding component has an intent filter defined. See https://developer.android.com/guide/topics/manifest/activity-element#exported for details
如下:
修改:app -> manifests -> AndroidManifest.xml ,在添加如下代码
android:exported="true"