C++ TensorflowLite模型验证的过程详解

故事是这样的:

有一个手撑检测的tflite模型,需要在开发板上跑起来。手机版本的已成熟,要移植到开发板上。现在要验证tflite模型文件在板子上的运行结果要和手机上一致。

前提:为了多次重复测试,在Android端使用了同一帧数据(从一个录制的mp4中固定取一张图)测试代码如下图

C++ TensorflowLite模型验证的过程详解_第1张图片

下面是测试过程 

记录下Android版API运行推理前的图片数据文件(经过了规一化处理,所以都是-1~1之间的float数据)

C++ TensorflowLite模型验证的过程详解_第2张图片

这一步卡在了写float数据到二进制文件中,C++读出来有问题

换了个方案,直接存储float字符串

private void saveFile(float[] pfImageData) {
        try {
            File file = new File(Environment.getExternalStoragePublicDirectory(Environment.DIRECTORY_DOWNLOADS).getAbsolutePath() + "/tfimg");
 
            StringBuilder sb = new StringBuilder();
            for (float val : pfImageData) {
                //保留4位小数,这里可以改为其他值
                sb.append(String.format("%.4f", val));
                sb.append("\r\n");
            }
 
            FileWriter out = new FileWriter(file);  //文件写入流
            out.write(sb.toString());
            out.close();
        } catch (Exception e) {
            e.printStackTrace();
            Log.e("Melon", "存储文件异常," + e.getMessage());
        }
    }

拿着这个文件在板子上输入到Tflite模型中

测试代码,主要是RunInference()和read_file()

/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
 
#include "tensorflow/lite/examples/label_image/label_image.h"
 
#include      // NOLINT(build/include_order)
#include     // NOLINT(build/include_order)
#include   // NOLINT(build/include_order)
#include  // NOLINT(build/include_order)
#include    // NOLINT(build/include_order)
#include     // NOLINT(build/include_order)
 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
 
#include "absl/memory/memory.h"
#include "tensorflow/lite/examples/label_image/bitmap_helpers.h"
#include "tensorflow/lite/examples/label_image/get_top_n.h"
#include "tensorflow/lite/examples/label_image/log.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/optional_debug_tools.h"
#include "tensorflow/lite/profiling/profiler.h"
#include "tensorflow/lite/string_util.h"
#include "tensorflow/lite/tools/command_line_flags.h"
#include "tensorflow/lite/tools/delegates/delegate_provider.h"
 
namespace tflite
{
  namespace label_image
  {
 
    double get_us(struct timeval t) { return (t.tv_sec * 1000000 + t.tv_usec); }
 
    using TfLiteDelegatePtr = tflite::Interpreter::TfLiteDelegatePtr;
    using ProvidedDelegateList = tflite::tools::ProvidedDelegateList;
 
    class DelegateProviders
    {
    public:
      DelegateProviders() : delegate_list_util_(¶ms_)
      {
        delegate_list_util_.AddAllDelegateParams();
      }
 
      // Initialize delegate-related parameters from parsing command line arguments,
      // and remove the matching arguments from (*argc, argv). Returns true if all
      // recognized arg values are parsed correctly.
      bool InitFromCmdlineArgs(int *argc, const char **argv)
      {
        std::vector flags;
        // delegate_list_util_.AppendCmdlineFlags(&flags);
 
        const bool parse_result = Flags::Parse(argc, argv, flags);
        if (!parse_result)
        {
          std::string usage = Flags::Usage(argv[0], flags);
          LOG(ERROR) << usage;
        }
        return parse_result;
      }
 
      // According to passed-in settings `s`, this function sets corresponding
      // parameters that are defined by various delegate execution providers. See
      // lite/tools/delegates/README.md for the full list of parameters defined.
      void MergeSettingsIntoParams(const Settings &s)
      {
        // Parse settings related to GPU delegate.
        // Note that GPU delegate does support OpenCL. 'gl_backend' was introduced
        // when the GPU delegate only supports OpenGL. Therefore, we consider
        // setting 'gl_backend' to true means using the GPU delegate.
        if (s.gl_backend)
        {
          if (!params_.HasParam("use_gpu"))
          {
            LOG(WARN) << "GPU deleate execution provider isn't linked or GPU "
                         "delegate isn't supported on the platform!";
          }
          else
          {
            params_.Set("use_gpu", true);
            // The parameter "gpu_inference_for_sustained_speed" isn't available for
            // iOS devices.
            if (params_.HasParam("gpu_inference_for_sustained_speed"))
            {
              params_.Set("gpu_inference_for_sustained_speed", true);
            }
            params_.Set("gpu_precision_loss_allowed", s.allow_fp16);
          }
        }
 
        // Parse settings related to NNAPI delegate.
        if (s.accel)
        {
          if (!params_.HasParam("use_nnapi"))
          {
            LOG(WARN) << "NNAPI deleate execution provider isn't linked or NNAPI "
                         "delegate isn't supported on the platform!";
          }
          else
          {
            params_.Set("use_nnapi", true);
            params_.Set("nnapi_allow_fp16", s.allow_fp16);
          }
        }
 
        // Parse settings related to Hexagon delegate.
        if (s.hexagon_delegate)
        {
          if (!params_.HasParam("use_hexagon"))
          {
            LOG(WARN) << "Hexagon deleate execution provider isn't linked or "
                         "Hexagon delegate isn't supported on the platform!";
          }
          else
          {
            params_.Set("use_hexagon", true);
            params_.Set("hexagon_profiling", s.profiling);
          }
        }
 
        // Parse settings related to XNNPACK delegate.
        if (s.xnnpack_delegate)
        {
          if (!params_.HasParam("use_xnnpack"))
          {
            LOG(WARN) << "XNNPACK deleate execution provider isn't linked or "
                         "XNNPACK delegate isn't supported on the platform!";
          }
          else
          {
            params_.Set("use_xnnpack", true);
            params_.Set("num_threads", s.number_of_threads);
          }
        }
      }
 
      // Create a list of TfLite delegates based on what have been initialized (i.e.
      // 'params_').
      std::vector CreateAllDelegates()
          const
      {
        return delegate_list_util_.CreateAllRankedDelegates();
      }
 
    private:
      // Contain delegate-related parameters that are initialized from command-line
      // flags.
      tflite::tools::ToolParams params_;
 
      // A helper to create TfLite delegates.
      ProvidedDelegateList delegate_list_util_;
    };
 
    // Takes a file name, and loads a list of labels from it, one per line, and
    // returns a vector of the strings. It pads with empty strings so the length
    // of the result is a multiple of 16, because our model expects that.
 
    // std::vector read_file(const std::string &input_bmp_name)
    // {
    //   int begin, end;
 
    //   std::ifstream file(input_bmp_name, std::ios::in | std::ios::binary);
    //   if (!file)
    //   {
    //     LOG(FATAL) << "input file " << input_bmp_name << " not found";
    //     exit(-1);
    //   }
 
    //   begin = file.tellg();
    //   file.seekg(0, std::ios::end);
    //   end = file.tellg();
    //   size_t len = end - begin;
 
    //   LOG(INFO) << "len: " << len;
    //   std::vector img_bytes(len);
 
    //   file.seekg(0, std::ios::beg);
    //   file.read(reinterpret_cast(img_bytes.data()), len);
 
    //   return img_bytes;
    // }
 
    /**
     * 读取文件
     */
    std::vector read_file(const std::string &input_bmp_name)
    {
      int begin, end;
 
      std::ifstream file(input_bmp_name, std::ios::in | std::ios::binary);
      if (!file)
      {
        LOG(FATAL) << "input file " << input_bmp_name << " not found";
        exit(-1);
      }
 
      begin = file.tellg();
      file.seekg(0, std::ios::end);
      end = file.tellg();
      size_t len = end - begin;
 
      LOG(INFO) << "len: " << len;
      std::vector img_bytes;
 
      file.seekg(0, std::ios::beg);
 
      string strLine = "";
      float temp;
      while (getline(file, strLine))
      {
        temp = atof(strLine.c_str());
        img_bytes.push_back(temp);
      }
 
      LOG(INFO) << "文件读取完成:" << input_bmp_name;
      return img_bytes;
    }
 
    /**
     * 运行推理
     */
    void RunInference(Settings *settings)
    {
      if (!settings->model_name.c_str())
      {
        LOG(ERROR) << "no model file name";
        exit(-1);
      }
 
      std::unique_ptr model;
      std::unique_ptr interpreter;
      model = tflite::FlatBufferModel::BuildFromFile(settings->model_name.c_str());
      if (!model)
      {
        LOG(ERROR) << "Failed to mmap model " << settings->model_name;
        exit(-1);
      }
      settings->model = model.get();
      LOG(INFO) << "Loaded model " << settings->model_name;
      model->error_reporter();
      LOG(INFO) << "resolved reporter";
 
      tflite::ops::builtin::BuiltinOpResolver resolver;
 
      tflite::InterpreterBuilder(*model, resolver)(&interpreter); //生成interpreter
      if (!interpreter)
      {
        LOG(ERROR) << "Failed to construct interpreter";
        exit(-1);
      }
 
      interpreter->SetAllowFp16PrecisionForFp32(settings->allow_fp16);
 
      if (settings->verbose)
      {
        LOG(INFO) << "tensors size: " << interpreter->tensors_size();
        LOG(INFO) << "nodes size: " << interpreter->nodes_size();
        LOG(INFO) << "inputs: " << interpreter->inputs().size();
        LOG(INFO) << "input(0) name: " << interpreter->GetInputName(0);
 
        int t_size = interpreter->tensors_size();
        for (int i = 0; i < t_size; i++)
        {
          if (interpreter->tensor(i)->name)
            LOG(INFO) << i << ": " << interpreter->tensor(i)->name << ", "
                      << interpreter->tensor(i)->bytes << ", "
                      << interpreter->tensor(i)->type << ", "
                      << interpreter->tensor(i)->params.scale << ", "
                      << interpreter->tensor(i)->params.zero_point;
        }
      }
 
      if (settings->number_of_threads != -1)
      {
        interpreter->SetNumThreads(settings->number_of_threads);
      }
 
      int image_width = 128;
      int image_height = 128;
      int image_channels = 3;
      // std::vector in = read_bmp(settings->input_bmp_name, &image_width, &image_height, &image_channels, settings);
      std::vector file_bytes = read_file(settings->input_bmp_name);
      for (int i = 0; i < 100; i++)
      {
        //和Android的输入做对比
        LOG(INFO) << i << ": " << file_bytes[i];
      }
 
      /* inputs()[0]得到输入张量数组中的第一个张量,也就是classifier中唯一的那个输入张量;
      input是个整型值,是张量列表中的引索 */
      int input = interpreter->inputs()[0];
      LOG(INFO) << "input: " << input;
 
      const std::vector inputs = interpreter->inputs();
      const std::vector outputs = interpreter->outputs();
 
      LOG(INFO) << "number of inputs: " << inputs.size();
      LOG(INFO) << "input index: " << inputs[0];
      LOG(INFO) << "number of outputs: " << outputs.size();
      LOG(INFO) << "outputs index1: " << outputs[0] << ",outputs index2: " << outputs[1];
 
      if (interpreter->AllocateTensors() != kTfLiteOk)
      { //加载所有tensor
        LOG(ERROR) << "Failed to allocate tensors!";
        exit(-1);
      }
 
      if (settings->verbose)
        PrintInterpreterState(interpreter.get());
 
      // 从输入张量的原数据中得到输入尺寸
      TfLiteIntArray *dims = interpreter->tensor(input)->dims;
      int wanted_height = dims->data[1];
      int wanted_width = dims->data[2];
      int wanted_channels = dims->data[3];
 
      settings->input_type = interpreter->tensor(input)->type;
 
      //typed_tensor返回一个经过固定数据类型转换的tensor指针
      //以input为索引,在TfLiteTensor* content_.tensors这个张量表得到具体的张量
      //返回该张量的data.raw,它指示张量正关联着的内存块
      // resize(interpreter->typed_tensor(input), in.data(),
      //               image_height, image_width, image_channels, wanted_height,
      //               wanted_width, wanted_channels, settings);
 
      //赋值给input tensor
      float *inputP = interpreter->typed_input_tensor(0);
 
      LOG(INFO) << "file_bytes size: " << file_bytes.size();
      for (int i = 0; i < file_bytes.size(); i++)
      {
        inputP[i] = file_bytes[i];
      }
 
      struct timeval start_time, stop_time;
      gettimeofday(&start_time, nullptr);
      for (int i = 0; i < settings->loop_count; i++)
      { //调用模型进行推理
        if (interpreter->Invoke() != kTfLiteOk)
        {
          LOG(ERROR) << "Failed to invoke tflite!";
          exit(-1);
        }
      }
      gettimeofday(&stop_time, nullptr);
      LOG(INFO) << "invoked";
      LOG(INFO) << "average time: "
                << (get_us(stop_time) - get_us(start_time)) /
                       (settings->loop_count * 1000)
                << " ms";
 
      const float threshold = 0.001f;
 
      int output = interpreter->outputs()[1];
      LOG(INFO) << "output: " << output;
      LOG(INFO) << "interpreter->tensors_size: " << interpreter->tensors_size();
 
      TfLiteTensor *tensor = interpreter->tensor(output);
 
      TfLiteIntArray *output_dims = tensor->dims;
      // assume output dims to be something like (1, 1, ... ,size)
      auto output_size = output_dims->data[output_dims->size - 1];
      LOG(INFO) << "索引为" << output << "的输出张量的-"
                << "output_size: " << output_size;
 
      for (int i = 0; i < output_dims->size; i++)
      {
        LOG(INFO) << "元数据有:" << output_dims->data[i];
      }
 
      float *prediction = interpreter->typed_output_tensor(1);
 
      float classificators[1][896][1];
      memcpy(classificators, prediction, 896 * 1 * sizeof(float));
      // float classificators[1][896][18];
      // memcpy(classificators, prediction, 896 * 18 * sizeof(float));
 
      //输出分类结果
      for (float(&r)[896][1] : classificators)
      {
        for (float(&p)[1] : r)
        {
          for (float &q : p)
          {
            std::cout << q << ' ';
          }
          std::cout << std::endl;
        }
        std::cout << std::endl;
      }
    }
 
    int Main(int argc, char **argv)
    {
      DelegateProviders delegate_providers;
      bool parse_result = delegate_providers.InitFromCmdlineArgs(
          &argc, const_cast(argv));
      if (!parse_result)
      {
        return EXIT_FAILURE;
      }
 
      Settings s;
 
      int c;
      while (true)
      {
        static struct option long_options[] = {
            {"accelerated", required_argument, nullptr, 'a'},
            {"allow_fp16", required_argument, nullptr, 'f'},
            {"count", required_argument, nullptr, 'c'},
            {"verbose", required_argument, nullptr, 'v'},
            {"image", required_argument, nullptr, 'i'},
            {"labels", required_argument, nullptr, 'l'},
            {"tflite_model", required_argument, nullptr, 'm'},
            {"profiling", required_argument, nullptr, 'p'},
            {"threads", required_argument, nullptr, 't'},
            {"input_mean", required_argument, nullptr, 'b'},
            {"input_std", required_argument, nullptr, 's'},
            {"num_results", required_argument, nullptr, 'r'},
            {"max_profiling_buffer_entries", required_argument, nullptr, 'e'},
            {"warmup_runs", required_argument, nullptr, 'w'},
            {"gl_backend", required_argument, nullptr, 'g'},
            {"hexagon_delegate", required_argument, nullptr, 'j'},
            {"xnnpack_delegate", required_argument, nullptr, 'x'},
            {nullptr, 0, nullptr, 0}};
 
        /* getopt_long stores the option index here. */
        int option_index = 0;
 
        c = getopt_long(argc, argv,
                        "a:b:c:d:e:f:g:i:j:l:m:p:r:s:t:v:w:x:", long_options,
                        &option_index);
 
        /* Detect the end of the options. */
        if (c == -1)
          break;
 
        switch (c)
        {
        case 'a':
          s.accel = strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn)
          break;
        case 'b':
          s.input_mean = strtod(optarg, nullptr);
          break;
        case 'c':
          s.loop_count =
              strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn)
          break;
        case 'e':
          s.max_profiling_buffer_entries =
              strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn)
          break;
        case 'f':
          s.allow_fp16 =
              strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn)
          break;
        case 'g':
          s.gl_backend =
              strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn)
          break;
        case 'i':
          s.input_bmp_name = optarg;
          break;
        case 'j':
          s.hexagon_delegate = optarg;
          break;
        case 'l':
          s.labels_file_name = optarg;
          break;
        case 'm':
          s.model_name = optarg;
          break;
        case 'p':
          s.profiling =
              strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn)
          break;
        case 'r':
          s.number_of_results =
              strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn)
          break;
        case 's':
          s.input_std = strtod(optarg, nullptr);
          break;
        case 't':
          s.number_of_threads = strtol( // NOLINT(runtime/deprecated_fn)
              optarg, nullptr, 10);
          break;
        case 'v':
          s.verbose =
              strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn)
          break;
        case 'w':
          s.number_of_warmup_runs =
              strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn)
          break;
        case 'x':
          s.xnnpack_delegate =
              strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn)
          break;
        case 'h':
        case '?':
          /* getopt_long already printed an error message. */
          exit(-1);
        default:
          exit(-1);
        }
      }
 
      delegate_providers.MergeSettingsIntoParams(s);
      RunInference(&s);
      return 0;
    }
 
  } // namespace label_image
} // namespace tflite
 
int main(int argc, char **argv)
{
  return tflite::label_image::Main(argc, argv);
}

运行指令 ./ws_app --tflite_model libnewpalm_detection.tflite --image tfimg对比推理前的输入一致

Android端

C++ TensorflowLite模型验证的过程详解_第3张图片

开发板上

C++ TensorflowLite模型验证的过程详解_第4张图片

对比推理后的输出一致 Android端

C++ TensorflowLite模型验证的过程详解_第5张图片

开发板端

C++ TensorflowLite模型验证的过程详解_第6张图片

到此这篇关于C++ TensorflowLite模型验证的文章就介绍到这了,更多相关C++ TensorflowLite模型验证内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

你可能感兴趣的:(C++ TensorflowLite模型验证的过程详解)