假如想在 c++ 的应用中使用 tensorflow 做诸如 model inference 的操作,一个简单的方法是链接Python.dll,使用 python 脚本代替完成。但是众所周知,tensorflow 的底层是由 c++ 写的,上述方式虽然简单,但是丧失了性能和实现的优雅性。本文主要介绍如何在 windows 平台上直接调用 c++ 版本的tensorflow
模型在使用 model.save 保存后,必须固化为与语言无关的 pb 格式才能进行被 c++ 版 tensorflow 所调用,可以用从网上抄的一个 h5 转 pb 的脚本:
import tensorflow as tf
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
def h5_to_pb(h5_save_path):
model = tf.keras.models.load_model(h5_save_path, compile=False)
model.summary()
full_model = tf.function(lambda Input: model(Input))
full_model = full_model.get_concrete_function(tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype))
# Get frozen ConcreteFunction
frozen_func = convert_variables_to_constants_v2(full_model)
frozen_func.graph.as_graph_def()
layers = [op.name for op in frozen_func.graph.get_operations()]
print("-" * 50)
print("Frozen model layers: ")
for layer in layers:
print(layer)
print("-" * 50)
print("Frozen model inputs: ")
print(frozen_func.inputs)
print("Frozen model outputs: ")
print(frozen_func.outputs)
tf.io.write_graph(graph_or_graph_def=frozen_func.graph,
logdir="./pb",
name="model.pb",
as_text=False)
h5_to_pb('./model.h5')
可以参考这篇文章在 windows 下编译 tensorflow(链接 ),但是最新版本的tensorflow在本机编译时导出的符号不全,导致报 undefined reference 的链接错误,可以改成用以下命令编译生成动态链接库
bazel build --config=opt //tensorflow/tools/lib_package:libtensorflow
需要注意的是,默认情况下 bazel 编译时会占用 cpu 所有的核,可能占用大量内存导致编译失败,可以用 --jobs 选项限制使用的 cpu 核数
c++ 版 tensorflow 是 bazel 使用 msvc 编译的,所以直接链接 tensorflow_cc.lib 即可,代码和下文中的 “3.a 把 model inference 封装为 dll” 基本类似
mingw 编译器直接链接 tensorflow 是有问题的,其直接原因是二者是不同的编译器,对于 c++ 来说 mangle 后的符号不同(mangle 相关的知识可以自行百度),肯定会报 undefined reference 这个错误。再者,由于标准库的实现方式也不同,所以即使正常链接接,生成了可执行文件,但是 dll 导出的涉及到以标准库为参或为返回值的函数其本身的执行也有可能出现问题。(诸如同一成员函数的偏移地址不一样,可能会导致程序直接崩溃)
有两个方案可以解决这个问题:
这里方案2由于 tensorflow_cc.lib 里面的符号太多,笔者也暂时没找到高效的从 msvc mangle 后的符号 demangle 在 mangle 成 mingw 所认识的符号的方法(下文只会介绍一个低效的方式),加上不确定 tensorflow 和 model inference 相关的函数中是否涉及标准库相关的(比如 std::shared_ptr),所以暂时没有尝试。所以下文只介绍第一种方案。
对于 windows 下的动态链接库,可以用如下命令生成导出符号的定义文件和 .a 文件
gendef xxxx.dll
dlltool -D xxxx.dll -d xxxx.def -l xxxx.a
gendef 和 dlltool 在mingw 的 bin 目录下就能找到,需要注意的是,dlltool 和 动态链接库的位数要保持相同。
假如直接用 extern “C” 导出和使用 C 风格的函数接口时,那么就不存在 mangle 的问题,但是为什么不导出可读性和封装性更强的 c++ 的接口呢?
可是,直接导出 c++ 接口会面临和上文中提到的相同的两个错误,然而都有相应的解决对策:
针对解决对策一,我们先把 model inference 的过程封装为 dll ,接口全部使用数组传参,代码如下:
libevaluate.h
#pragma once
class chessEvaluate {
private:
_declspec(dllexport) chessEvaluate();
public:
_declspec(dllexport) float evaluate(int map[10][9]);
static chessEvaluate& instance() {
static chessEvaluate instance_;
return instance_;
}
};
libevaluate.cpp
#include
#include
#include
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/graph/default_device.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/public/session.h"
template<typename T>
struct is_vector_type : std::false_type {};
template<typename T>
struct is_vector_type<std::vector<T>> : std::true_type {};
template <typename T>
void vectorSize(T& vec,std::vector<int64_t> &size) {
if constexpr (is_vector_type<T>::value) {
size.push_back(vec.size());
vectorSize(vec[0], size);
}
}
template <typename T>
constexpr int vectorDepth() {
if constexpr (is_vector_type<T>::value) {
return vectorDepth<typename T::value_type>() + 1;
}
else {
return 0;
}
}
template <typename Mat,typename Vec,typename ...Args>
void tensorAssign(Mat& mat, Vec& vec, Args... args) {
if constexpr (is_vector_type<Vec>::value) {
for (int i = 0; i < vec.size(); i++) {
tensorAssign(mat, vec[i], args..., i);
}
}
else {
mat(args...) = vec;
}
}
template <typename Mat, typename Vec>
void tensorAssign(Mat& mat, Vec& vec) {
for (int i = 0; i < vec.size(); i++) {
tensorAssign(mat, vec[i], i);
}
}
template <typename T>
std::shared_ptr<tensorflow::Tensor> vector2Tensor(T vec) {
std::vector<int64_t> vecSize_;
vectorSize(vec, vecSize_);
const std::vector<int64_t> vecSize=vecSize_;
auto span = absl::Span<const int64_t>(vecSize.data(), vecSize.size());
auto tensor=std::make_shared<tensorflow::Tensor>(tensorflow::DT_FLOAT, tensorflow::TensorShape(span));
auto input_tensor_mapped = tensor->tensor<float, vectorDepth<T>()>();
tensorAssign(input_tensor_mapped, vec);
return tensor;
}
static tensorflow::Session* session;
static float evaluate_impl(std::vector<std::vector<std::vector<std::vector<float>>>>&& chessboard) {
auto input_tensor_ptr = vector2Tensor(chessboard);
std::vector<tensorflow::Tensor> outputs;
std::string output_node = "Identity:0";
//开始预测,这里的输入名images要和模型的输入相匹配
tensorflow::Status status_run = session->Run({ {"Input:0", *input_tensor_ptr} }, { output_node }, {}, &outputs);
if (!status_run.ok()) {
std::cout << "ERROR: RUN failed..." << std::endl;
std::cout << status_run.ToString() << "\n";
return -1;
}
assert(outputs.size() == 1);
auto p = outputs[0].flat<float>();
return p(0) * 256;
}
#include "libevaluate.h"
chessEvaluate::chessEvaluate() {
std::string model_file = "model.pb";
session = tensorflow::NewSession(tensorflow::SessionOptions()); //创建新会话Session
tensorflow::GraphDef graphdef; //当前模型的图定义
tensorflow::Status status_load = ReadBinaryProto(tensorflow::Env::Default(), model_file, &graphdef); //从pb文件中读取图模型;
if (!status_load.ok()) {
std::cout << "ERROR: Loading model failed..." << model_file << std::endl;
std::cout << status_load.ToString() << "\n";
return;
}
tensorflow::Status status_create = session->Create(graphdef); //将图模型导入会话Session中;
if (!status_create.ok()) {
std::cout << "ERROR: Creating graph in session failed..." << status_create.ToString() << std::endl;
return;
}
return;
}
float chessEvaluate::evaluate(int map[10][9]) {
std::vector<std::vector<std::vector<float>>> chessboard;
for (int i = 0; i < 15; i++) {
std::vector<std::vector<float>> one_piece_chessboard;
for (int j = 0; j < 10; j++) {
std::vector<float> line;
for (int k = 0; k < 9; k++) {
line.push_back((map[j][k] + 1) == i);
}
one_piece_chessboard.emplace_back(line);
}
chessboard.emplace_back(one_piece_chessboard);
}
return evaluate_impl({chessboard});
}
在前面生成的导出符号的定义文件中,可以找到函数名对应的 mangle 后的符号。
;
; Definition file of libevaluate.dll
; Automatic generated by gendef
; written by Kai Tietz 2008
;
LIBRARY "libevaluate.dll"
EXPORTS
; private: __cdecl chessEvaluate::chessEvaluate(void)__ptr64
??0chessEvaluate@@AEAA@XZ
; public: float __cdecl chessEvaluate::evaluate(unknown ecsu[])__ptr64 throw()
?evaluate@chessEvaluate@@QEAAMQEAY08H@Z
以 chessEvaluate::evaluate 这个函数为例,可以看出 msvc 对其 mangle 后的符号是 ?evaluate@chessEvaluate@@QEAAMQEAY08H@Z。接下来的问题是,如何找出 mingw 对上述函数 mangle 后的符号表示?
一个简单的方法是,把上述的 libevaluate.h 加以改造,在 mingw 下编译一下,看看编译后的符号是什么。改造后的 test.cpp :
class chessEvaluate {
private:
__attribute((used)) chessEvaluate() {
}
public:
__attribute((used)) float evaluate(int map[10][9]) {
}
static chessEvaluate& instance() {
static chessEvaluate instance_;
return instance_;
}
};
然后通过以下命令编译并查看 mangle 后的符号
g++ -c test.cpp
nm test.o|grep evaluate
可以看出,chessEvaluate::evaluate 在 mingw 下 mangle 后的符号为 _ZN13chessEvaluate8evaluateEPA9_i 。同理,也可以按照上述方式找出 chessEvaluate的构造函数 mangle 后的符号为 _ZN13chessEvaluateC1Ev
此时可以使用 objcopy 命令,将 .a 文件里 msvc mangle 后的符号替换为 mingw mangle 后的符号。
objcopy --redefine-sym ?evaluate@chessEvaluate@@QEAAMQEAY08H@Z=_ZN13chessEvaluate8evaluateEPA9_i libevaluate.a libevaluate.out.a
objcopy --redefine-sym ??0chessEvaluate@@AEAA@XZ=_ZN13chessEvaluateC1Ev libevaluate.out.a libevaluate.out.a
此时,在 mingw 系的编译器里便可以链接 libevaluate.out.a 正常编译,从而可以在运行时动态链接 libevaluate.dll ,调用封装好的类进行 model inference