tensorflow c++ 读入pb文件

最近因为项目原因,开始接触c++ tensorflow,感觉c++的比python复杂太多了。不过好在不需要使用c++来训练模型,而是python端训练好,把模型文件保存下来,c++端导入模型文件,然后读入图片进行预测。
代码分为以下几个部分

  1. 创建session,同时设置gpu选项
Session* session;
SessionOptions options; 
options.config.mutable_gpu_options()->set_visible_device_list(gpus); //设置使用的gpu
options.config.mutable_gpu_options()->set_allow_growth(true); //设置GPU内存自动增长
TF_CHECK_OK(NewSession(options, &session));//创建新会话Session
  1. 从pb文件中读取模型,将模型导入session
GraphDef graphdef; //Graph Definition for current model
TF_CHECK_OK(ReadBinaryProto(Env::Default(), model_path, &graphdef)); //从pb文件中读取图模型;
TF_CHECK_OK(session->Create(graphdef)); //将模型导入会话Session中;
  1. 开始测试
Mat image = imread(images_address + image_name);
mat2Tensor(image, input);    //这个函数是将Mat转为Tensor
std::vector > in;    //模型输入数据
std::vector out;             //输出数据名称
std::vector outputs;      //输出数据存放的数组
in.push_back(pair(input_name, input));       //input_name为输入数据的名称,这个是将输入数据与名称一起放到数组中
out.push_back(output_name); //将输出数据的名称放到数组中
TF_CHECK_OK(session->Run(in, out, {}, &outputs));   //运行模型,得到输出
  1. 完整代码
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/graph/default_device.h"

#include 
#include 
#include 
#include 
#include 

#include 
#include 
#include 
#include 

//linux
#include 
#include 
#include 

using namespace cv;
using namespace tensorflow;
using namespace std;

const int IMAGE_SIZE = 512;
const int CLASS = 36;

void mat2Tensor(Mat &image, Tensor &t) {
    resize(image, image, Size(IMAGE_SIZE, IMAGE_SIZE));

    cvtColor(image, image, cv::COLOR_RGB2BGR);
    float *tensor_data_ptr = t.flat().data();
    cv::Mat fake_mat(image.rows, image.cols, CV_32FC(image.channels()), tensor_data_ptr);
    image.convertTo(fake_mat, CV_32FC(image.channels()));
}

void tensor2Mat(Tensor &t, Mat &image) {
    int *p = t.flat().data();
    image = Mat(IMAGE_SIZE, IMAGE_SIZE, CV_32SC1, p);
    image.convertTo(image, CV_8UC1);
}
void solve() {
    string model_path = "../modules/20190707.pb";

    string input_name = "inputs/X:0";

    string output_name = "preds:0";

    string images_address = "../Images/";

    string output_address = "../outputs/";

    string gpus = "0";

    /*--------------------------------创建session------------------------------*/
    Session* session;
    SessionOptions options;
    
    options.config.mutable_gpu_options()->set_visible_device_list(gpus);

    options.config.mutable_gpu_options()->set_allow_growth(true);
    

    TF_CHECK_OK(NewSession(options, &session));//创建新会话Session
    

    /*--------------------------------从pb文件中读取模型--------------------------------*/
    
    GraphDef graphdef; //Graph Definition for current model
    TF_CHECK_OK(ReadBinaryProto(Env::Default(), model_path, &graphdef)); //从pb文件中读取图模型;

    TF_CHECK_OK(session->Create(graphdef)); //将模型导入会话Session中;

    std::cout << "<----Successfully created session and load graph.------->" << std::endl;
    Tensor input(DT_FLOAT, TensorShape({ 1, IMAGE_SIZE, IMAGE_SIZE, 3 }));

    /*--------------------------------读取目录下的图片--------------------------------*/

    vector images;  //图片名称

    for(int i = 4800; i < 5550; i++) {
        char tmp[20];
        sprintf(tmp, "img_%06d.png",i*4);
        // cout << tmp << endl;
        images.push_back(tmp);
    }

    cout << "load data success" << endl;

    /*--------------------------------开始测试--------------------------------*/
    cout << "start run" << endl;
    double total_time = 0;
    Mat lut(1, 256, CV_8UC3, lutData);
    for(string image_name : images) {
        Mat image = imread(images_address + image_name);
        mat2Tensor(image, input);
        std::vector > in;
        std::vector out;
        std::vector outputs;
        in.push_back(pair(input_name, input));
        out.push_back(output_name);     
        TF_CHECK_OK(session->Run(in, out, {}, &outputs));
        Mat res (IMAGE_SIZE, IMAGE_SIZE, CV_8UC1);
        tensor2Mat(outputs[0], res);
        cv::imwrite(output_address + image_name, res);
    }
}
int main()
{
    solve();
    return 0;
}

你可能感兴趣的:(tensorflow c++ 读入pb文件)