在安卓上玩转Tensorflow Lite :数据输入

文章目录

    • 1. 输入图片数据
      • (1)导入依赖
      • (2)读取一个TensorImage
      • (3)把输入图片转换成模型要求的格式和大小
      • (4)把TensorImage转换成 ByteBuffer输入给 tensorflow lite解释器

Tensorflow lite不具备模型的训练功能,只能运行训练好的模型,所以他的核心就是一个模型的解释器(Interpreter)。

要用神经网络模型识别检测一个东西,需要四个要素:

  1. 需要一个神经网络模型(怎么导出tensorflow lite格式的模型?)
  2. 需要把这个神经网络模型运行起来(怎么构造解释器?)
  3. 需要准备输入数据(Java中怎么封装tensor格式的数据)
  4. 需要读取输出数据(如何获取输出的格式?怎么解析它?)

本文主要针对各种各样的输入数据:

1. 输入图片数据

比如有一个图像识别的模型,需要的输入数据格式为

 type: `uint8[1,300,300,3]`
quantization: -1 ≤ 0.0078125 * (q - 128) ≤ 0.9921875

即整型数组[300*300],共三个通道RGB。

比如,我们现在有一张任意大小的图片(Bitmap:width=1920,height=1080)要放在模型里面去识别。

一般步骤是:

  1. 首先把图片裁剪加缩放到指定的大小(要保证图片不能拉伸,且裁剪时要保留最大的内容)。
  2. 然后把图片转换成tensorflowlite能读懂的类型

这里可以借助一个tensorflow lite support库,它可以帮助我们把安卓中的常见格式的图片转换成Tensor的格式输入给tf lite。

(1)导入依赖

dependencies {
    // 导入tflite依赖库
    implementation 'org.tensorflow:tensorflow-lite:2.3.0'
    // 导入tflite support依赖库
    implementation 'org.tensorflow:tensorflow-lite-support:0.1.0-rc1'
}

(2)读取一个TensorImage

首先把我们的Bitmap图片转换成TensorImage图片

// 现有一张Bitmap图 bitmap

// 按照输入的类型要求,创建TensorImage对象
TensorImage tensorImage = new TensorImage(DataType.UINT8);
// 然后把bitmap加载进来
tensorImage.load(bitmap);

(3)把输入图片转换成模型要求的格式和大小

然后把输入的图片按照要求裁剪

// 输入图片可能不是正方形,需要按短的一边裁剪,尽量保留中心最大的面积
int crop=Math.min(bitmap.getWidth(),bitmap.getHeight());
// 制造一个 冰激凌的磨具
ImageProcessor imageProcessor = new ImageProcessor
        .Builder()
        .add(new ResizeWithCropOrPadOp(crop,crop)) 
        .add(new ResizeOp(width_in,height_in,ResizeOp.ResizeMethod.NEAREST_NEIGHBOR))
    	.add(new NormalizeOp(0f,1f)).build();

第一个步骤add(new ResizeWithCropOrPadOp(crop,crop))设置裁剪的输出长度为我们刚才求的长度,会从中心裁剪,实现最大化的裁剪。

第二个步骤.add(new ResizeOp(长,宽,ResizeOp.ResizeMethod.NEAREST_NEIGHBOR)),把之前裁剪下来的区域,进行缩放,缩放到指定的长宽,缩放的法式有:临近插值(NEAREST_NEIGHBOR),双线性插值(BILINEAR)。

第三个步骤.add(new NormalizeOp(均值:0f,标准差:1f))把像素值的范围进行归一化(从[0,255]区间映射到[-1,1]区间)

最重要的一步,用刚才那个模具来裁剪我们之前输入的图片

imageProcessor.process(tensorImage);

(4)把TensorImage转换成 ByteBuffer输入给 tensorflow lite解释器

ByteBuffer byte_buffer= tensorImage.getBuffer();
//最后把准备好的数据传递给 tflite解释器
(Interpreter)tflite.run(byte_buffer,output);

你可能感兴趣的:(android,晋级之路,深度学习)