*本篇文章已授权微信公众号 guolin_blog (郭霖)独家发布
TensorFlow Lite是一款专门针对移动设备的深度学习框架,移动设备深度学习框架是部署在手机或者树莓派等小型移动设备上的深度学习框架,可以使用训练好的模型在手机等设备上完成推理任务。这一类框架的出现,可以使得一些推理的任务可以在本地执行,不需要再调用服务器的网络接口,大大减少了预测时间。在前几篇文章中已经介绍了百度的paddle-mobile,小米的mace,还有腾讯的ncnn。这在本章中我们将介绍谷歌的TensorFlow Lite。
TensorFlow Lite的GitHub地址:https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite
手机上执行预测,首先需要一个训练好的模型,这个模型不能是TensorFlow原来格式的模型,TensorFlow Lite使用的模型格式是另一种格式的模型。下面就介绍如何使用这个格式的模型。
获取模型主要有两种方法,第一种是在训练的时候就保存tflite
模型,另外一种就是使用其他格式的TensorFlow模型转换成tflite
模型。
1、最方便的就是在训练的时候保存tflite
格式的模型,主要是使用到tf.contrib.lite.toco_convert()
接口,下面就是一个简单的例子:
import tensorflow as tf
img = tf.placeholder(name="img", dtype=tf.float32, shape=(1, 64, 64, 3))
val = img + tf.constant([1., 2., 3.]) + tf.constant([1., 4., 4.])
out = tf.identity(val, name="out")
with tf.Session() as sess:
tflite_model = tf.contrib.lite.toco_convert(sess.graph_def, [img], [out])
open("converteds_model.tflite", "wb").write(tflite_model)
最后获得的converteds_model.tflite
文件就可以直接在TensorFlow Lite上使用。
2、第二种就是把tensorflow保存的其他模型转换成tflite
,我们可以在以下的链接下载模型:
tensorflow模型:https://github.com/tensorflow/models/tree/master/research/slim#pre-trained-models
上面提供的模型同时也包括了tflite
模型,我们可以直接拿来使用,但是我们也可以使用其他格式的模型来转换。比如我们下载一个mobilenet_v1_1.0_224.tgz,解压之后获得以下文件:
mobilenet_v1_1.0_224.ckpt.data-00000-of-00001 mobilenet_v1_1.0_224_eval.pbtxt mobilenet_v1_1.0_224.tflite
mobilenet_v1_1.0_224.ckpt.index mobilenet_v1_1.0_224_frozen.pb
mobilenet_v1_1.0_224.ckpt.meta mobilenet_v1_1.0_224_info.txt
首先要安装Bazel,可以参考:https://docs.bazel.build/versions/master/install-ubuntu.html ,只需要完成Installing using binary installer
这一部分即可。
然后克隆TensorFlow的源码:
git clone https://github.com/tensorflow/tensorflow.git
接着编译转换工具,这个编译时间可能比较长:
cd tensorflow/
bazel build tensorflow/python/tools:freeze_graph
bazel build tensorflow/contrib/lite/toco:toco
获得到转换工具之后,我们就可以开始转换模型了,以下操作是冻结图。
input_graph
对应的是.pb
文件;input_checkpoint
对应的是mobilenet_v1_1.0_224.ckpt.data-00000-of-00001
,但是在使用的使用是去掉后缀名的。output_node_names
这个可以在mobilenet_v1_1.0_224_info.txt
中获取。不过要注意的是我们下载的模型已经是冻结过来,所以不用再执行这个操作。但如果是其他的模型,要先冻结图,然后再执行之后的操作。
./freeze_graph --input_graph=/mobilenet_v1_1.0_224/mobilenet_v1_1.0_224_frozen.pb \
--input_checkpoint=/mobilenet_v1_1.0_224/mobilenet_v1_1.0_224.ckpt \
--input_binary=true \
--output_graph=/tmp/frozen_mobilenet_v1_224.pb \
--output_node_names=MobilenetV1/Predictions/Reshape_1
以下操作就是把已经冻结的图转换成.tflite
:
input_file
是已经冻结的图;output_file
是转换后输出的路径;output_arrays
这个可以在mobilenet_v1_1.0_224_info.txt
中获取;input_shapes
这个是预测数据的shape./toco --input_file=/tmp/mobilenet_v1_1.0_224_frozen.pb \
--input_format=TENSORFLOW_GRAPHDEF \
--output_format=TFLITE \
--output_file=/tmp/mobilenet_v1_1.0_224.tflite \
--inference_type=FLOAT \
--input_type=FLOAT \
--input_arrays=input \
--output_arrays=MobilenetV1/Predictions/Reshape_1 \
--input_shapes=1,224,224,3
经过上面的步骤就可以获取到mobilenet_v1_1.0_224.tflite
模型了,之后我们会在Android项目中使用它。
有了上面的模型之后,我们就使用Android Studio创建一个Android项目,一路默认就可以了,并不需要C++的支持,因为我们使用到的TensorFlow Lite是Java代码的,开发起来非常方便。
1、创建完成之后,在app
目录下的build.gradle
配置文件加上以下配置信息:
在dependencies
下加上包的引用,第一个是图片加载框架Glide,第二个就是我们这个项目的核心TensorFlow Lite:
implementation 'com.github.bumptech.glide:glide:4.3.1'
implementation 'org.tensorflow:tensorflow-lite:0.0.0-nightly'
然后在android
下加上以下代码,这个主要是限制不要对tensorflow lite的模型进行压缩,压缩之后就无法加载模型了:
//set no compress models
aaptOptions {
noCompress "tflite"
}
2、在main
目录下创建assets
文件夹,这个文件夹主要是存放tflite
模型和label名称文件。
3、以下是主界面的代码MainActivity.java
,这个代码比较长,我们来分析这段代码,重要的方法介绍如下:
loadModelFile()
方法是把模型文件读取成MappedByteBuffer
,之后给Interpreter
类初始化模型,这个模型存放在main
的assets
目录下。load_model()
方法是加载模型,并得到一个对象tflite
,之后就是使用这个对象来预测图像,同时可以使用这个对象设置一些参数,比如设置使用的线程数量tflite.setNumThreads(4);
showDialog()
方法是显示弹窗,通过这个弹窗的选择不同的模型。readCacheLabelFromLocalFile()
方法是读取文件种分类标签对应的名称,这个文件比较长,可以参考这篇文章获取标签名称,也可以下载笔者的项目,里面有对用的文件。这个文件cacheLabel.txt
跟模型一样存放在assets
目录下。predict_image()
方法是预测图片并显示结果的,预测的流程是:获取图片的路径,然后使用对图片进行压缩,之后把图片转换成ByteBuffer
格式的数据,最后调用tflite.run()
方法进行预测。get_max_result()
方法是获取最大概率的标签。package com.yeyupiaoling.testtflite;
import android.Manifest;
import android.app.Activity;
import android.content.DialogInterface;
import android.content.Intent;
import android.content.pm.PackageManager;
import android.content.res.AssetFileDescriptor;
import android.content.res.AssetManager;
import android.graphics.Bitmap;
import android.net.Uri;
import android.os.Bundle;
import android.support.annotation.NonNull;
import android.support.annotation.Nullable;
import android.support.v4.app.ActivityCompat;
import android.support.v4.content.ContextCompat;
import android.support.v7.app.AlertDialog;
import android.support.v7.app.AppCompatActivity;
import android.text.method.ScrollingMovementMethod;
import android.util.Log;
import android.view.View;
import android.widget.Button;
import android.widget.ImageView;
import android.widget.TextView;
import android.widget.Toast;
import com.bumptech.glide.Glide;
import com.bumptech.glide.load.engine.DiskCacheStrategy;
import com.bumptech.glide.request.RequestOptions;
import org.tensorflow.lite.Interpreter;
import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.nio.ByteBuffer;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
import java.util.ArrayList;
import java.util.List;
public class MainActivity extends AppCompatActivity {
private static final String TAG = MainActivity.class.getName();
private static final int USE_PHOTO = 1001;
private static final int START_CAMERA = 1002;
private String camera_image_path;
private ImageView show_image;
private TextView result_text;
private String assets_path = "lite_images";
private boolean load_result = false;
private int[] ddims = {1, 3, 224, 224};
private int model_index = 0;
private List<String> resultLabel = new ArrayList<>();
private Interpreter tflite = null;
private static final String[] PADDLE_MODEL = {
"mobilenet_v1",
"mobilenet_v2"
};
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);
init_view();
readCacheLabelFromLocalFile();
}
// initialize view
private void init_view() {
request_permissions();
show_image = (ImageView) findViewById(R.id.show_image);
result_text = (TextView) findViewById(R.id.result_text);
result_text.setMovementMethod(ScrollingMovementMethod.getInstance());
Button load_model = (Button) findViewById(R.id.load_model);
Button use_photo = (Button) findViewById(R.id.use_photo);
Button start_photo = (Button) findViewById(R.id.start_camera);
load_model.setOnClickListener(new View.OnClickListener() {
@Override
public void onClick(View view) {
showDialog();
}
});
// use photo click
use_photo.setOnClickListener(new View.OnClickListener() {
@Override
public void onClick(View view) {
if (!load_result) {
Toast.makeText(MainActivity.this, "never load model", Toast.LENGTH_SHORT).show();
return;
}
PhotoUtil.use_photo(MainActivity.this, USE_PHOTO);
}
});
// start camera click
start_photo.setOnClickListener(new View.OnClickListener() {
@Override
public void onClick(View view) {
if (!load_result) {
Toast.makeText(MainActivity.this, "never load model", Toast.LENGTH_SHORT).show();
return;
}
camera_image_path = PhotoUtil.start_camera(MainActivity.this, START_CAMERA);
}
});
}
/**
* Memory-map the model file in Assets.
*/
private MappedByteBuffer loadModelFile(String model) throws IOException {
AssetFileDescriptor fileDescriptor = getApplicationContext().getAssets().openFd(model + ".tflite");
FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
FileChannel fileChannel = inputStream.getChannel();
long startOffset = fileDescriptor.getStartOffset();
long declaredLength = fileDescriptor.getDeclaredLength();
return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
}
// load infer model
private void load_model(String model) {
try {
tflite = new Interpreter(loadModelFile(model));
Toast.makeText(MainActivity.this, model + " model load success", Toast.LENGTH_SHORT).show();
Log.d(TAG, model + " model load success");
tflite.setNumThreads(4);
load_result = true;
} catch (IOException e) {
Toast.makeText(MainActivity.this, model + " model load fail", Toast.LENGTH_SHORT).show();
Log.d(TAG, model + " model load fail");
load_result = false;
e.printStackTrace();
}
}
public void showDialog() {
AlertDialog.Builder builder = new AlertDialog.Builder(MainActivity.this);
// set dialog title
builder.setTitle("Please select model");
// set dialog icon
builder.setIcon(android.R.drawable.ic_dialog_alert);
// able click other will cancel
builder.setCancelable(true);
// cancel button
builder.setNegativeButton("cancel", null);
// set list
builder.setSingleChoiceItems(PADDLE_MODEL, model_index, new DialogInterface.OnClickListener() {
@Override
public void onClick(DialogInterface dialog, int which) {
model_index = which;
load_model(PADDLE_MODEL[model_index]);
dialog.dismiss();
}
});
// show dialog
builder.show();
}
private void readCacheLabelFromLocalFile() {
try {
AssetManager assetManager = getApplicationContext().getAssets();
BufferedReader reader = new BufferedReader(new InputStreamReader(assetManager.open("cacheLabel.txt")));
String readLine = null;
while ((readLine = reader.readLine()) != null) {
resultLabel.add(readLine);
}
reader.close();
} catch (Exception e) {
Log.e("labelCache", "error " + e);
}
}
@Override
protected void onActivityResult(int requestCode, int resultCode, @Nullable Intent data) {
String image_path;
RequestOptions options = new RequestOptions().skipMemoryCache(true).diskCacheStrategy(DiskCacheStrategy.NONE);
if (resultCode == Activity.RESULT_OK) {
switch (requestCode) {
case USE_PHOTO:
if (data == null) {
Log.w(TAG, "user photo data is null");
return;
}
Uri image_uri = data.getData();
Glide.with(MainActivity.this).load(image_uri).apply(options).into(show_image);
// get image path from uri
image_path = PhotoUtil.get_path_from_URI(MainActivity.this, image_uri);
// predict image
predict_image(image_path);
break;
case START_CAMERA:
// show photo
Glide.with(MainActivity.this).load(camera_image_path).apply(options).into(show_image);
// predict image
predict_image(camera_image_path);
break;
}
}
}
// predict image
private void predict_image(String image_path) {
// picture to float array
Bitmap bmp = PhotoUtil.getScaleBitmap(image_path);
ByteBuffer inputData = PhotoUtil.getScaledMatrix(bmp, ddims);
try {
// Data format conversion takes too long
// Log.d("inputData", Arrays.toString(inputData));
float[][] labelProbArray = new float[1][1001];
long start = System.currentTimeMillis();
// get predict result
tflite.run(inputData, labelProbArray);
long end = System.currentTimeMillis();
long time = end - start;
float[] results = new float[labelProbArray[0].length];
System.arraycopy(labelProbArray[0], 0, results, 0, labelProbArray[0].length);
// show predict result and time
int r = get_max_result(results);
String show_text = "result:" + r + "\nname:" + resultLabel.get(r) + "\nprobability:" + results[r] + "\ntime:" + time + "ms";
result_text.setText(show_text);
} catch (Exception e) {
e.printStackTrace();
}
// get max probability label
private int get_max_result(float[] result) {
float probability = result[0];
int r = 0;
for (int i = 0; i < result.length; i++) {
if (probability < result[i]) {
probability = result[i];
r = i;
}
}
return r;
}
// request permissions
private void request_permissions() {
List<String> permissionList = new ArrayList<>();
if (ContextCompat.checkSelfPermission(this, Manifest.permission.CAMERA) != PackageManager.PERMISSION_GRANTED) {
permissionList.add(Manifest.permission.CAMERA);
}
if (ContextCompat.checkSelfPermission(this, Manifest.permission.WRITE_EXTERNAL_STORAGE) != PackageManager.PERMISSION_GRANTED) {
permissionList.add(Manifest.permission.WRITE_EXTERNAL_STORAGE);
}
if (ContextCompat.checkSelfPermission(this, Manifest.permission.READ_EXTERNAL_STORAGE) != PackageManager.PERMISSION_GRANTED) {
permissionList.add(Manifest.permission.READ_EXTERNAL_STORAGE);
}
// if list is not empty will request permissions
if (!permissionList.isEmpty()) {
ActivityCompat.requestPermissions(this, permissionList.toArray(new String[permissionList.size()]), 1);
}
}
@Override
public void onRequestPermissionsResult(int requestCode, @NonNull String[] permissions, @NonNull int[] grantResults) {
super.onRequestPermissionsResult(requestCode, permissions, grantResults);
switch (requestCode) {
case 1:
if (grantResults.length > 0) {
for (int i = 0; i < grantResults.length; i++) {
int grantResult = grantResults[i];
if (grantResult == PackageManager.PERMISSION_DENIED) {
String s = permissions[i];
Toast.makeText(this, s + " permission was denied", Toast.LENGTH_SHORT).show();
}
}
}
break;
}
}
}
4、 以下的代码片段是一个工具类PhotoUtil.java
,各方法功能如下:
start_camera()
方法是启动相机拍照并返回图片的路径,兼容了Android 7.0。use_photo()
方法是打开相册,获取选择的图片的URI。get_path_from_URI()
方法是把图片的URI转换成图片路径。getScaledMatrix()
方法是把图片的Bitmap格式转换成TensorFlow Lite所需的数据格式。getScaleBitmap()
方法是压缩图片,防止内存溢出。package com.yeyupiaoling.testtflite;
import android.app.Activity;
import android.content.Context;
import android.content.Intent;
import android.database.Cursor;
import android.graphics.Bitmap;
import android.graphics.BitmapFactory;
import android.net.Uri;
import android.os.Build;
import android.os.Environment;
import android.provider.MediaStore;
import android.support.v4.content.FileProvider;
import android.util.Log;
import java.io.File;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
public class PhotoUtil {
// start camera
public static String start_camera(Activity activity, int requestCode) {
Uri imageUri;
// save image in cache path
File outputImage = new File(Environment.getExternalStorageDirectory().getAbsolutePath()
+ "/lite_mobile/", System.currentTimeMillis() + ".jpg");
Log.d("outputImage", outputImage.getAbsolutePath());
try {
if (outputImage.exists()) {
outputImage.delete();
}
File out_path = new File(Environment.getExternalStorageDirectory().getAbsolutePath()
+ "/lite_mobile/");
if (!out_path.exists()) {
out_path.mkdirs();
}
outputImage.createNewFile();
} catch (IOException e) {
e.printStackTrace();
}
if (Build.VERSION.SDK_INT >= 24) {
// compatible with Android 7.0 or over
imageUri = FileProvider.getUriForFile(activity,
"com.yeyupiaoling.testtflite.fileprovider", outputImage);
} else {
imageUri = Uri.fromFile(outputImage);
}
// set system camera Action
Intent intent = new Intent(MediaStore.ACTION_IMAGE_CAPTURE);
intent.addFlags(Intent.FLAG_GRANT_READ_URI_PERMISSION);
// set save photo path
intent.putExtra(MediaStore.EXTRA_OUTPUT, imageUri);
// set photo quality, min is 0, max is 1
intent.putExtra(MediaStore.EXTRA_VIDEO_QUALITY, 0);
activity.startActivityForResult(intent, requestCode);
// return image absolute path
return outputImage.getAbsolutePath();
}
// get picture in photo
public static void use_photo(Activity activity, int requestCode) {
Intent intent = new Intent(Intent.ACTION_PICK);
intent.setType("image/*");
activity.startActivityForResult(intent, requestCode);
}
// get photo from Uri
public static String get_path_from_URI(Context context, Uri uri) {
String result;
Cursor cursor = context.getContentResolver().query(uri, null, null, null, null);
if (cursor == null) {
result = uri.getPath();
} else {
cursor.moveToFirst();
int idx = cursor.getColumnIndex(MediaStore.Images.ImageColumns.DATA);
result = cursor.getString(idx);
cursor.close();
}
return result;
}
// TensorFlow model,get predict data
public static ByteBuffer getScaledMatrix(Bitmap bitmap, int[] ddims) {
ByteBuffer imgData = ByteBuffer.allocateDirect(ddims[0] * ddims[1] * ddims[2] * ddims[3] * 4);
imgData.order(ByteOrder.nativeOrder());
// get image pixel
int[] pixels = new int[ddims[2] * ddims[3]];
Bitmap bm = Bitmap.createScaledBitmap(bitmap, ddims[2], ddims[3], false);
bm.getPixels(pixels, 0, bm.getWidth(), 0, 0, ddims[2], ddims[3]);
int pixel = 0;
for (int i = 0; i < ddims[2]; ++i) {
for (int j = 0; j < ddims[3]; ++j) {
final int val = pixels[pixel++];
imgData.putFloat(((((val >> 16) & 0xFF) - 128f) / 128f));
imgData.putFloat(((((val >> 8) & 0xFF) - 128f) / 128f));
imgData.putFloat((((val & 0xFF) - 128f) / 128f));
}
}
if (bm.isRecycled()) {
bm.recycle();
}
return imgData;
}
// compress picture
public static Bitmap getScaleBitmap(String filePath) {
BitmapFactory.Options opt = new BitmapFactory.Options();
opt.inJustDecodeBounds = true;
BitmapFactory.decodeFile(filePath, opt);
int bmpWidth = opt.outWidth;
int bmpHeight = opt.outHeight;
int maxSize = 500;
// compress picture with inSampleSize
opt.inSampleSize = 1;
while (true) {
if (bmpWidth / opt.inSampleSize < maxSize || bmpHeight / opt.inSampleSize < maxSize) {
break;
}
opt.inSampleSize *= 2;
}
opt.inJustDecodeBounds = false;
return BitmapFactory.decodeFile(filePath, opt);
}
}
5、AndroidManifest.xml
下加上申请的权限,用到了相机和读取外部存储的内存:
<uses-permission android:name="android.permission.CAMERA"/>
<uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE"/>
<uses-permission android:name="android.permission.WRITE_EXTERNAL_STORAGE"/>
然后还要在application
下加上以下的配置信息,这个主要是为了兼容Android 7.0的相机:
<provider
android:name="android.support.v4.content.FileProvider"
android:authorities="com.yeyupiaoling.testtflite.fileprovider"
android:exported="false"
android:grantUriPermissions="true">
<meta-data
android:name="android.support.FILE_PROVIDER_PATHS"
android:resource="@xml/file_paths"/>
provider>
6、之后在res
创建一个xml
目录,然后创建一个file_paths.xml
文件,在这个文件中加上以下代码,这个是我们拍照之后图片存放的位置:
<resources>
<external-path
name="images"
path="lite_mobile/" />
resources>
7、主界面布局代码activity_main.xml
:
<RelativeLayout xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:app="http://schemas.android.com/apk/res-auto"
xmlns:tools="http://schemas.android.com/tools"
android:layout_width="match_parent"
android:layout_height="match_parent"
tools:context=".MainActivity">
<LinearLayout
android:id="@+id/btn1_ll"
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:layout_alignParentBottom="true"
android:orientation="horizontal">
<Button
android:id="@+id/use_photo"
android:layout_width="0dp"
android:layout_height="wrap_content"
android:layout_weight="1"
android:text="相册" />
<Button
android:id="@+id/start_camera"
android:layout_width="0dp"
android:layout_height="wrap_content"
android:layout_weight="1"
android:text="拍照" />
LinearLayout>
<LinearLayout
android:id="@+id/btn2_ll"
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:layout_above="@id/btn1_ll"
android:orientation="horizontal">
<Button
android:id="@+id/load_model"
android:layout_width="0dp"
android:layout_height="wrap_content"
android:layout_weight="1"
android:text="加载模型" />
LinearLayout>
<TextView
android:id="@+id/result_text"
android:layout_width="match_parent"
android:layout_height="150dp"
android:layout_above="@id/btn2_ll"
android:hint="预测结果会在这里显示"
android:inputType="textMultiLine"
android:textSize="16sp"
tools:ignore="TextViewEdits" />
<ImageView
android:id="@+id/show_image"
android:layout_width="match_parent"
android:layout_height="match_parent"
android:layout_above="@id/result_text"
android:layout_alignParentTop="true" />
RelativeLayout>
上面已经提高了全部代码,这里为了方便读者调试,这里可以在这里下载项目,然后使用Android Studio打开。