在tensorfow lite中对各op进行单元测试

Tensorflow lite源码中提供了对个op的单元测试源码,但是在官方的tflite Makefile中默认并没有编译该部分代码。本文主要是记录在tflite中对op进行单独测试的方法,平台为ARM嵌入式。

概要

在tflite的源码中单元测试的源码一般在op名后面添加有test,在目录 tensorflow/contrib/lite/kernels下可以看到很多op的单元测试源码,如convolution的实现源码为conv.cc,则对应的单元测试源码为conv_test.cc,查看源码后可以知道单元测试采用googletest来实现。另外,基本上所有op的单元测试都会继承tensorflow/contrib/lite/kernels/test_util.h里面的SingleOpModel类,而官方源码中的Makefile默认是没有编译test_util.cc的。
下面,本文以编译conv_test.cc为例说明怎么使用单元测试。

基本思想为:先修改Makefile把test_util.cc编译进libtensorflow-lite.a,然后对要测试的conv_test.cc源码单独写一个cmake去调用新的libtensorflow-lite.a

由于unit test需要用到Googletest库,所以需要提前编译准备好Googletest,另外还需要用到absl库。

安装Googletest库

git clone https://github.com/google/googletest 
cd googletest
mkdir build
cd build
cmake -DCMAKE_INSTALL_PREFIX=/path/to/yourdir ..
make install 

absl地址:https://github.com/abseil/abseil-cpp 先下载放到制定位置,可以暂时不用编译。

修改Makefile,编译新的tflite库

准备工作做好以后就可以修改lite源码中的Makefile了,修改的地方主要是添加googletest到INCLUDES,以及添加对其他源码的编译。

添加新的INCLUDES

INCLUDES += -I/path/to/googletest/include

修改CORE_CC_EXCLUDE_SRCS变量

CORE_CC_EXCLUDE_SRCS := \
$(wildcard tensorflow/contrib/lite/*test.cc) \
$(wildcard tensorflow/contrib/lite/*/*test.cc) \
$(wildcard tensorflow/contrib/lite/*/*/*test.cc) \
$(wildcard tensorflow/contrib/lite/*/*/*/*test.cc) 
##$(wildcard tensorflow/contrib/lite/kernels/test_util.cc) \   把这行注释掉

另外还需要添加

CORE_CC_ALL_SRCS += \
$(wildcard tensorflow/core/platform/default/logging.cc) \
$(wildcard tensorflow/core/platform/env_time.cc)

到此,对Makefile的修改就完成了,运行./tensorflow/contrib/lite/tools/make/build_rpi_lib.sh编译可生成新的libtensorflow-lite.a库。

Tips: 注意上述编译的库在被conv_test.cc调用时会出错,原因是env_time.cc中的 EnvTime类部分函数还没有实现,这个可以自己把相关函数实现下,可以参考 tensorflow/core/platform/posix/env_time.cc的实现方式。

编译conv_test.cc

接下来编译conv_test.cc,编译CMAKE的时候遇到一些坑,下面是填坑后的完整CMakeLists.txt

cmake_minimum_required(VERSION 3.0)
add_definitions(-std=c++11)  #must use c++11
set(CMAKE_SYSTEM_PROCESSOR aarch64)
set(GCC_COMPILER_VERSION "" STRING "GCC Compiler version")

SET(CMAKE_C_COMPILER   aarch64-linux-gnu-gcc) 
SET(CMAKE_CXX_COMPILER aarch64-linux-gnu-g++) 
find_package(Threads)
SET(CMAKE_BUILD_TYPE "Release")
#set(CMAKE_EXE_LINKER_FLAGS "-lpthread -lrt -ldl") #special for tflite compile
INCLUDE_DIRECTORIES("/path/to/tflite_lib/include")
INCLUDE_DIRECTORIES("/path/to/googletest/include")
INCLUDE_DIRECTORIES("/path/to/absl")

LINK_DIRECTORIES("/path/to/tflite/lib")
LINK_DIRECTORIES("/path/to/googletest/lib")

add_executable(ConvUnitTest conv_test.cc)
target_link_libraries(ConvUnitTest libtensorflow-lite.a libgtest.a libgmock.a ${CMAKE_THREAD_LIBS_INIT} ${CMAKE_DL_LIBS})

如果需要对其他的op进行单元测试,则把对应的op_test.cc替换掉上面的conv_test.cc即可。

你可能感兴趣的:(tensorflow)