TensorRT/parsers/caffe/caffeParser/readProto.h源碼研讀

TensorRT/parsers/caffe/caffeParser/readProto.h源碼研讀

  • TensorRT/parsers/caffe/caffeParser/readProto.h
  • trtcaffe.pb.h
  • std::ifstream
  • google::protobuf::io
  • 參考連結

TensorRT/parsers/caffe/caffeParser/readProto.h

/*
 * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef TRT_CAFFE_PARSER_READ_PROTO_H
#define TRT_CAFFE_PARSER_READ_PROTO_H

#include 

#include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "google/protobuf/text_format.h"

#include "caffeMacros.h"
#include "trtcaffe.pb.h"

namespace nvcaffeparser1
{
// There are some challenges associated with importing caffe models. One is that
// a .caffemodel file just consists of layers and doesn't have the specs for its
// input and output blobs.
//
// So we need to read the deploy file to get the input
//從.caffemodel中讀取模型權重到net裡
bool readBinaryProto(trtcaffe::NetParameter* net, const char* file, size_t bufSize)
{
    //因為是macro函數,所以結尾不加分號
    //如果net或file是nullptr,就return false
    CHECK_NULL_RET_VAL(net, false)
    CHECK_NULL_RET_VAL(file, false)
    using namespace google::protobuf::io;

    std::ifstream stream(file, std::ios::in | std::ios::binary);
    if (!stream)
    {
        //輸出錯誤訊息,並回傳false
        RETURN_AND_LOG_ERROR(false, "Could not open file " + std::string(file));
    }

    //創建一個從C++ istream裡讀取數據的流
    IstreamInputStream rawInput(&stream);
    /*
    IstreamInputStream為ZeroCopyInputStream的子類別
    從ZeroCopyInputStream讀取數據並解碼。
    */
    CodedInputStream codedInput(&rawInput);
    //設定CodedInputStream物件將讀取的最大bytes數,第二個參數將被忽略
    codedInput.SetTotalBytesLimit(int(bufSize), -1);

    /*
    從給定的input stream裡解析出protocol buffer,
    並填入net這個message物件中
    */
    bool ok = net->ParseFromCodedStream(&codedInput);
    stream.close();

    if (!ok)
    {
        RETURN_AND_LOG_ERROR(false, "Could not parse binary model file");
    }

    return ok;
}

//從deploy.prototxt中讀取模型架構到net裡
bool readTextProto(trtcaffe::NetParameter* net, const char* file)
{
    //因為是macro函數,所以結尾不加分號
    CHECK_NULL_RET_VAL(net, false)
    CHECK_NULL_RET_VAL(file, false)
    using namespace google::protobuf::io;

    std::ifstream stream(file, std::ios::in);
    if (!stream)
    {
        RETURN_AND_LOG_ERROR(false, "Could not open file " + std::string(file));
    }

    //創建一個從C++ istream裡讀取數據的流
    IstreamInputStream input(&stream);
    /*
    從給定的ZeroCopyInputStream裡讀取並解析文字格式的protocol message,
    存到給定的Message物件net當中
    */
    bool ok = google::protobuf::TextFormat::Parse(&input, net);
    stream.close();
    return ok;
}
} //namespace nvcaffeparser1
#endif //TRT_CAFFE_PARSER_READ_PROTO_H

trtcaffe.pb.h

TensorRT的源碼中並沒有trtcaffe.pb.h這個檔案,那麼它是從何而來呢?詳見Protocol Buffer(proto2)及C++ API。

std::ifstream

readBinaryProtoreadTextProto兩個函數中都用到了std::ifstream,詳見C++ ifstream。

google::protobuf::io

readBinaryProtoreadTextProto兩個函數中都用到了來自 google/protobuf 套件的函數,詳見 C++ google protobuf。

參考連結

Protocol Buffer(proto2)及C++ API

C++ ifstream

C++ google protobuf

你可能感兴趣的:(TensorRT源碼研讀筆記)