TensorFlow Lite 开发手册(5)——TensorFlow Lite模型使用实例(分类模型)

TensorFlow Lite 开发手册(5)——TensorFlow Lite模型使用实例(分类模型)

  • (一)新建CLion工程
  • (二)编写Cmakelist
  • (三)编写main.cpp
  • (四)下载预训练模型
  • (五)修改模型配置
  • (六)运行实例

(一)新建CLion工程

到(https://download.csdn.net/download/weixin_42499236/11892106)下载该工程,解压后如下图所示:
TensorFlow Lite 开发手册(5)——TensorFlow Lite模型使用实例(分类模型)_第1张图片

(二)编写Cmakelist

cmake_minimum_required(VERSION 3.15)
project(testlite)

set(CMAKE_CXX_STANDARD 14)

include_directories(/home/ai/CLionProjects/tensorflow-master/)
include_directories(/home/ai/CLionProjects/tensorflow-master/tensorflow/lite/tools/make/downloads/flatbuffers/include)
include_directories(/home/ai/CLionProjects/tensorflow-master/tensorflow/lite/tools/make/downloads/absl)

add_executable(testlite main.cpp bitmap_helpers.cc utils.cc)

target_link_libraries(testlite /home/ai/CLionProjects/tensorflow-master/tensorflow/lite/tools/make/gen/linux_x86_64/lib/libtensorflow-lite.a -lpthread -ldl -lrt)

(三)编写main.cpp

  • 导入头文件
#include       // NOLINT(build/include_order)
#include      // NOLINT(build/include_order)
#include    // NOLINT(build/include_order)
#include   // NOLINT(build/include_order)
#include     // NOLINT(build/include_order)
#include      // NOLINT(build/include_order)

#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 

#include "bitmap_helpers.h"
#include "get_top_n.h"

#include "tensorflow/lite/model.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/optional_debug_tools.h"
#include "tensorflow/lite/string_util.h"
#include "tensorflow/lite/profiling/profiler.h"
#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
#include "absl/memory/memory.h"
#include "utils.h"

using namespace std;
  • 调用GPU、NNAPI加速(若无GPU,则默认使用CPU)
#define LOG(x) std::cerr

double get_us(struct timeval t) { return (t.tv_sec * 1000000 + t.tv_usec); }

using TfLiteDelegatePtr = tflite::Interpreter::TfLiteDelegatePtr;
using TfLiteDelegatePtrMap = std::map<std::string, TfLiteDelegatePtr>;

// 调用GPU
TfLiteDelegatePtr CreateGPUDelegate(tflite::label_image::Settings* s) {
#if defined(__ANDROID__)
    TfLiteGpuDelegateOptionsV2 gpu_opts = TfLiteGpuDelegateOptionsV2Default();
  gpu_opts.inference_preference =
      TFLITE_GPU_INFERENCE_PREFERENCE_SUSTAINED_SPEED;
  gpu_opts.is_precision_loss_allowed = s->allow_fp16 ? 1 : 0;
  return evaluation::CreateGPUDelegate(s->model, &gpu_opts);
#else
    return tflite::evaluation::CreateGPUDelegate(s->model);
#endif
}

TfLiteDelegatePtrMap GetDelegates(tflite::label_image::Settings* s) {
    TfLiteDelegatePtrMap delegates;
    if (s->gl_backend) {
        auto delegate = CreateGPUDelegate(s);
        if (!delegate) {
            LOG(INFO) << "GPU acceleration is unsupported on this platform.";
        } else {
            delegates.emplace("GPU", std::move(delegate));
        }
    }

    if (s->accel) {
        auto delegate = tflite::evaluation::CreateNNAPIDelegate();
        if (!delegate) {
            LOG(INFO) << "NNAPI acceleration is unsupported on this platform.";
        } else {
            delegates.emplace("NNAPI", tflite::evaluation::CreateNNAPIDelegate());
        }
    }
    return delegates;
}
  • 读取标签文件
TfLiteStatus ReadLabelsFile(const string& file_name,
                            std::vector<string>* result,
                            size_t* found_label_count) {
    std::ifstream file(file_name);
    if (!file) {
        LOG(FATAL) << "Labels file " << file_name << " not found\n";
        return kTfLiteError;
    }
    result->clear();
    string line;
    while (std::getline(file, line)) {
        result->push_back(line);
    }
    *found_label_count = result->size();
    const int padding = 16;
    while (result->size() % padding) {
        result->emplace_back();
    }
    return kTfLiteOk;
}
  • 打印模型节点信息
void PrintProfilingInfo(const tflite::profiling::ProfileEvent* e,
                        uint32_t subgraph_index, uint32_t op_index,
                        TfLiteRegistration registration) {
    // output something like
    // time (ms) , Node xxx, OpCode xxx, symblic name
    //      5.352, Node   5, OpCode   4, DEPTHWISE_CONV_2D

    LOG(INFO) << std::fixed << std::setw(10) << std::setprecision(3)
              << (e->end_timestamp_us - e->begin_timestamp_us) / 1000.0
              << ", Subgraph " << std::setw(3) << std::setprecision(3)
              << subgraph_index << ", Node " << std::setw(3)
              << std::setprecision(3) << op_index << ", OpCode " << std::setw(3)
              << std::setprecision(3) << registration.builtin_code << ", "
              << EnumNameBuiltinOperator(
                      static_cast<tflite::BuiltinOperator>(registration.builtin_code))
              << "\n";
}
  • 定义模型推理函数
void RunInference(tflite::label_image::Settings* s){
    if (!s->model_name.c_str()) {
        LOG(ERROR) << "no model file name\n";
        exit(-1);
    }

// 读取.tflite模型
    std::unique_ptr<tflite::FlatBufferModel> model;
    std::unique_ptr<tflite::Interpreter> interpreter;
    model = tflite::FlatBufferModel::BuildFromFile(s->model_name.c_str());
    if (!model) {
        LOG(FATAL) << "\nFailed to mmap model " << s->model_name << "\n";
        exit(-1);
    }
    s->model = model.get();
    LOG(INFO) << "Loaded model " << s->model_name << "\n";
    model->error_reporter();
    LOG(INFO) << "resolved reporter\n";
// 生成解释器
    tflite::ops::builtin::BuiltinOpResolver resolver;

    tflite::InterpreterBuilder(*model, resolver)(&interpreter);
    if (!interpreter) {
        LOG(FATAL) << "Failed to construct interpreter\n";
        exit(-1);
    }

    interpreter->UseNNAPI(s->old_accel);
    interpreter->SetAllowFp16PrecisionForFp32(s->allow_fp16);
// 打印解释器参数,包括张量大小、输入节点名称等
    if (s->verbose) {
        LOG(INFO) << "tensors size: " << interpreter->tensors_size() << "\n";
        LOG(INFO) << "nodes size: " << interpreter->nodes_size() << "\n";
        LOG(INFO) << "inputs: " << interpreter->inputs().size() << "\n";
        LOG(INFO) << "input(0) name: " << interpreter->GetInputName(0) << "\n";

        int t_size = interpreter->tensors_size();
        for (int i = 0; i < t_size; i++) {
            if (interpreter->tensor(i)->name)
                LOG(INFO) << i << ": " << interpreter->tensor(i)->name << ", "
                          << interpreter->tensor(i)->bytes << ", "
                          << interpreter->tensor(i)->type << ", "
                          << interpreter->tensor(i)->params.scale << ", "
                          << interpreter->tensor(i)->params.zero_point << "\n";
        }
    }

    if (s->number_of_threads != -1) {
        interpreter->SetNumThreads(s->number_of_threads);
    }

// 定义输入图像参数
    int image_width = 224;
    int image_height = 224;
    int image_channels = 3;
// 读取bmp图像
    std::vector<uint8_t> in = tflite::label_image::read_bmp(s->input_bmp_name, &image_width,
                                       &image_height, &image_channels, s);

    int input = interpreter->inputs()[0];

    if (s->verbose) LOG(INFO) << "input: " << input << "\n";

    const std::vector<int> inputs = interpreter->inputs();
    const std::vector<int> outputs = interpreter->outputs();

    if (s->verbose) {
        LOG(INFO) << "number of inputs: " << inputs.size() << "\n";
        LOG(INFO) << "number of outputs: " << outputs.size() << "\n";
    }

// 创建图
    auto delegates_ = GetDelegates(s);
    for (const auto& delegate : delegates_) {
        if (interpreter->ModifyGraphWithDelegate(delegate.second.get()) !=
            kTfLiteOk) {
            LOG(FATAL) << "Failed to apply " << delegate.first << " delegate.";
        } else {
            LOG(INFO) << "Applied " << delegate.first << " delegate.";
        }
    }

    if (interpreter->AllocateTensors() != kTfLiteOk) {
        LOG(FATAL) << "Failed to allocate tensors!";
    }

    if (s->verbose) PrintInterpreterState(interpreter.get());

// 获取输入张量元数据的维度等信息
    TfLiteIntArray* dims = interpreter->tensor(input)->dims;
    int wanted_height = dims->data[1];
    int wanted_width = dims->data[2];
    int wanted_channels = dims->data[3];

// 对图像进行resize
    switch (interpreter->tensor(input)->type) {
        case kTfLiteFloat32:
            s->input_floating = true;
            tflite::label_image::resize<float>(interpreter->typed_tensor<float>(input), in.data(),
                          image_height, image_width, image_channels, wanted_height,
                          wanted_width, wanted_channels, s);
            break;
        case kTfLiteUInt8:
            tflite::label_image::resize<uint8_t>(interpreter->typed_tensor<uint8_t>(input), in.data(),
                            image_height, image_width, image_channels, wanted_height,
                            wanted_width, wanted_channels, s);
            break;
        default:
            LOG(FATAL) << "cannot handle input type "
                       << interpreter->tensor(input)->type << " yet";
            exit(-1);
    }

// 调用解释器
    auto profiler =
            absl::make_unique<tflite::profiling::Profiler>(s->max_profiling_buffer_entries);
    interpreter->SetProfiler(profiler.get());

    if (s->profiling) profiler->StartProfiling();
    if (s->loop_count > 1)
        for (int i = 0; i < s->number_of_warmup_runs; i++) {
            if (interpreter->Invoke() != kTfLiteOk) {
                LOG(FATAL) << "Failed to invoke tflite!\n";
            }
        }
// 进行模型推理并计算运行时间
    struct timeval start_time, stop_time;
    gettimeofday(&start_time, nullptr);
    for (int i = 0; i < s->loop_count; i++) {
        if (interpreter->Invoke() != kTfLiteOk) {
            LOG(FATAL) << "Failed to invoke tflite!\n";
        }
    }
    gettimeofday(&stop_time, nullptr);
    LOG(INFO) << "invoked \n";
    LOG(INFO) << "average time: "
              << (get_us(stop_time) - get_us(start_time)) / (s->loop_count * 1000)
              << " ms \n";
// 打印运行事件
    if (s->profiling) {
        profiler->StopProfiling();
        auto profile_events = profiler->GetProfileEvents();
        for (int i = 0; i < profile_events.size(); i++) {
            auto subgraph_index = profile_events[i]->event_subgraph_index;
            auto op_index = profile_events[i]->event_metadata;
            const auto subgraph = interpreter->subgraph(subgraph_index);
            const auto node_and_registration =
                    subgraph->node_and_registration(op_index);
            const TfLiteRegistration registration = node_and_registration->second;
            PrintProfilingInfo(profile_events[i], subgraph_index, op_index,
                               registration);
        }
    }

    const float threshold = 0.001f;

    std::vector<std::pair<float, int>> top_results;

// 获取Top-N结果
    int output = interpreter->outputs()[0];
    TfLiteIntArray* output_dims = interpreter->tensor(output)->dims;
    // assume output dims to be something like (1, 1, ... ,size)
    auto output_size = output_dims->data[output_dims->size - 1];
    switch (interpreter->tensor(output)->type) {
        case kTfLiteFloat32:
            tflite::label_image::get_top_n<float>(interpreter->typed_output_tensor<float>(0), output_size,
                             s->number_of_results, threshold, &top_results, true);
            break;
        case kTfLiteUInt8:
            tflite::label_image::get_top_n<uint8_t>(interpreter->typed_output_tensor<uint8_t>(0),
                               output_size, s->number_of_results, threshold,
                               &top_results, false);
            break;
        default:
            LOG(FATAL) << "cannot handle output type "
                       << interpreter->tensor(input)->type << " yet";
            exit(-1);
    }

    std::vector<string> labels;
    size_t label_count;

    if (ReadLabelsFile(s->labels_file_name, &labels, &label_count) != kTfLiteOk)
        exit(-1);
// 打印Top-N结果
    for (const auto& result : top_results) {
        const float confidence = result.first;
        const int index = result.second;
        LOG(INFO) << confidence << ": " << index << " " << labels[index] << "\n";
    }
}

int main() {
    tflite::label_image::Settings s;
    RunInference(&s);
}

(四)下载预训练模型

# Get model
curl https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224.tgz | tar xzv -C /tmp

# Get labels
curl https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_1.0_224_frozen.tgz  | tar xzv -C /tmp  mobilenet_v1_1.0_224/labels.txt

mv /tmp/mobilenet_v1_1.0_224/labels.txt /tmp/

(五)修改模型配置

在label_image.h中修改Settings:

struct Settings {
  bool verbose = false;
  bool accel = false;
  bool old_accel = false;
  bool input_floating = false;
  bool profiling = false;
  bool allow_fp16 = false;
  bool gl_backend = false;
  int loop_count = 1;
  float input_mean = 127.5f;
  float input_std = 127.5f;
  string model_name = "/home/ai/CLionProjects/tflite/mobilenet_v1_1.0_224/mobilenet_v1_1.0_224.tflite";
  tflite::FlatBufferModel* model;
  string input_bmp_name = "/home/ai/CLionProjects/tflite/grace_hopper.bmp";
  string labels_file_name = "/home/ai/CLionProjects/tflite/mobilenet_v1_1.0_224/labels.txt";
  string input_layer_type = "uint8_t";
  int number_of_threads = 4;
  int number_of_results = 5;
  int max_profiling_buffer_entries = 1024;
  int number_of_warmup_runs = 2;
};

(六)运行实例

Top5分类结果输出如下:

Loaded model /tmp/mobilenet_v1_1.0_224.tflite
resolved reporter
invoked
average time: 68.12 ms
0.860174: 653 653:military uniform
0.0481017: 907 907:Windsor tie
0.00786704: 466 466:bulletproof vest
0.00644932: 514 514:cornet, horn, trumpet, trump
0.00608029: 543 543:drumstick

结果显示该图像被正确分类,平均耗时68.12ms,速度非常快!

你可能感兴趣的:(TensorFlow,Lite开发手册)