tensorflow lite(二)

1、main函数中获取setting的值

tflite::Lable_image::Main函数:

输入参数全部存储在Setting里面:

./lable_image

 

 

 -i

 ./grace_hopper.bmp

 

 -l

./labels.txt

用于输出结果的标签有哪些。比如 background tench goldfish great white sharp tiger shark hammerhead。。。文件里有很多当然也可以改成汉字的。最后的输出还会输出相关label的置信度

 -m

 ./mobilet_quant_v1_224.tflite

 

 -a

 0

是否使用android NNAPI加速【interpreter->UseNNAPI(s->accel);】

 

 -c

 1

循环次数loop_count【

for (int i = 0; i < s->loop_count; i++) {

    if (interpreter->Invoke() != kTfLiteOk) {

      LOG(FATAL) << "Failed 】

-b

128

input mean 代码总默认127.5【用于控制收敛速度?】

 

-p

0

是否开启profiling【用于深度学习参数优化 

 -t

1

 线程数量【 

 if (s->number_of_threads != -1) {

    interpreter->SetNumThreads(s->number_of_threads);

  }】

 -v

1

 是否显示更多运行信息

-s

 

input std 代码中默认127.5

文件中的默认值:

external/tensorflow$ vi tensorflow/contrib/lite/examples/label_image/label_image.h +24

#ifndef TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_LABEL_IMAGE_H
#define TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_LABEL_IMAGE_H

#include "tensorflow/contrib/lite/string.h"

namespace tflite {
namespace label_image {

struct Settings {
  bool verbose = false;
  bool accel = false;
  bool input_floating = false;
  int loop_count = 1;
  float input_mean = 127.5f;
  float input_std = 127.5f;
  string model_name = "./mobilenet_quant_v1_224.tflite";
  string input_bmp_name = "./grace_hopper.bmp";
  string labels_file_name = "./labels.txt";
  string input_layer_type = "uint8_t";
  int number_of_threads = 4;
};

}  // namespace label_image
}  // namespace tflite

#endif  // TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_LABEL_IMAGE_H

在main函数中对应的解释:

static struct option long_options[] = {
        {"accelerated", required_argument, 0, 'a'},
        {"count", required_argument, 0, 'c'},
        {"verbose", required_argument, 0, 'v'},
        {"image", required_argument, 0, 'i'},
        {"labels", required_argument, 0, 'l'},
        {"tflite_model", required_argument, 0, 'm'},
        {"threads", required_argument, 0, 't'},
        {"input_mean", required_argument, 0, 'b'},
        {"input_std", required_argument, 0, 's'},
        {0, 0, 0, 0}};

2、RunInference(&s);

首先flatbufferbuilder,在modle.h里面包含了两个builder,一个是FlatBufferBuilder,一个是InterpreterBuilder

之前以为FlatBufferBuilder是用来构建只读模型,InterpreterBuilder是用来构建可修改的模型,实际上这样的理解是不对的(从注释上看),FlatBufferBuilder是用来构建tflite的模型,InterpreterBuilder是用来构建interpreter

// An RAII object that represents a read-only tflite model, copied from disk,
// or mmapped. This uses flatbuffers as the serialization format.
// flatbuffers是什么来着,好像是一种固定格式的文件,具体有点忘记了
class FlatBufferModel {
 public:
  // Builds a model based on a file. Returns a nullptr in case of failure.
  static std::unique_ptr BuildFromFile(
      const char* filename,
      ErrorReporter* error_reporter = DefaultErrorReporter());

  // Builds a model based on a pre-loaded flatbuffer. The caller retains
  // ownership of the buffer and should keep it alive until the returned object
  // is destroyed. Returns a nullptr in case of failure.
  static std::unique_ptr BuildFromBuffer(
      const char* buffer, size_t buffer_size,
      ErrorReporter* error_reporter = DefaultErrorReporter());

  // Builds a model directly from a flatbuffer pointer. The caller retains
  // ownership of the buffer and should keep it alive until the returned object
  // is destroyed. Returns a nullptr in case of failure.
  static std::unique_ptr BuildFromModel(
      const tflite::Model* model_spec,
      ErrorReporter* error_reporter = DefaultErrorReporter());

  // Releases memory or unmaps mmaped meory.
  ~FlatBufferModel();

  // Copying or assignment is disallowed to simplify ownership semantics.
  FlatBufferModel(const FlatBufferModel&) = delete;
  FlatBufferModel& operator=(const FlatBufferModel&) = delete;

  bool initialized() const { return model_ != nullptr; }
  const tflite::Model* operator->() const { return model_; }
  const tflite::Model* GetModel() const { return model_; }
  ErrorReporter* error_reporter() const { return error_reporter_; }
  const Allocation* allocation() const { return allocation_; }

  // Returns true if the model identifier is correct (otherwise false and
  // reports an error).
  bool CheckModelIdentifier() const;

 private:
  // Loads a model from `filename`. If `mmap_file` is true then use mmap,
  // otherwise make a copy of the model in a buffer.
  //
  // Note, if `error_reporter` is null, then a DefaultErrorReporter() will be
  // used.
  explicit FlatBufferModel(
      const char* filename, bool mmap_file = true,
      ErrorReporter* error_reporter = DefaultErrorReporter(),
      bool use_nnapi = false);

  // Loads a model from `ptr` and `num_bytes` of the model file. The `ptr` has
  // to remain alive and unchanged until the end of this flatbuffermodel's
  // lifetime.
  //
  // Note, if `error_reporter` is null, then a DefaultErrorReporter() will be
  // used.
  FlatBufferModel(const char* ptr, size_t num_bytes,
                  ErrorReporter* error_reporter = DefaultErrorReporter());

  // Loads a model from Model flatbuffer. The `model` has to remain alive and
  // unchanged until the end of this flatbuffermodel's lifetime.
  FlatBufferModel(const Model* model, ErrorReporter* error_reporter);

  // Flatbuffer traverser pointer. (Model* is a pointer that is within the
  // allocated memory of the data allocated by allocation's internals.
  const tflite::Model* model_ = nullptr;
  ErrorReporter* error_reporter_;
  Allocation* allocation_ = nullptr;
};

InterpreterBuilder

// Build an interpreter capable(能力) of interpreting `model`.建立一个能够解析模型的解析器
//
// model: a scoped(作用域) model whose lifetime must be at least as long as
//   the interpreter. In principle multiple interpreters can be made from
//   a single model.
// op_resolver: An instance(实例) that implements(实现) the Resolver(分解器) interface which maps
//   custom op names and builtin op codes to op registrations.
// reportError: a functor that is called to report errors that handles
//   printf var arg semantics(语意). The lifetime of the reportError object must
//   be greater than or equal to the Interpreter created by operator().
//
// Returns a kTfLiteOk when successful and sets interpreter to a valid
// Interpreter. Note: the user must ensure the model lifetime is at least as
// long as interpreter's lifetime.
class InterpreterBuilder {
 public:
  InterpreterBuilder(const FlatBufferModel& model,
                     const OpResolver& op_resolver);
  // Builds an interpreter given only the raw flatbuffer Model object (instead
  // of a FlatBufferModel). Mostly used for testing.
  // If `error_reporter` is null, then DefaultErrorReporter() is used.
  InterpreterBuilder(const ::tflite::Model* model,
                     const OpResolver& op_resolver,
                     ErrorReporter* error_reporter = DefaultErrorReporter());
  InterpreterBuilder(const InterpreterBuilder&) = delete;
  InterpreterBuilder& operator=(const InterpreterBuilder&) = delete;
  TfLiteStatus operator()(std::unique_ptr* interpreter);
  TfLiteStatus operator()(std::unique_ptr* interpreter,
                          int num_threads);

 private:
  TfLiteStatus BuildLocalIndexToRegistrationMapping();
  TfLiteStatus ParseNodes(
      const flatbuffers::Vector>* operators,
      Interpreter* interpreter);
  TfLiteStatus ParseTensors(
      const flatbuffers::Vector>* buffers,
      const flatbuffers::Vector>* tensors,
      Interpreter* interpreter);

  const ::tflite::Model* model_;
  const OpResolver& op_resolver_;
  ErrorReporter* error_reporter_;

  std::vector flatbuffer_op_index_to_registration_;
  std::vector flatbuffer_op_index_to_registration_types_;
  const Allocation* allocation_ = nullptr;
};

just like this below:

#ifndef TENSORFLOW_CONTRIB_LITE_MODEL_H_
#define TENSORFLOW_CONTRIB_LITE_MODEL_H_

#include 
#include "tensorflow/contrib/lite/error_reporter.h"
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/schema/schema_generated.h"

namespace tflite {

class FlatBufferModel{
	*************
};

class InterpreterBuilder{
	****************
};

}  // namespace tflite

#endif  // TENSORFLOW_CONTRIB_LITE_MODEL_H_

然后是Interpreter,也就是我们的解释器,翻译官,在interperter.h这个头文件中,class里面的factor很多

// Interpreter实际上是翻译官
class Interpreter {
 public:
  // Instantiate an interpreter. All errors associated with reading and
  // processing this model will be forwarded to the error_reporter object.
  //
  // Note, if error_reporter is nullptr, then a default StderrReporter is
  // used.
  explicit Interpreter(ErrorReporter* error_reporter = DefaultErrorReporter());

  ~Interpreter();

  Interpreter(const Interpreter&) = delete;
  Interpreter& operator=(const Interpreter&) = delete;

  // Functions to build interpreter

  // Provide a list of tensor indexes that are inputs to the model.
  // Each index is bound check and this modifies the consistent_ flag of the
  // interpreter.
  TfLiteStatus SetInputs(std::vector inputs);

  // Provide a list of tensor indexes that are outputs to the model
  // Each index is bound check and this modifies the consistent_ flag of the
  // interpreter.
  TfLiteStatus SetOutputs(std::vector outputs);

  // Adds a node with the given parameters and returns the index of the new
  // node in `node_index` (optionally). Interpreter will take ownership of
  // `builtin_data` and destroy it with `free`. Ownership of 'init_data'
  // remains with the caller.
  TfLiteStatus AddNodeWithParameters(const std::vector& inputs,
                                     const std::vector& outputs,
                                     const char* init_data,
                                     size_t init_data_size, void* builtin_data,
                                     const TfLiteRegistration* registration,
                                     int* node_index = nullptr);

  // Adds `tensors_to_add` tensors, preserving pre-existing Tensor entries.
  // The value pointed to by `first_new_tensor_index` will be set to the
  // index of the first new tensor if `first_new_tensor_index` is non-null.
  TfLiteStatus AddTensors(int tensors_to_add,
                          int* first_new_tensor_index = nullptr);

  // Set description of inputs/outputs/data/fptrs for node `node_index`.
  // This variant assumes an external buffer has been allocated of size
  // bytes. The lifetime of buffer must be ensured to be greater or equal
  // to Interpreter.
  TfLiteStatus SetTensorParametersReadOnly(
      int tensor_index, TfLiteType type, const char* name,
      const std::vector& dims, TfLiteQuantizationParams quantization,
      const char* buffer, size_t bytes, const Allocation* allocation = nullptr);

  // Set description of inputs/outputs/data/fptrs for node `node_index`.
  // This variant assumes an external buffer has been allocated of size
  // bytes. The lifetime of buffer must be ensured to be greater or equal
  // to Interpreter.
  TfLiteStatus SetTensorParametersReadWrite(
      int tensor_index, TfLiteType type, const char* name,
      const std::vector& dims, TfLiteQuantizationParams quantization);

  // Functions to access tensor data

  // Read only access to list of inputs.
  const std::vector& inputs() const { return inputs_; }

  // Return the name of a given input. The given index must be between 0 and
  // inputs().size().
  const char* GetInputName(int index) const {
    return context_.tensors[inputs_[index]].name;
  }

  // Read only access to list of outputs.
  const std::vector& outputs() const { return outputs_; }

  // Return the name of a given output. The given index must be between 0 and
  // outputs().size().
  const char* GetOutputName(int index) const {
    return context_.tensors[outputs_[index]].name;
  }

  // Return the number of tensors in the model.
  int tensors_size() const { return context_.tensors_size; }

  // Return the number of ops in the model.
  int nodes_size() const { return nodes_and_registration_.size(); }

  // WARNING: Experimental interface, subject to change
  const std::vector& execution_plan() const { return execution_plan_; }

  // WARNING: Experimental interface, subject to change
  // Overrides execution plan. This bounds checks indices sent in.
  TfLiteStatus SetExecutionPlan(const std::vector& new_plan);

  // Get a tensor data structure.
  // TODO(aselle): Create a safe ArrayHandle interface to avoid exposing this
  // read/write access to structure
  TfLiteTensor* tensor(int tensor_index) {
    if (tensor_index >= context_.tensors_size || tensor_index < 0)
      return nullptr;
    return &context_.tensors[tensor_index];
  }

  // Get an immutable tensor data structure.
  const TfLiteTensor* tensor(int tensor_index) const {
    if (tensor_index >= context_.tensors_size || tensor_index < 0)
      return nullptr;
    return &context_.tensors[tensor_index];
  }

  // Get a pointer to an operation and registration data structure if in bounds.
  // TODO(aselle): Create a safe ArrayHandle interface to avoid exposing this
  // read/write access to structure
  const std::pair* node_and_registration(
      int node_index) const {
    if (node_index >= nodes_and_registration_.size() || node_index < 0)
      return nullptr;
    return &nodes_and_registration_[node_index];
  }

  // Perform a checked cast to the appropriate tensor type.
  template 
  T* typed_tensor(int tensor_index) {
    if (TfLiteTensor* tensor_ptr = tensor(tensor_index)) {
      if (tensor_ptr->type == typeToTfLiteType()) {
        return reinterpret_cast(tensor_ptr->data.raw);
      }
    }
    return nullptr;
  }

  // Return a pointer into the data of a given input tensor. The given index
  // must be between 0 and inputs().size().
  template 
  T* typed_input_tensor(int index) {
    return typed_tensor(inputs_[index]);
  }

  // Return a pointer into the data of a given output tensor. The given index
  // must be between 0 and outputs().size().
  template 
  T* typed_output_tensor(int index) {
    return typed_tensor(outputs_[index]);
  }

  // Change the dimensionality of a given tensor. Note, this is only acceptable
  // for tensor indices that are inputs.
  // Returns status of failure or success.
  // TODO(aselle): Consider implementing ArraySlice equivalent to make this
  //   more adept at accepting data without an extra copy. Use absl::ArraySlice
  //   if our partners determine that dependency is acceptable.
  TfLiteStatus ResizeInputTensor(int tensor_index,
                                 const std::vector& dims);

  // Update allocations for all tensors. This will redim dependent tensors using
  // the input tensor dimensionality as given. This is relatively expensive.
  // If you know that your sizes are not changing, you need not call this.

  // Returns status of success or failure.
  TfLiteStatus AllocateTensors();

  // Invoke the interpreter (run the whole graph in dependency(依赖) order).
  //
  // NOTE: It is possible that the interpreter is not in a ready state
  // to evaluate (i.e. if a ResizeTensor() has been performed without an
  // AllocateTensors().
  // Returns status of success or failure.
  TfLiteStatus Invoke(); //调用,感觉这个是个最重点的函数  

  // Enable or disable the NN API (true to enable)
  void UseNNAPI(bool enable);

  // Set the number of threads available to the interpreter.
  void SetNumThreads(int num_threads);

  // Allow a delegate to look at the graph and modify the graph to handle
  // parts of the graph themselves. After this is called, the graph may
  // contain new nodes that replace 1 more nodes.
  TfLiteStatus ModifyGraphWithDelegate(TfLiteDelegate* delegate);

  // Retrieve an operator's description of its work, for profiling purposes.
  const char* OpProfilingString(const TfLiteRegistration& op_reg,
                                const TfLiteNode* node) const {
    // haili TODO:
    //if (op_reg.profiling_string == nullptr) return nullptr;
    //return op_reg.profiling_string(&context_, node);
    return nullptr;
  }

  void SetProfiler(profiling::Profiler* profiler) { profiler_ = profiler; }

  profiling::Profiler* GetProfiler() { return profiler_; }

 private:
  // Give 'op_reg' a chance to initialize itself using the contents of
  // 'buffer'.
  void* OpInit(const TfLiteRegistration& op_reg, const char* buffer,
               size_t length) {
    if (op_reg.init == nullptr) return nullptr;
    return op_reg.init(&context_, buffer, length);
  }

  // Let 'op_reg' release any memory it might have allocated via 'OpInit'.
  void OpFree(const TfLiteRegistration& op_reg, void* buffer) {
    if (op_reg.free == nullptr) return;
    if (buffer) {
      op_reg.free(&context_, buffer);
    }
  }

  // Prepare the given 'node' for execution.
  TfLiteStatus OpPrepare(const TfLiteRegistration& op_reg, TfLiteNode* node) {
    if (op_reg.prepare == nullptr) return kTfLiteOk;
    return op_reg.prepare(&context_, node);
  }

  // Invoke the operator represented by 'node'.
  TfLiteStatus OpInvoke(const TfLiteRegistration& op_reg, TfLiteNode* node) {
    if (op_reg.invoke == nullptr) return kTfLiteError;
    return op_reg.invoke(&context_, node);
  }

  // Call OpPrepare() for as many ops as possible, allocating memory for their
  // tensors. If an op containing dynamic tensors is found, preparation will be
  // postponed until this function is called again. This allows the interpreter
  // to wait until Invoke() to resolve the sizes of dynamic tensors.
  TfLiteStatus PrepareOpsAndTensors();

  // Call OpPrepare() for all ops starting at 'first_node'. Stop when a
  // dynamic tensors is found or all ops have been prepared. Fill
  // 'last_node_prepared' with the id of the op containing dynamic tensors, or
  // the last in the graph.
  TfLiteStatus PrepareOpsStartingAt(int first_execution_plan_index,
                                    int* last_execution_plan_index_prepared);

  // Tensors needed by the interpreter. Use `AddTensors` to add more blank
  // tensor entries. Note, `tensors_.data()` needs to be synchronized to the
  // `context_` whenever this std::vector is reallocated. Currently this
  // only happens in `AddTensors()`.
  std::vector tensors_;

  // Check if an array of tensor indices are valid with respect to the Tensor
  // array.
  // NOTE: this changes consistent_ to be false if indices are out of bounds.
  TfLiteStatus CheckTensorIndices(const char* label, const int* indices,
                                  int length);

  // Compute the number of bytes required to represent a tensor with dimensions
  // specified by the array dims (of length dims_size). Returns the status code
  // and bytes.
  TfLiteStatus BytesRequired(TfLiteType type, const int* dims, int dims_size,
                             size_t* bytes);

  // Request an tensor be resized implementation. If the given tensor is of
  // type kTfLiteDynamic it will also be allocated new memory.
  TfLiteStatus ResizeTensorImpl(TfLiteTensor* tensor, TfLiteIntArray* new_size);

  // Report a detailed error string (will be printed to stderr).
  // TODO(aselle): allow user of class to provide alternative destinations.
  void ReportErrorImpl(const char* format, va_list args);

  // Entry point for C node plugin API to request an tensor be resized.
  static TfLiteStatus ResizeTensor(TfLiteContext* context, TfLiteTensor* tensor,
                                   TfLiteIntArray* new_size);
  // Entry point for C node plugin API to report an error.
  static void ReportError(TfLiteContext* context, const char* format, ...);

  // Entry point for C node plugin API to add new tensors.
  static TfLiteStatus AddTensors(TfLiteContext* context, int tensors_to_add,
                                 int* first_new_tensor_index);

  // WARNING: This is an experimental API and subject to change.
  // Entry point for C API ReplaceSubgraphsWithDelegateKernels
  static TfLiteStatus ReplaceSubgraphsWithDelegateKernels(
      TfLiteContext* context, TfLiteRegistration registration,
      const TfLiteIntArray* nodes_to_replace);

  // Update the execution graph to replace some of the nodes with stub
  // nodes. Specifically any node index that has `nodes[index]==1` will be
  // slated for replacement with a delegate kernel specified by registration.
  // WARNING: This is an experimental interface that is subject to change.
  TfLiteStatus ReplaceSubgraphsWithDelegateKernels(
      TfLiteRegistration registration, const TfLiteIntArray* nodes_to_replace);

  // WARNING: This is an experimental interface that is subject to change.
  // Gets the internal pointer to a TensorFlow lite node by node_index.
  TfLiteStatus GetNodeAndRegistration(int node_index, TfLiteNode** node,
                                      TfLiteRegistration** registration);

  // WARNING: This is an experimental interface that is subject to change.
  // Entry point for C node plugin API to get a node by index.
  static TfLiteStatus GetNodeAndRegistration(struct TfLiteContext*,
                                             int node_index, TfLiteNode** node,
                                             TfLiteRegistration** registration);

  // WARNING: This is an experimental interface that is subject to change.
  // Gets an TfLiteIntArray* representing the execution plan. The caller owns
  // this memory and must free it with TfLiteIntArrayFree().
  TfLiteStatus GetExecutionPlan(TfLiteIntArray** execution_plan);

  // WARNING: This is an experimental interface that is subject to change.
  // Entry point for C node plugin API to get the execution plan
  static TfLiteStatus GetExecutionPlan(struct TfLiteContext* context,
                                       TfLiteIntArray** execution_plan);

  // A pure C data structure used to communicate with the pure C plugin
  // interface. To avoid copying tensor metadata, this is also the definitive
  // structure to store tensors.
  TfLiteContext context_;

  // Node inputs/outputs are stored in TfLiteNode and TfLiteRegistration stores
  // function pointers to actual implementation.
  std::vector>
      nodes_and_registration_;

  // Whether the model is consistent. That is to say if the inputs and outputs
  // of every node and the global inputs and outputs are valid indexes into
  // the tensor array.
  bool consistent_ = true;

  // Whether the model is safe to invoke (if any errors occurred this
  // will be false).
  bool invokable_ = false;

  // Array of indices representing the tensors that are inputs to the
  // interpreter.
  std::vector inputs_;

  // Array of indices representing the tensors that are outputs to the
  // interpreter.
  std::vector outputs_;

  // The error reporter delegate that tflite will forward queries errors to.
  ErrorReporter* error_reporter_;

  // Index of the next node to prepare.
  // During Invoke(), Interpreter will allocate input tensors first, which are
  // known to be fixed size. Then it will allocate outputs from nodes as many
  // as possible. When there is a node that produces dynamic sized tensor.
  // Intepreter will stop allocating tensors, set the value of next allocate
  // node id, and execute the node to generate the output tensor before continue
  // to allocate successors. This process repeats until all nodes are executed.
  // NOTE: this relies on the order of nodes that is in topological order.
  int next_execution_plan_index_to_prepare_;

  // WARNING: This is an experimental interface that is subject to change.
  // This is a list of node indices (to index into nodes_and_registration).
  // This represents a valid topological sort (dependency ordered) execution
  // plan. In particular, it is valid for this ordering to contain only a
  // subset of the node indices.
  std::vector execution_plan_;

  // In the future, we'd like a TfLiteIntArray compatible representation.
  // TODO(aselle): replace execution_plan_ with this.
  std::unique_ptr plan_cache_;

  // Whether to delegate to NN API
  std::unique_ptr nnapi_delegate_;

  std::unique_ptr memory_planner_;

  // Profiler for this interpreter instance.
  profiling::Profiler* profiler_;
};

构建OpResolver

#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_REGISTER_H_
#define TENSORFLOW_CONTRIB_LITE_KERNELS_REGISTER_H_

#include 
#include "tensorflow/contrib/lite/context.h"
#include "tensorflow/contrib/lite/model.h"

namespace tflite {
namespace ops {
namespace builtin {
//OpResolver 是父类
class BuiltinOpResolver : public OpResolver {//OpResolver负责维护函数和指针之间的对应关系
 public:
  BuiltinOpResolver();
  TfLiteRegistration* FindOp(tflite::BuiltinOperator op) const override;
  TfLiteRegistration* FindOp(const char* op) const override;
  void AddBuiltin(tflite::BuiltinOperator op, TfLiteRegistration* registration);
  void AddCustom(const char* name, TfLiteRegistration* registration);

 private:
  struct BuiltinOperatorHasher {
    size_t operator()(const tflite::BuiltinOperator& x) const {
      return std::hash()(static_cast(x));
    }
  };
  std::unordered_map
      builtins_;
  std::unordered_map custom_ops_;
};

}  // namespace builtin
}  // namespace ops
}  // namespace tflite

#endif  // TENSORFLOW_CONTRIB_LITE_KERNELS_BUILTIN_KERNELS_H

完整的RunInference函数如下:

double get_us(struct timeval t) { return (t.tv_sec * 1000000 + t.tv_usec); }

void RunInference(Settings* s) {
  if (!s->model_name.c_str()) {
    LOG(ERROR) << "no model file name\n";
    exit(-1);
  }

  std::unique_ptr model;
  std::unique_ptr interpreter;
  // 1、建立模型
  /*
   public:
  // Builds a model based on a file. Returns a nullptr in case of failure.
  static std::unique_ptr BuildFromFile(
      const char* filename,
      ErrorReporter* error_reporter = DefaultErrorReporter());
  */
  model = tflite::FlatBufferModel::BuildFromFile(s->model_name.c_str());
  if (!model) {
    LOG(FATAL) << "\nFailed to mmap model " << s->model_name << "\n";
    exit(-1);
  }
  LOG(INFO) << "Loaded model " << s->model_name << "\n";
  /* ErrorReporter* error_reporter() const { return error_reporter_; }*/
  model->error_reporter();
  LOG(INFO) << "resolved reporter\n";
  //2)建立OpResolver 用于指向每个node的操作函数 tflite::ops::builtin::BuiltinOpResolver resolver;
  tflite::ops::builtin::BuiltinOpResolver resolver;

  //3)建立解释器 tflite::InterpreterBuilder(*model, resolver)(&interpreter);
  /*
  // Builds an interpreter given only the raw flatbuffer Model object (instead
  // of a FlatBufferModel). Mostly used for testing.
  // If `error_reporter` is null, then DefaultErrorReporter() is used.
  InterpreterBuilder(const ::tflite::Model* model, const OpResolver& op_resolver, ErrorReporter* error_reporter = DefaultErrorReporter());
  传入的第二个参数是引用,实际上有好几个构造函数,maybe this is true or not
  */
  // 构建之后生成的是class Interpreter
  tflite::InterpreterBuilder(*model, resolver)(&interpreter); // 后面这样的操作可能是将interperter赋值给他,我去,忘得差不多了
  if (!interpreter) {
    LOG(FATAL) << "Failed to construct interpreter\n";
    exit(-1);
  }
  //4)对解释器进行参数设置包括
  interpreter->UseNNAPI(s->accel);

  // 具体可以看class Interpreter里剩下的函数
  if (s->verbose) {
    LOG(INFO) << "tensors size: " << interpreter->tensors_size() << "\n";
    LOG(INFO) << "nodes size: " << interpreter->nodes_size() << "\n";
    LOG(INFO) << "inputs: " << interpreter->inputs().size() << "\n";
    LOG(INFO) << "input(0) name: " << interpreter->GetInputName(0) << "\n";

    int t_size = interpreter->tensors_size();
    for (int i = 0; i < t_size; i++) {
	  // tensor()是TFliteTensor的格式
	  // 模型中的tensor会被加载成TFliteTensor的格式
      if (interpreter->tensor(i)->name)
        LOG(INFO) << i << ": " << interpreter->tensor(i)->name << ", "
                  << interpreter->tensor(i)->bytes << ", "
                  << interpreter->tensor(i)->type << ", "
                  << interpreter->tensor(i)->params.scale << ", "
                  << interpreter->tensor(i)->params.zero_point << "\n";
    }
  }

  if (s->number_of_threads != -1) {
    interpreter->SetNumThreads(s->number_of_threads);
  }

  // 5)bmp文件读入并进行必要的resize
  int image_width = 224;
  int image_height = 224;
  int image_channels = 3;
  // examples/label_image/bitmap_helpers.cc 可以借鉴一下
  uint8_t* in = read_bmp(s->input_bmp_name, &image_width, &image_height,
                         &image_channels, s);

  // 为什么只取第一个数据呢?
  int input = interpreter->inputs()[0];
  if (s->verbose) LOG(INFO) << "input: " << input << "\n";

  /*
  // Array of indices representing the tensors that are inputs to the
  // interpreter.
  std::vector inputs_;

  // Array of indices representing the tensors that are outputs to the
  // interpreter.
  std::vector outputs_;
  */
  const std::vector inputs = interpreter->inputs();
  const std::vector outputs = interpreter->outputs();

  if (s->verbose) {
    LOG(INFO) << "number of inputs: " << inputs.size() << "\n";
    LOG(INFO) << "number of outputs: " << outputs.size() << "\n";
  }
  /*
  // Returns status of success or failure.
  TfLiteStatus AllocateTensors();
  */
  if (interpreter->AllocateTensors() != kTfLiteOk) {
    LOG(FATAL) << "Failed to allocate tensors!";
  }

  //打印运行参数相关信息
  //optional_debug_tools.cc +72
  if (s->verbose) PrintInterpreterState(interpreter.get());

  // get input dimension from the input tensor metadata
  // assuming one input only
  /*
	// Fixed size list of integers. Used for dimensions and inputs/outputs tensor
	// indices
	typedef struct {
	int size;
	// gcc 6.1+ have a bug where flexible members aren't properly handled
	// https://github.com/google/re2/commit/b94b7cd42e9f02673cd748c1ac1d16db4052514c
	#if !defined(__clang__) && defined(__GNUC__) && __GNUC__ == 6 && \
		__GNUC_MINOR__ >= 1
	int data[0];
	#else
	int data[];
	#endif
	} TfLiteIntArray;
  */
  TfLiteIntArray* dims = interpreter->tensor(input)->dims;
  int wanted_height = dims->data[1];
  int wanted_width = dims->data[2];
  int wanted_channels = dims->data[3];

  // 大胆假设是将这些数据都转换成tensor指定的type的类型
  switch (interpreter->tensor(input)->type) {
    case kTfLiteFloat32:
      s->input_floating = true;
      resize(interpreter->typed_tensor(input), in, image_height,
                    image_width, image_channels, wanted_height, wanted_width, wanted_channels, s);
      break;
    case kTfLiteUInt8:
      resize(interpreter->typed_tensor(input), in,
                      image_height, image_width, image_channels, wanted_height,wanted_width, wanted_channels, s);
      break;
    default:
      LOG(FATAL) << "cannot handle input type "
                 << interpreter->tensor(input)->type << " yet";
      exit(-1);
  }

  struct timeval start_time, stop_time;
  gettimeofday(&start_time, NULL);
  //运行模型及获得运行时间
  for (int i = 0; i < s->loop_count; i++) {
    if (interpreter->Invoke() != kTfLiteOk) {
      LOG(FATAL) << "Failed to invoke tflite!\n";
    }
  }
  gettimeofday(&stop_time, NULL);
  LOG(INFO) << "invoked \n";
  LOG(INFO) << "average time: "
            << (get_us(stop_time) - get_us(start_time)) / (s->loop_count * 1000)
            << " ms \n";

  const int output_size = 1000;
  const size_t num_results = 5;
  const float threshold = 0.001f;

  std::vector> top_results;

  // 为什么也是取第一个数据呢?
  int output = interpreter->outputs()[0];
  //获取输出,和上面类似,格式化输出数据的类型
  switch (interpreter->tensor(output)->type) {
    case kTfLiteFloat32:
      get_top_n(interpreter->typed_output_tensor(0), output_size,
                       num_results, threshold, &top_results, true);
      break;
    case kTfLiteUInt8:
      get_top_n(interpreter->typed_output_tensor(0),
                         output_size, num_results, threshold, &top_results,
                         false);
      break;
    default:
      LOG(FATAL) << "cannot handle output type "
                 << interpreter->tensor(input)->type << " yet";
      exit(-1);
  }

  //加载label并显示对应输出结果
  std::vector labels;
  size_t label_count;

  //vi examples/label_image/label_image.cc +52
  // 读取标签文件
  if (ReadLabelsFile(s->labels_file_name, &labels, &label_count) != kTfLiteOk)
    exit(-1);

  // first是float的数据
  // secound是int的数据
  for (const auto& result : top_results) {
    const float confidence = result.first;
    const int index = result.second;
    LOG(INFO) << confidence << ": " << index << " " << labels[index] << "\n";
  }
}

读取lable文件是自己定义:

// Takes a file name, and loads a list of labels from it, one per line, and
// returns a vector of the strings. It pads with empty strings so the length
// of the result is a multiple of 16, because our model expects that.
TfLiteStatus ReadLabelsFile(const string& file_name,
                            std::vector* result,
                            size_t* found_label_count) {
  std::ifstream file(file_name);
  if (!file) {
    LOG(FATAL) << "Labels file " << file_name << " not found\n";
    return kTfLiteError;
  }
  result->clear();
  string line;
  while (std::getline(file, line)) {
    result->push_back(line);
  }
  *found_label_count = result->size();
  const int padding = 16;
  while (result->size() % padding) {
    result->emplace_back();
  }
  return kTfLiteOk;
}

 

你可能感兴趣的:(AI)