Tensorflow lite
源码中提供了对个op的单元测试源码,但是在官方的tflite Makefile中默认并没有编译该部分代码。本文主要是记录在tflite中对op进行单独测试的方法,平台为ARM嵌入式。
在tflite的源码中单元测试的源码一般在op名后面添加有test,在目录 tensorflow/contrib/lite/kernels
下可以看到很多op的单元测试源码,如convolution的实现源码为conv.cc
,则对应的单元测试源码为conv_test.cc
,查看源码后可以知道单元测试采用googletest来实现。另外,基本上所有op的单元测试都会继承tensorflow/contrib/lite/kernels/test_util.h
里面的SingleOpModel
类,而官方源码中的Makefile默认是没有编译test_util.cc的。
下面,本文以编译conv_test.cc
为例说明怎么使用单元测试。
基本思想为:先修改Makefile把test_util.cc
编译进libtensorflow-lite.a
,然后对要测试的conv_test.cc
源码单独写一个cmake去调用新的libtensorflow-lite.a
。
由于unit test需要用到Googletest库,所以需要提前编译准备好Googletest,另外还需要用到absl库。
安装Googletest库
git clone https://github.com/google/googletest
cd googletest
mkdir build
cd build
cmake -DCMAKE_INSTALL_PREFIX=/path/to/yourdir ..
make install
absl地址:https://github.com/abseil/abseil-cpp 先下载放到制定位置,可以暂时不用编译。
准备工作做好以后就可以修改lite源码中的Makefile了,修改的地方主要是添加googletest到INCLUDES,以及添加对其他源码的编译。
添加新的INCLUDES
INCLUDES += -I/path/to/googletest/include
修改CORE_CC_EXCLUDE_SRCS变量
CORE_CC_EXCLUDE_SRCS := \
$(wildcard tensorflow/contrib/lite/*test.cc) \
$(wildcard tensorflow/contrib/lite/*/*test.cc) \
$(wildcard tensorflow/contrib/lite/*/*/*test.cc) \
$(wildcard tensorflow/contrib/lite/*/*/*/*test.cc)
##$(wildcard tensorflow/contrib/lite/kernels/test_util.cc) \ 把这行注释掉
另外还需要添加
CORE_CC_ALL_SRCS += \
$(wildcard tensorflow/core/platform/default/logging.cc) \
$(wildcard tensorflow/core/platform/env_time.cc)
到此,对Makefile的修改就完成了,运行./tensorflow/contrib/lite/tools/make/build_rpi_lib.sh
编译可生成新的libtensorflow-lite.a
库。
Tips: 注意上述编译的库在被conv_test.cc
调用时会出错,原因是env_time.cc
中的 EnvTime
类部分函数还没有实现,这个可以自己把相关函数实现下,可以参考 tensorflow/core/platform/posix/env_time.cc
的实现方式。
接下来编译conv_test.cc
,编译CMAKE的时候遇到一些坑,下面是填坑后的完整CMakeLists.txt
。
cmake_minimum_required(VERSION 3.0)
add_definitions(-std=c++11) #must use c++11
set(CMAKE_SYSTEM_PROCESSOR aarch64)
set(GCC_COMPILER_VERSION "" STRING "GCC Compiler version")
SET(CMAKE_C_COMPILER aarch64-linux-gnu-gcc)
SET(CMAKE_CXX_COMPILER aarch64-linux-gnu-g++)
find_package(Threads)
SET(CMAKE_BUILD_TYPE "Release")
#set(CMAKE_EXE_LINKER_FLAGS "-lpthread -lrt -ldl") #special for tflite compile
INCLUDE_DIRECTORIES("/path/to/tflite_lib/include")
INCLUDE_DIRECTORIES("/path/to/googletest/include")
INCLUDE_DIRECTORIES("/path/to/absl")
LINK_DIRECTORIES("/path/to/tflite/lib")
LINK_DIRECTORIES("/path/to/googletest/lib")
add_executable(ConvUnitTest conv_test.cc)
target_link_libraries(ConvUnitTest libtensorflow-lite.a libgtest.a libgmock.a ${CMAKE_THREAD_LIBS_INIT} ${CMAKE_DL_LIBS})
如果需要对其他的op进行单元测试,则把对应的op_test.cc替换掉上面的conv_test.cc即可。