本文以花朵识别项目为例
使用Colab训练TensorFlow Lite模型。训练后下载相应的模型文件(model.tflite)、标签文件(labels.txt)。
1、项目创建完成后将下载的model.tflite、labels.tx复制到项目app/src/main/assets/。
2、添加TensorFlow Lite依赖库,修改app下的build.gradle文件中的dependencies{}块,添加一下代码:
implementation('org.tensorflow:tensorflow-lite:0.0.0-nightly') { changing = true }
implementation('org.tensorflow:tensorflow-lite-gpu:0.0.0-nightly') { changing = true }
implementation('org.tensorflow:tensorflow-lite-support:0.0.0-nightly') { changing = true }
为防止编译项目时TensorFlow Lite模型文件被编译,android{}代码块中添加
aaptOptions {
noCompress "tflite"
}
1、初始化解释器,加载模型、标签文件,指定运行代理、线程数量
/**
* 初始化分类器
* @param activity
* @param device
* @param numThreads
* @throws IOException
*/
protected Classifier(Activity activity, Device device, int numThreads) throws IOException {
tfliteModel = FileUtil.loadMappedFile(activity, getModelPath());
switch (device) {
case NNAPI:
nnApiDelegate = new NnApiDelegate();
tfliteOptions.addDelegate(nnApiDelegate);
break;
case GPU:
gpuDelegate = new GpuDelegate();
tfliteOptions.addDelegate(gpuDelegate);
break;
case CPU:
break;
}
tfliteOptions.setNumThreads(numThreads);
tflite = new Interpreter(tfliteModel, tfliteOptions);
// Loads labels out from the label file.
labels = FileUtil.loadLabels(activity, getLabelPath());
// Reads type and shape of input and output tensors, respectively.
int imageTensorIndex = 0;
int[] imageShape = tflite.getInputTensor(imageTensorIndex).shape(); // {1, height, width, 3}
imageSizeY = imageShape[1];
imageSizeX = imageShape[2];
DataType imageDataType = tflite.getInputTensor(imageTensorIndex).dataType();
int probabilityTensorIndex = 0;
int[] probabilityShape =
tflite.getOutputTensor(probabilityTensorIndex).shape(); // {1, NUM_CLASSES}
DataType probabilityDataType = tflite.getOutputTensor(probabilityTensorIndex).dataType();
// Creates the input tensor.
inputImageBuffer = new TensorImage(imageDataType);
// Creates the output tensor and its processor.
outputProbabilityBuffer = TensorBuffer.createFixedSize(probabilityShape, probabilityDataType);
// Creates the post processor for the output probability.
probabilityProcessor = new TensorProcessor.Builder().add(getPostprocessNormalizeOp()).build();
LOGGER.d("Created a Tensorflow Lite Image Classifier.");
}
2、运行解释器,根据图片获取相应的预测结果。
/**
* 运行解释器并返回结果
* @param bitmap
* @param sensorOrientation
* @return
*/
public List<Recognition> recognizeImage(final Bitmap bitmap, int sensorOrientation) {
// Logs this method so that it can be analyzed with systrace.
Trace.beginSection("recognizeImage");
Trace.beginSection("loadImage");
long startTimeForLoadImage = SystemClock.uptimeMillis();
inputImageBuffer = loadImage(bitmap, sensorOrientation);
long endTimeForLoadImage = SystemClock.uptimeMillis();
Trace.endSection();
LOGGER.v("Timecost to load the image: " + (endTimeForLoadImage - startTimeForLoadImage));
// Runs the inference call.
Trace.beginSection("runInference");
long startTimeForReference = SystemClock.uptimeMillis();
tflite.run(inputImageBuffer.getBuffer(), outputProbabilityBuffer.getBuffer().rewind());
long endTimeForReference = SystemClock.uptimeMillis();
Trace.endSection();
LOGGER.v("Timecost to run model inference: " + (endTimeForReference - startTimeForReference));
// Gets the map of label and probability.
Map<String, Float> labeledProbability =
new TensorLabel(labels, probabilityProcessor.process(outputProbabilityBuffer))
.getMapWithFloatValue();
Trace.endSection();
// Gets top-k results.
return getTopKProbability(labeledProbability);
}
3、关闭解释器、释放资源
public void close() {
if (tflite != null) {
tflite.close();
tflite = null;
}
if (gpuDelegate != null) {
gpuDelegate.close();
gpuDelegate = null;
}
if (nnApiDelegate != null) {
nnApiDelegate.close();
nnApiDelegate = null;
}
tfliteModel = null;
}
Interpreter:TensorFlow Lite 解释器,它接收一个模型文件(model file),执行模型文件在输入数据(input data)上定义的运算符(operations),并提供对输出(output)的访问。
java调用解释器方式:
try (Interpreter interpreter = new Interpreter(tensorflow_lite_model_file)) {
interpreter.run(input, output);
}
Delegate:TensorFlow Lite 解释器可以配置Delegates以在不同设备上使用硬件加速,除CPU加速外还有:NnApiDelegate、GpuDelegate。
注意:官方Delegate API 仍处于试验阶段并将随时进行调整。用户也可以根据需要自定义Delegate。
GpuDelegate delegate = new GpuDelegate();
Interpreter.Options options = (new Interpreter.Options()).addDelegate(delegate);
Interpreter interpreter = new Interpreter(tensorflow_lite_model_file, options);
try {
interpreter.run(input, output);
}