第一次集成使用 tensorflow
,内心还是有些激动的。开始时候,并不知道怎么进行,其实是一脸茫然的,然后就看了不少文章,关于集成的,大致有了个思路,然后就开始集成测试。这次就总结下具体集成思路和步骤。
方式:
tensorflow android
而不是tensorflow lite
首先,在 android
上集成 tensorflow
,我们可以确定使用 TensorFlowInferenceInterface 类,不知道的自己去查资料。通过阅读官方给的源码,我们大概就可以知道所需要的内容和怎么使用。
TensorFlowInferenceInterface
构造函数
首先,在实例化此类的时候,我们需要提供 assets
和 modelName
,所以可以确定要将模型放在 assets
文件夹下,同时传入模型名字即可。其次,在初始化的时候首先执行的 prepareNativeRuntime()
函数,可以确定要加载相关的 so
库,并且在集成后初始化后,不需要在业务代码中去重复加载,因为这里已经加载过了。
assets
文件夹下so
库 so
库 不需要去加载 public TensorFlowInferenceInterface(AssetManager assetManager, String model) {
prepareNativeRuntime();
this.modelName = model;
this.g = new Graph();
this.sess = new Session(g);
this.runner = sess.runner();
....
}
private void prepareNativeRuntime() {
····
try {
System.loadLibrary("tensorflow_inference");
Log.i(TAG, "Successfully loaded TensorFlow native methods (RunStats error may be ignored)");
} catch (UnsatisfiedLinkError e2) {
throw new RuntimeException(
"Native TF methods not found; check that the correct native"
+ " libraries are present in the APK.");
}
}
}
TensorFlowInferenceInterface
中 feed
、run
和 fetch
函数
feed
重载函数有很多,根据需要传入参数即可,必须传入的为 intputName
和数据 src
。必须要的是传入数据的类型是什么,不然是不成功的。
作用:注入数据
run
重载函数也有好几个,是执行运行的,需要传入 outputName
数组,这里的outputName
需要和 fetch
相关函数中的一致。
作用:运行
fetch
重载的函数也有很多,也是需要传出的即可,必须传入的是 outputName
和 要存储结果的数组 dst
。必须要确定传出结果的数据类型。
作用:取出结果
总结下操作函数
inputName
outputName
: run
和 fetch
均需要1.准备模型 ( .pb
文件)
这个要算法工程师给训练好的模型并打包成 .pb
文件,当然自己可以,自己来。并确定使用模型的时候所需要的输入、输出参数 ,即上述 feed
和 fetch
的参数。
注意将 .pb
文件放入 assets
文件夹下。
2.在项目中引入 tensorflow
提供的 jar
包和 so
库
jar
包和 so
库 下载地址:http://ci.tensorflow.org/view/Nightly/job/nightly-android/ws/out/
jar
包集成方式有两种,一是放置 libs
文件夹下,进行集成,另一种为下面办法。我这边使用的版本是 1.6.0
。
dependencies {
implementation 'org.tensorflow:tensorflow-android:+' // 1.6.0
}
so
库集成方式,将下载的 so
库相关文件夹放置到 libs
下,在 app
下 build.gradle
文件指定 jniLibs.srcDirs
目录即可, 当然其它方式也可以。
sourceSets {
main {
jniLibs.srcDirs = ['libs']
}
}
这时目录结构为:
3.调用
这边是使用 kotlin 实现的
class MainActivity : AppCompatActivity() {
// modelName
private val MODEL_FILE: String = "lstm_150_2_50.pb"
// inputName
private val INPUT_NODE: String = "input"
// outputName
private val OUTPUT_NODE: String = "output"
// TensorFlowInferenceInterface
private lateinit var mTensorFlowInferenceInterface: TensorFlowInferenceInterface
private lateinit var btnResult: AppCompatButton
override fun onCreate(savedInstanceState: Bundle?) {
super.onCreate(savedInstanceState)
setContentView(R.layout.activity_main)
btnResult = findViewById(R.id.btn_result)
// 初始化模型
mTensorFlowInferenceInterface = TensorFlowInferenceInterface(assets, MODEL_FILE)
btnResult.setOnClickListener {
// 整理数据
val valuesList = resources.getString(R.string.tf_values).split(",")
val valuesDatas: ArrayList = ArrayList()
valuesList.forEach {
valuesDatas.add(it.toInt())
}
val datas = valuesDatas.toIntArray()
// 注入: 必须确定 inputName 和数据类型
mTensorFlowInferenceInterface.feed(INPUT_NODE, datas,这里根据模型定义)
// 运行:outputName
mTensorFlowInferenceInterface.run(arrayOf(OUTPUT_NODE))
// 输出: 必须确定 outputName 和数据类型
val result = IntArray(1)
mTensorFlowInferenceInterface.fetch(OUTPUT_NODE, result)
Log.v("Main", "输出结果: ${Arrays.toString(result)}")
tv_result.text = Arrays.toString(result)
}
}
}
4.确认 so
库 和 模型加载成功
so
库加载成功,日志中会出现
Successfully loaded TensorFlow native methods (....)
模型加载成功,日志中出现下面内容
Successfully loaded model from $modelName