Android加载tensorflow模型文件(.tflite)

本文以花朵识别项目为例

一、训练花朵识别模型

使用Colab训练TensorFlow Lite模型。训练后下载相应的模型文件(model.tflite)、标签文件(labels.txt)。

二、配置项目

1、项目创建完成后将下载的model.tflite、labels.tx复制到项目app/src/main/assets/。
2、添加TensorFlow Lite依赖库,修改app下的build.gradle文件中的dependencies{}块,添加一下代码:

implementation('org.tensorflow:tensorflow-lite:0.0.0-nightly') { changing = true }
implementation('org.tensorflow:tensorflow-lite-gpu:0.0.0-nightly') { changing = true }
implementation('org.tensorflow:tensorflow-lite-support:0.0.0-nightly') { changing = true }

为防止编译项目时TensorFlow Lite模型文件被编译,android{}代码块中添加

aaptOptions {
  noCompress "tflite"
}

三、编写代码

1、初始化解释器,加载模型、标签文件,指定运行代理、线程数量

/**
   * 初始化分类器
   * @param activity
   * @param device
   * @param numThreads
   * @throws IOException
   */
  protected Classifier(Activity activity, Device device, int numThreads) throws IOException {
    tfliteModel = FileUtil.loadMappedFile(activity, getModelPath());
    switch (device) {
      case NNAPI:
        nnApiDelegate = new NnApiDelegate();
        tfliteOptions.addDelegate(nnApiDelegate);
        break;
      case GPU:
        gpuDelegate = new GpuDelegate();
        tfliteOptions.addDelegate(gpuDelegate);
        break;
      case CPU:
        break;
    }
    tfliteOptions.setNumThreads(numThreads);
    tflite = new Interpreter(tfliteModel, tfliteOptions);

    // Loads labels out from the label file.
    labels = FileUtil.loadLabels(activity, getLabelPath());

    // Reads type and shape of input and output tensors, respectively.
    int imageTensorIndex = 0;
    int[] imageShape = tflite.getInputTensor(imageTensorIndex).shape(); // {1, height, width, 3}
    imageSizeY = imageShape[1];
    imageSizeX = imageShape[2];
    DataType imageDataType = tflite.getInputTensor(imageTensorIndex).dataType();
    int probabilityTensorIndex = 0;
    int[] probabilityShape =
        tflite.getOutputTensor(probabilityTensorIndex).shape(); // {1, NUM_CLASSES}
    DataType probabilityDataType = tflite.getOutputTensor(probabilityTensorIndex).dataType();

    // Creates the input tensor.
    inputImageBuffer = new TensorImage(imageDataType);

    // Creates the output tensor and its processor.
    outputProbabilityBuffer = TensorBuffer.createFixedSize(probabilityShape, probabilityDataType);

    // Creates the post processor for the output probability.
    probabilityProcessor = new TensorProcessor.Builder().add(getPostprocessNormalizeOp()).build();

    LOGGER.d("Created a Tensorflow Lite Image Classifier.");
  }

2、运行解释器,根据图片获取相应的预测结果。

/**
   * 运行解释器并返回结果
   * @param bitmap
   * @param sensorOrientation
   * @return
   */
  public List<Recognition> recognizeImage(final Bitmap bitmap, int sensorOrientation) {
    // Logs this method so that it can be analyzed with systrace.
    Trace.beginSection("recognizeImage");

    Trace.beginSection("loadImage");
    long startTimeForLoadImage = SystemClock.uptimeMillis();
    inputImageBuffer = loadImage(bitmap, sensorOrientation);
    long endTimeForLoadImage = SystemClock.uptimeMillis();
    Trace.endSection();
    LOGGER.v("Timecost to load the image: " + (endTimeForLoadImage - startTimeForLoadImage));

    // Runs the inference call.
    Trace.beginSection("runInference");
    long startTimeForReference = SystemClock.uptimeMillis();
    tflite.run(inputImageBuffer.getBuffer(), outputProbabilityBuffer.getBuffer().rewind());
    long endTimeForReference = SystemClock.uptimeMillis();
    Trace.endSection();
    LOGGER.v("Timecost to run model inference: " + (endTimeForReference - startTimeForReference));

    // Gets the map of label and probability.
    Map<String, Float> labeledProbability =
        new TensorLabel(labels, probabilityProcessor.process(outputProbabilityBuffer))
            .getMapWithFloatValue();
    Trace.endSection();

    // Gets top-k results.
    return getTopKProbability(labeledProbability);
  }

3、关闭解释器、释放资源

  public void close() {
    if (tflite != null) {
      tflite.close();
      tflite = null;
    }
    if (gpuDelegate != null) {
      gpuDelegate.close();
      gpuDelegate = null;
    }
    if (nnApiDelegate != null) {
      nnApiDelegate.close();
      nnApiDelegate = null;
    }
    tfliteModel = null;
  }

四、关键类解释

Interpreter:TensorFlow Lite 解释器,它接收一个模型文件(model file),执行模型文件在输入数据(input data)上定义的运算符(operations),并提供对输出(output)的访问。
java调用解释器方式:

try (Interpreter interpreter = new Interpreter(tensorflow_lite_model_file)) {
  interpreter.run(input, output);
}

Delegate:TensorFlow Lite 解释器可以配置Delegates以在不同设备上使用硬件加速,除CPU加速外还有:NnApiDelegate、GpuDelegate。
注意:官方Delegate API 仍处于试验阶段并将随时进行调整。用户也可以根据需要自定义Delegate。

GpuDelegate delegate = new GpuDelegate();
Interpreter.Options options = (new Interpreter.Options()).addDelegate(delegate);
Interpreter interpreter = new Interpreter(tensorflow_lite_model_file, options);
try {
  interpreter.run(input, output);
}

你可能感兴趣的:(Android开发,tensorflow,tensorflow,android,深度学习)