tensorflow - 在 Android 中 集成 tensorflow 并使用训练后的模型

第一次集成使用 tensorflow ,内心还是有些激动的。开始时候,并不知道怎么进行,其实是一脸茫然的,然后就看了不少文章,关于集成的,大致有了个思路,然后就开始集成测试。这次就总结下具体集成思路和步骤。

方式: tensorflow android 而不是 tensorflow lite

基本思路

首先,在 android 上集成 tensorflow ,我们可以确定使用 TensorFlowInferenceInterface 类,不知道的自己去查资料。通过阅读官方给的源码,我们大概就可以知道所需要的内容和怎么使用。

TensorFlowInferenceInterface 构造函数

首先,在实例化此类的时候,我们需要提供 assetsmodelName ,所以可以确定要将模型放在 assets 文件夹下,同时传入模型名字即可。其次,在初始化的时候首先执行的 prepareNativeRuntime() 函数,可以确定要加载相关的 so 库,并且在集成后初始化后,不需要在业务代码中去重复加载,因为这里已经加载过了。

  • 模型
  • 模型放在 assets 文件夹下
  • so
  • so 库 不需要去加载
 public TensorFlowInferenceInterface(AssetManager assetManager, String model) {
    prepareNativeRuntime();

    this.modelName = model;
    this.g = new Graph();
    this.sess = new Session(g);
    this.runner = sess.runner();
    ....
}

 private void prepareNativeRuntime() {
      ····
      try {
        System.loadLibrary("tensorflow_inference");
        Log.i(TAG, "Successfully loaded TensorFlow native methods (RunStats error may be ignored)");
      } catch (UnsatisfiedLinkError e2) {
        throw new RuntimeException(
            "Native TF methods not found; check that the correct native"
                + " libraries are present in the APK.");
      }
    }
  }

TensorFlowInferenceInterfacefeedrunfetch 函数

feed 重载函数有很多,根据需要传入参数即可,必须传入的为 intputName 和数据 src 。必须要的是传入数据的类型是什么,不然是不成功的。

作用:注入数据

这里写图片描述

run 重载函数也有好几个,是执行运行的,需要传入 outputName 数组,这里的outputName 需要和 fetch 相关函数中的一致。

作用:运行

这里写图片描述

fetch 重载的函数也有很多,也是需要传出的即可,必须传入的是 outputName 和 要存储结果的数组 dst。必须要确定传出结果的数据类型。

作用:取出结果

这里写图片描述

总结下操作函数

  • inputName
  • 验证的数据
  • outputName : runfetch 均需要
  • 结果放置的数组

实现步骤

1.准备模型 ( .pb 文件)

这个要算法工程师给训练好的模型并打包成 .pb 文件,当然自己可以,自己来。并确定使用模型的时候所需要的输入、输出参数 ,即上述 feedfetch 的参数。

注意将 .pb 文件放入 assets 文件夹下。

2.在项目中引入 tensorflow 提供的 jar 包和 so

jar 包和 so 库 下载地址:http://ci.tensorflow.org/view/Nightly/job/nightly-android/ws/out/

jar 包集成方式有两种,一是放置 libs 文件夹下,进行集成,另一种为下面办法。我这边使用的版本是 1.6.0

dependencies {
    implementation 'org.tensorflow:tensorflow-android:+' // 1.6.0
}

so 库集成方式,将下载的 so 库相关文件夹放置到 libs 下,在 appbuild.gradle 文件指定 jniLibs.srcDirs 目录即可, 当然其它方式也可以。

 sourceSets {
        main {
            jniLibs.srcDirs = ['libs']
        }
    }

这时目录结构为:

这里写图片描述

3.调用

这边是使用 kotlin 实现的

class MainActivity : AppCompatActivity() {

    // modelName
    private val MODEL_FILE: String = "lstm_150_2_50.pb"
    // inputName
    private val INPUT_NODE: String = "input"
    // outputName
    private val OUTPUT_NODE: String = "output"

    // TensorFlowInferenceInterface
    private lateinit var mTensorFlowInferenceInterface: TensorFlowInferenceInterface

    private lateinit var btnResult: AppCompatButton

    override fun onCreate(savedInstanceState: Bundle?) {
        super.onCreate(savedInstanceState)
        setContentView(R.layout.activity_main)
        btnResult = findViewById(R.id.btn_result)

        // 初始化模型
        mTensorFlowInferenceInterface = TensorFlowInferenceInterface(assets, MODEL_FILE)

        btnResult.setOnClickListener {
            // 整理数据
            val valuesList = resources.getString(R.string.tf_values).split(",")
            val valuesDatas: ArrayList = ArrayList()
            valuesList.forEach {
                valuesDatas.add(it.toInt())
            }
            val datas = valuesDatas.toIntArray()
            // 注入: 必须确定 inputName 和数据类型
            mTensorFlowInferenceInterface.feed(INPUT_NODE, datas,这里根据模型定义)
            // 运行:outputName
            mTensorFlowInferenceInterface.run(arrayOf(OUTPUT_NODE))
            // 输出: 必须确定 outputName 和数据类型
            val result = IntArray(1)
            mTensorFlowInferenceInterface.fetch(OUTPUT_NODE, result)
            Log.v("Main", "输出结果: ${Arrays.toString(result)}")
            tv_result.text = Arrays.toString(result)
        }

    }
}

4.确认 so 库 和 模型加载成功

so 库加载成功,日志中会出现

Successfully loaded TensorFlow native methods (....)

模型加载成功,日志中出现下面内容

Successfully loaded model from $modelName

这里写图片描述

你可能感兴趣的:(tensorflow)