现在的深度学习框架一般都是基于 Python 来实现,构建、训练、保存和调用模型都可以很容易地在 Python 下完成。但有时候,我们在实际应用这些模型的时候可能需要在其他编程语言下进行,本文将通过直接调用 TensorFlow 的 C/C++ 接口来导入 TensorFlow 预训练好的模型。
1.环境配置 点此查看 C/C++ 接口的编译
2. 导入预定义的图和训练好的参数值
// set up your input paths
const string pathToGraph = "/home/senius/python/c_python/test/model-10.meta";
const string checkpointPath = "/home/senius/python/c_python/test/model-10";
auto session = NewSession(SessionOptions()); // 创建会话
if (session == nullptr)
{
throw runtime_error("Could not create Tensorflow session.");
}
Status status;
// Read in the protobuf graph we exported
MetaGraphDef graph_def;
status = ReadBinaryProto(Env::Default(), pathToGraph, &graph_def); // 导入图模型
if (!status.ok())
{
throw runtime_error("Error reading graph definition from " + pathToGraph + ": " + status.ToString());
}
// Add the graph to the session
status = session->Create(graph_def.graph_def()); // 将图模型加入到会话中
if (!status.ok())
{
throw runtime_error("Error creating graph: " + status.ToString());
}
// Read weights from the saved checkpoint
Tensor checkpointPathTensor(DT_STRING, TensorShape());
checkpointPathTensor.scalar()() = checkpointPath; // 读取预训练好的权重
status = session->Run({{graph_def.saver_def().filename_tensor_name(), checkpointPathTensor},}, {},
{graph_def.saver_def().restore_op_name()}, nullptr);
if (!status.ok())
{
throw runtime_error("Error loading checkpoint from " + checkpointPath + ": " + status.ToString());
}
3. 准备测试数据
const string filename = "/home/senius/python/c_python/test/04t30t00.npy";
//Read TXT data to array
float Array[1681*41];
ifstream is(filename);
for (int i = 0; i < 1681*41; i++){
is >> Array[i];
}
is.close();
tensorflow::Tensor input_tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({1, 41, 41, 41, 1}));
auto input_tensor_mapped = input_tensor.tensor();
float *pdata = Array;
// copying the data into the corresponding tensor
for (int x = 0; x < 41; ++x)//depth
{
for (int y = 0; y < 41; ++y) {
for (int z = 0; z < 41; ++z) {
const float *source_value = pdata + x * 1681 + y * 41 + z;
input_tensor_mapped(0, x, y, z, 0) = *source_value;
}
}
}
4. 前向传播得到预测值
std::vector finalOutput;
std::string InputName = "X"; // Your input placeholder's name
std::string OutputName = "sigmoid"; // Your output tensor's name
vector > inputs;
inputs.push_back(std::make_pair(InputName, input_tensor));
// Fill input tensor with your input data
session->Run(inputs, {OutputName}, {}, &finalOutput);
auto output_y = finalOutput[0].scalar();
std::cout << output_y() << "\n";
5. 一些问题
6. 完整代码
#include
#include
#include
#include
#include
#include
#include
#include
using namespace std;
using namespace tensorflow;
using namespace tensorflow::ops;
int main()
{
// set up your input paths
const string pathToGraph = "/home/senius/python/c_python/test/model-10.meta";
const string checkpointPath = "/home/senius/python/c_python/test/model-10";
auto session = NewSession(SessionOptions());
if (session == nullptr)
{
throw runtime_error("Could not create Tensorflow session.");
}
Status status;
// Read in the protobuf graph we exported
MetaGraphDef graph_def;
status = ReadBinaryProto(Env::Default(), pathToGraph, &graph_def);
if (!status.ok())
{
throw runtime_error("Error reading graph definition from " + pathToGraph + ": " + status.ToString());
}
// Add the graph to the session
status = session->Create(graph_def.graph_def());
if (!status.ok())
{
throw runtime_error("Error creating graph: " + status.ToString());
}
// Read weights from the saved checkpoint
Tensor checkpointPathTensor(DT_STRING, TensorShape());
checkpointPathTensor.scalar()() = checkpointPath;
status = session->Run({{graph_def.saver_def().filename_tensor_name(), checkpointPathTensor},}, {},
{graph_def.saver_def().restore_op_name()}, nullptr);
if (!status.ok())
{
throw runtime_error("Error loading checkpoint from " + checkpointPath + ": " + status.ToString());
}
cout << 1 << endl;
const string filename = "/home/senius/python/c_python/test/04t30t00.npy";
//Read TXT data to array
float Array[1681*41];
ifstream is(filename);
for (int i = 0; i < 1681*41; i++){
is >> Array[i];
}
is.close();
tensorflow::Tensor input_tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({1, 41, 41, 41, 1}));
auto input_tensor_mapped = input_tensor.tensor();
float *pdata = Array;
// copying the data into the corresponding tensor
for (int x = 0; x < 41; ++x)//depth
{
for (int y = 0; y < 41; ++y) {
for (int z = 0; z < 41; ++z) {
const float *source_value = pdata + x * 1681 + y * 41 + z;
// input_tensor_mapped(0, x, y, z, 0) = *source_value;
input_tensor_mapped(0, x, y, z, 0) = 1;
}
}
}
std::vector finalOutput;
std::string InputName = "X"; // Your input placeholder's name
std::string OutputName = "sigmoid"; // Your output placeholder's name
vector > inputs;
inputs.push_back(std::make_pair(InputName, input_tensor));
// Fill input tensor with your input data
session->Run(inputs, {OutputName}, {}, &finalOutput);
auto output_y = finalOutput[0].scalar();
std::cout << output_y() << "\n";
return 0;
}
cmake_minimum_required(VERSION 3.8)
project(Tensorflow_test)
set(CMAKE_CXX_STANDARD 11)
set(SOURCE_FILES main.cpp)
include_directories(
/home/senius/tensorflow-r1.4
/home/senius/tensorflow-r1.4/tensorflow/bazel-genfiles
/home/senius/tensorflow-r1.4/tensorflow/contrib/makefile/gen/protobuf/include
/home/senius/tensorflow-r1.4/tensorflow/contrib/makefile/gen/host_obj
/home/senius/tensorflow-r1.4/tensorflow/contrib/makefile/gen/proto
/home/senius/tensorflow-r1.4/tensorflow/contrib/makefile/downloads/nsync/public
/home/senius/tensorflow-r1.4/tensorflow/contrib/makefile/downloads/eigen
/home/senius/tensorflow-r1.4/bazel-out/local_linux-py3-opt/genfiles
)
add_executable(Tensorflow_test ${SOURCE_FILES})
target_link_libraries(Tensorflow_test
/home/senius/tensorflow-r1.4/bazel-bin/tensorflow/libtensorflow_cc.so
/home/senius/tensorflow-r1.4/bazel-bin/tensorflow/libtensorflow_framework.so
)