TensorFlow不重新编译源码使用C/C++ API推理

TensorFlow可以使用pip 安装tensorflow包然后调用其python接口,或者使用其C++或者C api进行推理。出于性能或者业务等因素部分用户选择了C/C++接口进行推理,C接口推理tensorflow提供了预编译好的头文件和so(https://www.tensorflow.org/install/lang_c),其缺点是不能调用TensorFlow的C++接口,比较不方便。而C++接口通常需要用户自己重新基于源码编译,费事费力(参考Tensorflow C API 从训练到部署:使用 C API 进行预测和部署 - 技术刘   使用C++调用TensorFlow模型简单说明 | Dannyw's Blog等博客)。

如果开发C++代码,链接pip安装的Tensorflow安装目录下面的so,会报如下错误:
E tensorflow/core/common_runtime/session.cc:67] Not found: No session factory registered for the given session options: {target: "" config: } Registered factories are {}.
同时会发现TensorFlow内部的算子都未注册,即使使用-Wl,--whole-archive处理也无法解决。

那么是否可以实现直接使用pip安装的tensorflow的so和头文件,实现C++接口调用推理呢?作者发现了一个方法并分享如下。

main.cpp推理代码example

#include "tensorflow/core/protobuf/meta_graph.pb.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/platform/env.h"
#include 
#include 

using namespace tensorflow;

#ifdef __cplusplus
extern "C" {
#endif

// instantiated in tensorflow _pywrap_tensorflow_internal.so
extern const char* TF_Version(void);

#ifdef __cplusplus
}
#endif

int main() {
  // must be called to load op register
  TF_Version();

  std::string model_path = "resnet_50.pb";
  tensorflow::GraphDef graphdef;
  tensorflow::Status status_load = ReadBinaryProto(tensorflow::Env::Default(), model_path, &graphdef);

  tensorflow::SessionOptions options;
  tensorflow::Session* session;

  session = tensorflow::NewSession(options);
  if (session == nullptr) {
    std::cout << "create new session failed" << std::endl;
    return -1;
  }
  tensorflow::Status status;
  status = session->Extend(graphdef);
  if (!status.ok()) {
    std::cout << "session extend graph failed" << std::endl;
    return -1;
  }

  Tensor x(DT_FLOAT, TensorShape({1, 3, 224, 224}));

  std::vector> input_tensors;
  input_tensors.push_back({"input", x});

  std::vector output_names = {"resnet_model/stage_1/Relu_2"};
  std::vector outputs;
  TF_CHECK_OK(session->Run(input_tensors, output_names, {}, &outputs));

  // release session
  session->Close();
  delete session;
  session = nullptr;

  return 0;
}

这里的核心是调用了TF_Version();(可能其他函数也有类似功效) 从而成功加载so里面的符号,否则并不会加载。具体原因欢迎大家在评论区讨论。这个函数tf 2.x的pip安装包里已经提供了接口定义,而1.1x没有,需要手动定义下。

cmake文件编译选项

核心是需要包含python的so,tf的两个so

project(tf_cpp_test LANGUAGES CXX)

add_compile_options(-fPIC)

# tf version >=1.15 use ABI=0
add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0)

add_executable(
    main
    main.cpp
)

target_include_directories(
    main
    PUBLIC
    $ENV{TF_INCLUDE_PATH}
    $ENV{PYTHON_INCLUDE_PATH}
)

target_link_libraries(
    main 
    PUBLIC
    $ENV{TF_SO_FILE}
    $ENV{TF_SO_PATH}/python/_pywrap_tensorflow_internal.so
    $ENV{PYTHON_SO_FILE}
)

上面的TF_INCLUDE_PATH等可以通过bash脚本获取:

#!/bin/bash
TOOL_SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"

export TF_INCLUDE_PATH=$(python3 -c 'import tensorflow as tf; print(tf.sysconfig.get_compile_flags()[0].strip("-I"))')
export TF_SO_PATH=$(python3 -c 'import tensorflow as tf; print(tf.sysconfig.get_link_flags()[0].strip("-L"))')
export TF_SO_FILE=$(ls $TF_SO_PATH/libtensorflow_framework.* |head -1)
export PYTHON_INCLUDE_PATH=$(python3 -c 'import sysconfig; print(sysconfig.get_path("include"))')
export PYTHON_SO_PATH=$(python3 -c 'import sysconfig; print(sysconfig.get_path("stdlib"))')
export PYTHON_SO_FILE=$(find $PYTHON_SO_PATH/../ -name libpython3*.so|head -1)

mkdir ${TOOL_SCRIPT_DIR}/build
cd ${TOOL_SCRIPT_DIR}/build
cmake ..
make

上述代码测试环境:tf1.15+python3.7(基于conda虚拟环境)

你可能感兴趣的:(TensorFlow,tensorflow,c++,推理,编译,源码)