C++ API载入tensorflow graph

通过C++ API载入tensorflow graph

在tensorflow repo中,和C++相关的tutorial远没有python的那么详尽。这篇文章主要介绍如何利用C++来载入一个预训练好的graph,以便于单独使用或者嵌入到其他app中。

Requirements

  • 安装bazel:tensorflow是使用bazel来进行编译的,所以如果要编译其他需要用到tensorflow的文件,我们就需要用到bazel。关于bazel,如果想要了解更多,可以参考我的另外两篇博客:Bazel入门:编译C++项目,Bazel入门2:C++编译常见用例。

  • Clone TensorFlow repo。

    git clone --recursive https://github.com/tensorflow/tensorflow

构建graph

我们首先创建一个tensorflow graph,然后保存成protobuf备用。

import tensorflow as tf
import numpy as np

with tf.Session() as sess:
    a = tf.Variable(5.0, name='a')
    b = tf.Variable(6.0, name='b')
    c = tf.multiply(a, b, name="c")

    sess.run(tf.global_variables_initializer())

    print a.eval() # 5.0
    print b.eval() # 6.0
    print c.eval() # 30.0

    tf.train.write_graph(sess.graph_def, 'models/', 'graph.pb', as_text=False)

创建二进制文件

让我们在tensorflow/tensorflow目录下创建一个名叫loader的目录,即tensorflow/tensorflow/loader,用于载入之前我们创建好的graph。

loader/目录下我们再创建一个新的文件叫做loader.cc。在loader.cc里我们要做以下几件事情:

  1. 初始化一个tensorflow session
  2. 载入之前我们创建好的graph
  3. 将这个graph加入到session里面
  4. 设置好输入输出
  5. 运行graph,得到输出
  6. 读取输出中的值
  7. 关闭session,释放资源
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/platform/env.h"

using namespace tensorflow;

int main(int argc, char* argv[]) {
  // Initialize a tensorflow session
  Session* session;
  Status status = NewSession(SessionOptions(), &session);
  if (!status.ok()) {
    std::cout << status.ToString() << "\n";
    return 1;
  }

  // Read in the protobuf graph we exported
  // (The path seems to be relative to the cwd. Keep this in mind
  // when using `bazel run` since the cwd isn't where you call
  // `bazel run` but from inside a temp folder.)
  GraphDef graph_def;
  status = ReadBinaryProto(Env::Default(), "models/graph.pb", &graph_def);
  if (!status.ok()) {
    std::cout << status.ToString() << "\n";
    return 1;
  }

  // Add the graph to the session
  status = session->Create(graph_def);
  if (!status.ok()) {
    std::cout << status.ToString() << "\n";
    return 1;
  }

  // Setup inputs and outputs:

  // Our graph doesn't require any inputs, since it specifies default values,
  // but we'll change an input to demonstrate.
  Tensor a(DT_FLOAT, TensorShape());
  a.scalar<float>()() = 3.0;

  Tensor b(DT_FLOAT, TensorShape());
  b.scalar<float>()() = 2.0;

  std::vector<std::pair<string, tensorflow::Tensor>> inputs = {
    { "a", a },
    { "b", b },
  };

  // The session will initialize the outputs
  std::vector outputs;

  // Run the session, evaluating our "c" operation from the graph
  status = session->Run(inputs, {"c"}, {}, &outputs);
  if (!status.ok()) {
    std::cout << status.ToString() << "\n";
    return 1;
  }

  // Grab the first output (we only evaluated one graph node: "c")
  // and convert the node to a scalar representation.
  auto output_c = outputs[0].scalar<float>();

  // (There are similar methods for vectors and matrices here:
  // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/public/tensor.h)

  // Print the results
  std::cout << outputs[0].DebugString() << "\n"; // Tensor
  std::cout << output_c() << "\n"; // 30

  // Free any resources used by the session
  session->Close();
  return 0;
}

然后我们需要为我们的项目创建一个BUILD文件,这会告诉bazel要编译什么东西。在BUILD文件里我们要定义一个cc_binary,表示输出一个二进制文件。

cc_binary(
    name = "loader",
    srcs = ["loader.cc"],
    deps = [
        "//tensorflow/core:tensorflow",
    ]
)

那么最终文件结构如下:

  • tensorflow/tensorflow/loader/
  • tensorflow/tensorflow/loader/loader.cc
  • tensorflow/tensorflow/loader/BUILD

编译和运行

  • 在tensorflow repo的根目录下,运行./configure
  • 在tensorflow/tensorflow/loader目录下,运行bazel build :loader
    • 如果编译的时候遇到一大串undefined reference to ...的话建议用bazel build —config=monolithic :loader编译,参考https://github.com/tensorflow/tensorflow/issues/13267
  • 在tensorflow repo的根目录下,cd到 bazel-bin/tensorflow/loader目录下
  • 将graph protobuf 拷贝到models/graph.pb
  • 运行./loader,得到输出!

Reference

  1. Loading a TensorFlow graph with the C++ API
  2. tensorflow#issue:Packaged TensorFlow C++ library for bazel-independent use

你可能感兴趣的:(c++编程)