系统:mac os
TensorFlow·:2.2.0
注意本文步骤的运行环境,其他不同版本环境未测试,仅供参考。
TensorFlow Lite 是一组工具,可帮助开发者在移动设备、嵌入式设备和 IoT 设备上运行 TensorFlow 模型。它支持设备端机器学习推断,延迟较低,并且二进制文件很小。
TensorFlow Lite 包括两个主要组件:
更多介绍请参见官网指南
下面是一个完整从kears模型(结构+权重)转换成tflite模型。参考模型转换器(Converter)的 Python API 指南
并使用python版TensorFlow Lite 解释器运行。参考TensorFlow Lite推断-在Python中加载并运行模型
import tensorflow as tf
import os
from tensorflow.keras.models import load_model, save_model
import numpy as np
def h5_covert_tflite(model_path, output_path):
model = load_model(model_path)
keras_to_tflite_converter = tf.lite.TFLiteConverter.from_keras_model(model)
keras_to_tflite_converter.optimizations = [tf.lite.Optimize.DEFAULT]
keras_tflite = keras_to_tflite_converter.convert()
if not os.path.exists(output_path):
os.mkdir(output_path)
with open(os.path.join(output_path, "keras.tflite"), 'wb') as f:
f.write(keras_tflite)
def test_h5_covert_tflite():
model_path = "kears_model.hdf5"
output_dir = "output_tflite"
h5_covert_tflite(
model_path=model_path,
output_path=output_dir
)
# Load the TFLite model and allocate tensors.
tflite_file_path = os.path.join(output_dir, "my_model.tflite")
interpreter = tf.lite.Interpreter(model_path=tflite_file_path)
interpreter.allocate_tensors()
# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
print(input_details)
print(output_details)
# Test the model on random input data.
input_shape = input_details[0]['shape']
# input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)
input_data = cv2.imread(
"test.jpg",
0) / 255
input_data = input_data.reshape(input_shape).astype(np.float32)
interpreter.set_tensor(input_details[0]['index'], input_data)
interpreter.invoke()
# The function `get_tensor()` returns a copy of the tensor data.
# Use `tensor()` in order to get a pointer to the tensor.
output_data = interpreter.get_tensor(output_details[0]['index'])
print(output_data)
注意:如果遇到报错:tflite AttributeError: module 'six' has no attribute 'ensure_str'
,解决方法是pip install six==1.12.0
,参考stackoverflow。
到这一步已经顺利得到tflite模型
了,并且使用tflite解释器
测试通过。但是我们是在python环境里运行的,我们的目标是在Android环境中运行。所以需要build
Android环境,考虑很多app都是iOS+Android双平台,加之有些预处理需要依赖OpenCV处理。所以这里准备build的Android环境是arm64的C++动态库。
可不可以build静态库啦,当然可以,总之按照自己需求即可。静态库.a和动态库.so的区别可以参考这篇文章:Linux中的动态库和静态库(.a/.la/.so/.o)。个人认为实际交付产品时,使用动态库比较好,文件小。
可能有的小伙伴从github
下源码很慢,好在curl
可以设置代理。我的情况就是从从github
下载被拒绝,只能给git
挂指定github.com
的socks5
代理:
(base) jiangzhigang@192 Code % git clone https://github.com/tensorflow/tensorflow.git
Cloning into 'tensorflow'...
fatal: unable to access 'https://github.com/tensorflow/tensorflow.git/': LibreSSL SSL_connect: SSL_ERROR_SYSCALL in connection to github.com:443
(base) jiangzhigang@192 Code % git config --global http.https://github.com.proxy socks5://127.0.0.1:1086
(base) jiangzhigang@192 Code % git clone https://github.com/tensorflow/tensorflow.git
Cloning into 'tensorflow'...
remote: Enumerating objects: 26, done.
remote: Counting objects: 100% (26/26), done.
remote: Compressing objects: 100% (24/24), done.
Receiving objects: 13% (136962/1028385), 81.85 MiB | 1.92 MiB/s
可以正常下载了。
然后切换到对应版本分支:
(base) jiangzhigang@192 tensorflow % git checkout r2.2
Updating files: 100% (13763/13763), done.
Branch 'r2.2' set up to track remote branch 'r2.2' from 'origin'.
Switched to a new branch 'r2.2'
model是什么版本TensorFlow训练的,build的tflite环境也要对应,虽然发现1.x的tflite也能运行2.2的model。猜测是model比较简单,没有用到新版本的特性。
Bazel相关使用参考:Bazel学习笔记
本人使用的mac机器,原本打算使用brew
直接安装Bazel(最简单),奈何下载的Bazel版本太高(brew默认安装的是最新版,虽然可以改源然后下载对应版本,但是太麻烦),不能build
TensorFlow2.2.0。
(base) jiangzhigang@192 Code % cd tensorflow
(base) jiangzhigang@192 tensorflow % bazel
ERROR: The project you're trying to build requires Bazel 2.0.0 (specified in /Users/jiangzhigang/Code/tensorflow/.bazelversion), but it wasn't found in /Users/jiangzhigang/.bazel/bin.
Bazel binaries for all official releases can be downloaded from here:
https://github.com/bazelbuild/bazel/releases
You can download the required version directly using this command:
(cd "/Users/jiangzhigang/.bazel/bin" && curl -LO https://releases.bazel.build/2.0.0/release/bazel-2.0.0-darwin-x86_64 && chmod +x bazel-2.0.0-darwin-x86_64)
可以看到,TensorFlow 2.2.0
是需要bazel-2.0.0
的,所以是使用源码安装。
一定要使用curl
下载,因为你看Bazel的安装教程会看到:on macOS Catalina, due to Apple’s new app notarization requirements, you will need to download the installer from the terminal using curl
。我自己也试了使用浏览器下载是不能安装的。
# Example installing version `3.2.0`. Replace the version below as appropriate.
export BAZEL_VERSION=2.0.0
curl -fLO "https://github.com/bazelbuild/bazel/releases/download/${BAZEL_VERSION}/bazel-${BAZEL_VERSION}-installer-darwin-x86_64.sh"
如果遇到下载慢,相同道理,给curl
设置socks5
,更多参考:curl 设置代理)。
Curl设置Socks5代理全局生效:
# 修改curl配置文件
vim ~/.curlrc
# 写入
socks5 = "127.0.0.1:1086"
# 如果临时不需要代理使用以下参数
curl --noproxy "*" http://www.google.com
完整安装过程:
(base) jiangzhigang@192 Code % export BAZEL_VERSION=2.0.0
(base) jiangzhigang@192 Code % curl -fLO "https://github.com/bazelbuild/bazel/releases/download/${BAZEL_VERSION}/bazel-${BAZEL_VERSION}-installer-darwin-x86_64.sh"
\ % Total % Received % Xferd Average Speed Time Time Time Current
Dload Upload Total Spent Left Speed
100 662 100 662 0 0 559 0 0:00:01 0:00:01 --:--:-- 559
100 36.3M 100 36.3M 0 0 381k 0 0:01:37 0:01:37 --:--:-- 403k
(base) jiangzhigang@192 Code % chmod +x "bazel-${BAZEL_VERSION}-installer-darwin-x86_64.sh"
(base) jiangzhigang@192 Code % ./bazel-${BAZEL_VERSION}-installer-darwin-x86_64.sh --user
Bazel installer
---------------
Bazel is bundled with software licensed under the GPLv2 with Classpath exception.
You can find the sources next to the installer on our release page:
https://github.com/bazelbuild/bazel/releases
# Release 2.0.0 (2019-12-19)
Baseline: 807ed23e4f53a5e008ec823e9c23e2c9baa36d0d
Cherry picks:
+ db0e32ca6296e56e5314993fe9939bc7331768ec:
build.sh: Fix bug in build script for RC release
+ 85e84f7812f04bc0dbc36376f31b6dd2d229b905:
Set --incompatible_prohibit_aapt1 default to true.
+ 84eae2ff550c433a3d0409cf2b5525059939439d:
Let shellzelisk fallback to bazel-real if it's the requested
version.
+ d5ae460f1581ddf27514b4be18255481b47b4075:
Fix a typo in bazel.sh
Incompatible changes:
- --incompatible_remap_main_repo is enabled by default. Therefore,
both ways of addressing the main repository, by its name and by
'@' are now considered referring to the same repository.
see https://github.com/bazelbuild/bazel/issues/7130
- --incompatible_disallow_dict_lookup_unhashable_keys is enabled by
default https://github.com/bazelbuild/bazel/issues/9184
- --incompatible_remove_native_maven_jar is now enabled by default
and the flag removed. See https://github.com/bazelbuild/bazel/issues/6799
- --incompatible_prohibit_aapt1 is enabled by default.
See https://github.com/bazelbuild/bazel/issues/10000
Important changes:
- --incompatible_proto_output_v2: proto v2 for aquery proto output
formats, which reduces the output size compared to v1. Note that
the messages' ids in v2 are in uint64 instead of string like in
v1.
- Adds --incompatible_remove_enabled_toolchain_types.
- Package loading now consistently fails if package loading had a
glob evaluation that encountered a symlink cycle or symlink
infinite expansion. Previously, such package loading with such
glob evaluations would fail only in some cases.
- The --disk_cache flag can now also be used together
with the gRPC remote cache.
- An action's discover inputs runtime metrics is now categorized as
parse time on the CriticalPathComponent.
- Make the formatting example more like to the written text by
adding an initial description.
- An action's discover inputs runtime metrics is now categorized as
parse time on the CriticalPathComponent.
- Bazel's Debian package and the binary installer now include an
improved wrapper that understands `<WORKSPACE>/.bazelversion`
files and the `$USE_BAZEL_VERSION` environment variable. This is
similar to what Bazelisk offers
(https://github.com/bazelbuild/bazelisk#how-does-bazelisk-know-whi
ch-bazel-version-to-run-and-where-to-get-it-from), except that it
works offline and integrates with apt-get.
- We are planning to deprecate the runfiles manifest files, which
aren't safe in the presence of whitespace, and also unnecessarily
require local CPU when remote execution is used. This release
adds --experimental_skip_runfiles_manifests to disable the
generation of the input manifests (rule.manifest files) in most
cases. Note that this flag has no effect on Windows by default or
if --experimental_enable_runfiles is explicitly set to false.
This release contains contributions from many people at Google, as well as aldersondrive, Benjamin Peterson, Bor Kae Hwang, David Ostrovsky, Jakob Buchgraber, Jin, John Millikin, Keith Smiley, Lauri Peltonen, nikola-sh, Peter Mounce, Tony Hsu.
## Build information
- [Commit](https://github.com/bazelbuild/bazel/commit/50514fc6c1)
Uncompressing......Extracting Bazel installation....
Bazel is now installed!
Make sure you have "/Users/jiangzhigang/bin" in your path. You can also activate bash
completion by adding the following line to your ~/.bash_profile:
source /Users/jiangzhigang/.bazel/bin/bazel-complete.bash
See http://bazel.build/docs/getting-started.html to start a new project!
(base) jiangzhigang@192 Code % export PATH="$PATH:$HOME/bin"
(base) jiangzhigang@192 Code % bazel --version
bazel 2.0.0
安装成功后,到TensorFlow根目录看一下bazel是否兼容:
(base) jiangzhigang@192 Code % cd tensorflow
(base) jiangzhigang@192 tensorflow % bazel
Starting local Bazel server and connecting to it...
[bazel release 2.0.0]
Usage: bazel <command> <options> ...
Available commands:
analyze-profile Analyzes build profile data.
aquery Analyzes the given targets and queries the action graph.
build Builds the specified targets.
canonicalize-flags Canonicalizes a list of bazel options.
clean Removes output files and optionally stops the server.
coverage Generates code coverage report for specified test targets.
cquery Loads, analyzes, and queries the specified targets w/ configurations.
dump Dumps the internal state of the bazel server process.
fetch Fetches external repositories that are prerequisites to the targets.
help Prints help for commands, or the index.
info Displays runtime info about the bazel server.
license Prints the license of this software.
mobile-install Installs targets to mobile devices.
print_action Prints the command line args for compiling a file.
query Executes a dependency graph query.
run Runs the specified target.
shutdown Stops the bazel server.
sync Syncs all repositories specified in the workspace file
test Builds and runs the specified test targets.
version Prints version information for bazel.
Getting more help:
bazel help <command>
Prints help and options for <command>.
bazel help startup_options
Options for the JVM hosting bazel.
bazel help target-syntax
Explains the syntax for specifying targets.
bazel help info-keys
Displays a list of keys used by the info command.
ok, 这是安装好了Bazel
,下面开始正式使用Bazel
来build
环境。
先进入到tensorflow解压出来的目录下,运行下载依赖的脚本,主要是下载flatbuffers用于加载模型文件
(base) jiangzhigang@192 tensorflow % ./tensorflow/lite/tools/make/download_dependencies.sh
downloading https://gitlab.com/libeigen/eigen/-/archive/52a2fbbb008a47c5e3fb8ac1c65c2feecb0c511c/eigen-52a2fbbb008a47c5e3fb8ac1c65c2feecb0c511c.tar.gz
% Total % Received % Xferd Average Speed Time Time Time Current
Dload Upload Total Spent Left Speed
100 2524k 0 2524k 0 0 1048k 0 --:--:-- 0:00:02 --:--:-- 1048k
checking sha256 of tensorflow/lite/tools/make/downloads/eigen
/var/folders/ms/h4r5t2v90k151rsr51v_81vm0000gn/T/tmp.QLyl9rZF/eigen-52a2fbbb008a47c5e3fb8ac1c65c2feecb0c511c.tar.gz: OK
downloading https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/gemmlowp/archive/12fed0cd7cfcd9e169bf1925bc3a7a58725fdcc3.zip
% Total % Received % Xferd Average Speed Time Time Time Current
Dload Upload Total Spent Left Speed
100 913k 100 913k 0 0 344k 0 0:00:02 0:00:02 --:--:-- 344k
checking sha256 of tensorflow/lite/tools/make/downloads/gemmlowp
/var/folders/ms/h4r5t2v90k151rsr51v_81vm0000gn/T/tmp.bpc7qCBI/12fed0cd7cfcd9e169bf1925bc3a7a58725fdcc3.zip: OK
Archive: /var/folders/ms/h4r5t2v90k151rsr51v_81vm0000gn/T/tmp.bpc7qCBI/12fed0cd7cfcd9e169bf1925bc3a7a58725fdcc3.zip
12fed0cd7cfcd9e169bf1925bc3a7a58725fdcc3
creating: /var/folders/ms/h4r5t2v90k151rsr51v_81vm0000gn/T/tmp.7VZuSwcc/gemmlowp-12fed0cd7cfcd9e169bf1925bc3a7a58725fdcc3/
inflating: /var/folders/ms/h4r5t2v90k151rsr51v_81vm0000gn/T/tmp.7VZuSwcc/gemmlowp-12fed0cd7cfcd9e169bf1925bc3a7a58725fdcc3/.gitignore
inflating: /var/folders/ms/h4r5t2v90k151rsr51v_81vm0000gn/T/tmp.7VZuSwcc/gemmlowp-12fed0cd7cfcd9e169bf1925bc3a7a58725fdcc3/.travis.yml
...
...
...
inflating: /var/folders/ms/h4r5t2v90k151rsr51v_81vm0000gn/T/tmp.0IvYH1Bq/FP16-febbb1c163726b5db24bed55cc9dc42529068997/third-party/eigen-half.h
inflating: /var/folders/ms/h4r5t2v90k151rsr51v_81vm0000gn/T/tmp.0IvYH1Bq/FP16-febbb1c163726b5db24bed55cc9dc42529068997/third-party/float16-compressor.h
inflating: /var/folders/ms/h4r5t2v90k151rsr51v_81vm0000gn/T/tmp.0IvYH1Bq/FP16-febbb1c163726b5db24bed55cc9dc42529068997/third-party/half.hpp
inflating: /var/folders/ms/h4r5t2v90k151rsr51v_81vm0000gn/T/tmp.0IvYH1Bq/FP16-febbb1c163726b5db24bed55cc9dc42529068997/third-party/npy-halffloat.h
download_dependencies.sh completed successfully.
下载完的依赖库位于:/Users/jiangzhigang/Code/tensorflow/tensorflow/lite/tools/make/downloads
(base) jiangzhigang@192 tensorflow % ls /Users/jiangzhigang/Code/tensorflow/tensorflow/lite/tools/make/downloads
absl eigen farmhash fft2d flatbuffers fp16 gemmlowp googletest neon_2_sse
依赖库在导入Android时也有用处。
还是在tensorflow的根目录里面,运行配置文件:
(base) jiangzhigang@192 tensorflow % ./configure
我提前安装好了Android Studio
,所以是直接配置SDK和NDK的。如果没有安装Android Studio
需要自己下载SDN和NDK,也可以使用TensorFlow官网提供的Docker 环境
期间让你配置很多东西,前两个是python的路径,他们给的一般是正确的,直接回车。
中间会让你配置很多杂七杂八的,咱们不需要,直接输入No:N
最重要的一步,就是当出现配置Android时,要输入yes:y
Would you like to interactively configure ./WORKSPACE for Android builds? [y/N]: y
Searching for NDK and SDK installations.
然后会让你配置Android NDK的路径,把你开始解压出来的路径复制进去。回车。然后他会识别到版本,让你确认。回车就行。
然后会让你配置Android SDK的路径,和上一步一样,把解压路径复制进去。
完整配置过程:
(base) jiangzhigang@192 tensorflow % ./configure
You have bazel 2.0.0 installed.
Please specify the location of python. [Default is /Users/jiangzhigang/opt/miniconda3/bin/python]:
Found possible Python library paths:
/Users/jiangzhigang/opt/miniconda3/lib/python3.7/site-packages
Please input the desired Python library path to use. Default is [/Users/jiangzhigang/opt/miniconda3/lib/python3.7/site-packages]
Do you wish to build TensorFlow with OpenCL SYCL support? [y/N]: N
No OpenCL SYCL support will be enabled for TensorFlow.
Do you wish to build TensorFlow with ROCm support? [y/N]: N
No ROCm support will be enabled for TensorFlow.
Do you wish to build TensorFlow with CUDA support? [y/N]: N
No CUDA support will be enabled for TensorFlow.
Do you wish to download a fresh release of clang? (Experimental) [y/N]: N
Clang will not be downloaded.
Please specify optimization flags to use during compilation when bazel option "--config=opt" is specified [Default is -march=native -Wno-sign-compare]:
Would you like to interactively configure ./WORKSPACE for Android builds? [y/N]: y
Searching for NDK and SDK installations.
Please specify the home path of the Android NDK to use. [Default is /Users/jiangzhigang/library/Android/Sdk/ndk-bundle]:
WARNING: The NDK version in /Users/jiangzhigang/library/Android/Sdk/ndk-bundle is 21, which is not supported by Bazel (officially supported versions: [10, 11, 12, 13, 14, 15, 16, 17, 18]). Please use another version. Compiling Android targets may result in confusing errors.
Please specify the (min) Android NDK API level to use. [Available levels: ['16', '17', '18', '19', '21', '22', '23', '24', '26', '27', '28', '29', '30']] [Default is 21]:
Please specify the home path of the Android SDK to use. [Default is /Users/jiangzhigang/library/Android/Sdk]:
Please specify the Android SDK API level to use. [Available levels: ['29', '30']] [Default is 30]:
Please specify an Android build tools version to use. [Available versions: ['28.0.3', '30.0.2']] [Default is 30.0.2]:
Do you wish to build TensorFlow with iOS support? [y/N]: y
iOS support will be enabled for TensorFlow.
Preconfigured Bazel build configs. You can use any of the below by adding "--config=<>" to your build command. See .bazelrc for more details.
--config=mkl # Build with MKL support.
--config=monolithic # Config for mostly static monolithic build.
--config=ngraph # Build with Intel nGraph support.
--config=numa # Build with NUMA support.
--config=dynamic_kernels # (Experimental) Build kernels into separate shared objects.
--config=v2 # Build TensorFlow 2.x instead of 1.x.
Preconfigured Bazel build configs to DISABLE default on features:
--config=noaws # Disable AWS S3 filesystem support.
--config=nogcp # Disable GCP support.
--config=nohdfs # Disable HDFS support.
--config=nonccl # Disable NVIDIA NCCL support.
Configuration finished
注意,新版本Android Studio的NDK默认目录为:/Users/jiangzhigang/Library/Android/sdk/ndk/21.1.6352462
需要手动指定
我们希望输出的是一个动态链接库,即.so文件,所以,要配置
打开tensorflow/lite/BUILD文件配置输出的选项,在末尾添加如下内容:
cc_binary(
name = "libtensorflowLite.so",
linkopts = ["-shared", "-Wl,-soname=libtensorflowLite.so"],
visibility = ["//visibility:public"],
linkshared = 1,
copts = tflite_copts(),
deps = [
":framework",
"//tensorflow/lite/kernels:builtin_ops",
],
)
执行:
(base) jiangzhigang@192 tensorflow % bazel build -c opt //tensorflow/lite:libtensorflowLite.so --config=android_arm64 --cxxopt="-std=c++11"
不出意外,build过程如下:
(base) jiangzhigang@192 tensorflow % bazel build -c opt //tensorflow/lite:libtensorflowLite.so --config=android_arm64 --cxxopt="-std=c++11"
INFO: Options provided by the client:
Inherited 'common' options: --isatty=1 --terminal_columns=210
INFO: Reading rc options for 'build' from /Users/jiangzhigang/Code/tensorflow/.bazelrc:
Inherited 'common' options: --experimental_repo_remote_exec
INFO: Reading rc options for 'build' from /Users/jiangzhigang/Code/tensorflow/.bazelrc:
'build' options: --apple_platform_type=macos --define framework_shared_object=true --define open_source_build=true --java_toolchain=//third_party/toolchains/java:tf_java_toolchain --host_java_toolchain=//third_party/toolchains/java:tf_java_toolchain --define=use_fast_cpp_protos=true --define=allow_oversize_protos=true --spawn_strategy=standalone -c opt --announce_rc --define=grpc_no_ares=true --noincompatible_remove_legacy_whole_archive --noincompatible_prohibit_aapt1 --enable_platform_specific_config --config=v2
INFO: Reading rc options for 'build' from /Users/jiangzhigang/Code/tensorflow/.tf_configure.bazelrc:
'build' options: --action_env PYTHON_BIN_PATH=/Users/jiangzhigang/opt/miniconda3/bin/python --action_env PYTHON_LIB_PATH=/Users/jiangzhigang/opt/miniconda3/lib/python3.7/site-packages --python_path=/Users/jiangzhigang/opt/miniconda3/bin/python --config=xla --action_env ANDROID_NDK_HOME=/Users/jiangzhigang/Library/Android/sdk/ndk/21.3.6528147 --action_env ANDROID_NDK_API_LEVEL=21 --action_env ANDROID_BUILD_TOOLS_VERSION=30.0.2 --action_env ANDROID_SDK_API_LEVEL=30 --action_env ANDROID_SDK_HOME=/Users/jiangzhigang/library/Android/Sdk --action_env TF_CONFIGURE_IOS=0
INFO: Found applicable config definition build:v2 in file /Users/jiangzhigang/Code/tensorflow/.bazelrc: --define=tf_api_version=2 --action_env=TF2_BEHAVIOR=1
INFO: Found applicable config definition build:xla in file /Users/jiangzhigang/Code/tensorflow/.bazelrc: --action_env=TF_ENABLE_XLA=1 --define=with_xla_support=true
INFO: Found applicable config definition build:android_arm64 in file /Users/jiangzhigang/Code/tensorflow/.bazelrc: --config=android --cpu=arm64-v8a --fat_apk_cpu=arm64-v8a
INFO: Found applicable config definition build:android in file /Users/jiangzhigang/Code/tensorflow/.bazelrc: --crosstool_top=//external:android/crosstool --host_crosstool_top=@bazel_tools//tools/cpp:toolchain
INFO: Found applicable config definition build:macos in file /Users/jiangzhigang/Code/tensorflow/.bazelrc: --copt=-w --define=PREFIX=/usr --define=LIBDIR=$(PREFIX)/lib --define=INCLUDEDIR=$(PREFIX)/include --cxxopt=-std=c++14 --host_cxxopt=-std=c++14
WARNING: The major revision of the Android NDK referenced by android_ndk_repository rule 'androidndk' is 21. The major revisions supported by Bazel are [10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]. Bazel will attempt to treat the NDK as if it was r20. This may cause compilation and linkage problems. Please download a supported NDK version.
INFO: Analyzed target //tensorflow/lite:libtensorflowLite.so (61 packages loaded, 7812 targets configured).
INFO: Found 1 target...
Target //tensorflow/lite:libtensorflowLite.so up-to-date:
bazel-bin/tensorflow/lite/libtensorflowLite.so
INFO: Elapsed time: 369.020s, Critical Path: 119.45s
INFO: 308 processes: 308 local.
INFO: Build completed successfully, 385 total actions
so文件在/Users/jiangzhigang/Code/tensorflow/bazel-bin/tensorflow/lite/libtensorflowLite.so
只有一个动态链接库也是没法用的,我们还需要有头文件,只有靠头文件才能找到对应的库。
头文件从哪里来呢,很简单,把tensorflow/lite中所有的.h文件和他们的层级目录全部整理出来。
也有个简单的方法,用shell脚本,运行
# 找到所有的 .h文件,全压缩到一个headers.tar中
(base) jiangzhigang@192 tensorflow % find tensorflow/lite -name "*.h" | tar -cf headers.tar -T -
还有一个重要的头文件faltbuffer,它是tf lite的一个依赖库,也需要整理头文件出来。这个库位于tensorflow/lite/tools/make/downloads/flatbuffers
,把里面的include文件夹全部拷贝出来,放在一个flatbuffers文件夹里面。一般eigen
库也是必须的,相当于python下的numpy
。TensorFlow 2.2.0
又多了一个必须依赖的库absl
,谷歌的通用基础库。这些都位于tensorflow/lite/tools/make/downloads/
目录下。
整理后的样子:
cmake_minimum_required(VERSION 3.4.1)
set(libs "${CMAKE_SOURCE_DIR}/src/main/jniLibs")
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${
CMAKE_SOURCE_DIR}/src/main/jniLibs/${
ANDROID_ABI})
# 引入TensorFlow头文件
include_directories(../../third_party/tensorflow/)
# 引入TensorFlow依赖库的头文件
include_directories(../../third_party/flatbuffers/)
include_directories(../../third_party/absl/)
# Android项目中的C++文件列表作为SRC_LIST变量
aux_source_directory(src/main/cpp/ SRC_LIST)
# C++代码作为一个native动态库
add_library(native-lib SHARED ${
SRC_LIST})
# 直接对so动态库链接,链接目标native动态库
target_link_libraries(native-lib ${
libs}/${
ANDROID_ABI}/libtensorflowLite.so)
set(CMAKE_BUILD_TYPE Debug CACHE STRING "set build type to release")
上面CMakeLists文件只是演示,不是很标准。opencv按需要导入。
然后是把so库拷贝到对应jniLibs/arm64-v8a
目录下。third_party
拷贝到对应位置。
detector.hpp:
#include
#include
#include
#include
#include
#include
class Detector {
public:
cv::Mat sourceImage;
std::string modelPath;
Detector(cv::Mat sourceImage_, const std::string &modelPath_) : sourceImage(sourceImage_), modelPath(modelPath_) {
}
cv::Mat clip_center(float ratio = 1.0, cv::Vec2f pan = cv::Vec2f(0, 0), bool keep_ratio = false);
void fill_input(const std::unique_ptr<tflite::Interpreter>& interpreter);
float blockDetection();
};
detector.cpp
#include "detector.hpp"
#include
#define TFLITE_MINIMAL_CHECK(x) if (!(x)){ spdlog::info("Error at %s:%d\n", __FILE__, __LINE__);exit(1);}
cv::Mat Detector::clip_center(float ratio, cv::Vec2f pan, bool keep_ratio) {
int h = sourceImage.rows;
int w = sourceImage.cols;
float h_ = ratio * h;
float w_ = ratio * w;
if (keep_ratio) {
h_ = std::min(h_, w_);
w_ = h_;
}
return cv::Mat(sourceImage, cv::Rect(int((w - w_) / 2 + pan[0]), int((h - h_) / 2 + pan[1]), int(w_), int(h_)));
}
void Detector::fill_input(const std::unique_ptr<tflite::Interpreter> &interpreter) {
cv::Mat input_image = clip_center(1, cv::Vec2f(0, 0), true);
int input_height = 256;
int input_width = 256;
cv::resize(input_image, input_image, cv::Size(input_height, input_width));
for (int i = 0; i < input_height; i++) {
uchar *row = input_image.ptr<uchar>(i);
for (int j = 0; j < input_width; j++) {
int index = i * input_width + j;
interpreter->typed_input_tensor<float>(0)[index] = row[j] / 255.0f;
}
}
}
float Detector::detection() {
// Load the model
std::unique_ptr<tflite::FlatBufferModel> model = tflite::FlatBufferModel::BuildFromFile(modelPath.c_str());
TFLITE_MINIMAL_CHECK(model != nullptr);
// Build the interpreter
tflite::ops::builtin::BuiltinOpResolver resolver;
std::unique_ptr<tflite::Interpreter> interpreter;
tflite::InterpreterBuilder(*model, resolver)(&interpreter);
// Resize input tensors, if desired.
TFLITE_MINIMAL_CHECK(interpreter->AllocateTensors() == kTfLiteOk);
// Fill `input`.
fill_input(interpreter);
TFLITE_MINIMAL_CHECK(interpreter->Invoke() == kTfLiteOk);
float output = interpreter->typed_output_tensor<float>(0)[0];
spdlog::info("camera block detection result:{}", output);
return output;
}
有关更多示例代码,请参见 minimal.cc 和 label_image.cc。
以及TensorFlow Lite推断-用C ++加载并运行模型 和 c++ - TensorFlow Lite模型测试:正确的代码,类的概率
Android调用c++代码使用jni,不在赘述。
总之最后调用步骤为:java/kotlin → \rightarrow → jni → \rightarrow →c++
Tensorflow lite 编译Android JNI C++ 动态链接库(步骤详细生动,还怕搞不定吗)
C++TensorFlow Lite编译及使用