第一次集成使用 tensorflow ,内心还是有些激动的。开始时候,并不知道怎么进行,其实是一脸茫然的,然后就看了不少文章,关于集成的,大致有了个思路,然后就开始集成测试。这次就总结下具体集成思路和步骤。
方式: tensorflow android 而不是 tensorflow lite
基本思路
首先,在 android 上集成 tensorflow ,我们可以确定使用 TensorFlowInferenceInterface 类,不知道的自己去查资料。通过阅读官方给的源码,我们大概就可以知道所需要的内容和怎么使用。
TensorFlowInferenceInterface 构造函数
首先,在实例化此类的时候,我们需要提供 assets 和 modelName ,所以可以确定要将模型放在 assets 文件夹下,同时传入模型名字即可。其次,在初始化的时候首先执行的 prepareNativeRuntime() 函数,可以确定要加载相关的 so 库,并且在集成后初始化后,不需要在业务代码中去重复加载,因为这里已经加载过了。
模型
模型放在 assets 文件夹下
so 库
so 库 不需要去加载
public TensorFlowInferenceInterface(AssetManager assetManager, String model) {
prepareNativeRuntime();
this.modelName = model;
this.g = new Graph();
this.sess = new Session(g);
this.runner = sess.runner();
....
}
private void prepareNativeRuntime() {
····
try {
System.loadLibrary("tensorflow_inference");
Log.i(TAG, "Successfully loaded TensorFlow native methods (RunStats error may be ignored)");
} catch (UnsatisfiedLinkError e2) {
throw new RuntimeException(
"Native TF methods not found; check that the correct native"
+ " libraries are present in the APK.");
}
}
}
TensorFlowInferenceInterface 中 feed 、run 和 fetch 函数
feed 重载函数有很多,根据需要传入参数即可,必须传入的为 intputName 和数据 src 。必须要的是传入数据的类型是什么,不然是不成功的。
作用:注入数据
run 重载函数也有好几个,是执行运行的,需要传入 outputName 数组,这里的outputName 需要和 fetch 相关函数中的一致。
作用:运行
fetch 重载的函数也有很多,也是需要传出的即可,必须传入的是 outputName 和 要存储结果的数组 dst。必须要确定传出结果的数据类型。
作用:取出结果
总结下操作函数
inputName
验证的数据
outputName : run 和 fetch 均需要
结果放置的数组
实现步骤
1.准备模型 ( .pb 文件)
这个要算法工程师给训练好的模型并打包成 .pb 文件,当然自己可以,自己来。并确定使用模型的时候所需要的输入、输出参数 ,即上述 feed 和 fetch 的参数。
注意将 .pb 文件放入 assets 文件夹下。
2.在项目中引入 tensorflow 提供的 jar 包和 so 库
jar 包集成方式有两种,一是放置 libs 文件夹下,进行集成,另一种为下面办法。我这边使用的版本是 1.6.0 。
dependencies {
implementation 'org.tensorflow:tensorflow-android:+' // 1.6.0
}
so 库集成方式,将下载的 so 库相关文件夹放置到 libs 下,在 app 下 build.gradle 文件指定 jniLibs.srcDirs 目录即可, 当然其它方式也可以。
sourceSets {
main {
jniLibs.srcDirs = ['libs']
}
}
这时目录结构为:
3.调用
这边是使用 kotlin 实现的
class MainActivity : AppCompatActivity() {
// modelName
private val MODEL_FILE: String = "lstm_150_2_50.pb"
// inputName
private val INPUT_NODE: String = "input"
// outputName
private val OUTPUT_NODE: String = "output"
// TensorFlowInferenceInterface
private lateinit var mTensorFlowInferenceInterface: TensorFlowInferenceInterface
private lateinit var btnResult: AppCompatButton
override fun onCreate(savedInstanceState: Bundle?) {
super.onCreate(savedInstanceState)
setContentView(R.layout.activity_main)
btnResult = findViewById(R.id.btn_result)
// 初始化模型
mTensorFlowInferenceInterface = TensorFlowInferenceInterface(assets, MODEL_FILE)
btnResult.setOnClickListener {
// 整理数据
val valuesList = resources.getString(R.string.tf_values).split(",")
val valuesDatas: ArrayList = ArrayList()
valuesList.forEach {
valuesDatas.add(it.toInt())
}
val datas = valuesDatas.toIntArray()
// 注入: 必须确定 inputName 和数据类型
mTensorFlowInferenceInterface.feed(INPUT_NODE, datas,这里根据模型定义)
// 运行:outputName
mTensorFlowInferenceInterface.run(arrayOf(OUTPUT_NODE))
// 输出: 必须确定 outputName 和数据类型
val result = IntArray(1)
mTensorFlowInferenceInterface.fetch(OUTPUT_NODE, result)
Log.v("Main", "输出结果: ${Arrays.toString(result)}")
tv_result.text = Arrays.toString(result)
}
}
}
4.确认 so 库 和 模型加载成功
so 库加载成功,日志中会出现
Successfully loaded TensorFlow native methods (....)
模型加载成功,日志中出现下面内容
Successfully loaded model from $modelName