【手撕 - 深度学习】TF Lite 魔改:添加自定义 op

作者:LogM

本文原载于 https://segmentfault.com/u/logm/articles ,不允许转载~


1. 前言

Tensorflow Lite 是 Tensorflow 移动端的版本。

有关于 Tensorflow 怎么添加自定义 op,网上有很多博客都讲到了,我就不介绍了。而 Tensorflow Lite 因为相对小众一些,所以网上关于添加自定义 op 的教程很少。

刚好最近因为项目需要,我在 Tensorflow Lite 中添加了几个自定义 op。我把我的思考过程以及修改步骤记录下来,方便有相同需求的同学参考。

我花了大篇幅记录思考过程和源码阅读过程,是希望给其他小伙伴一些启发,以后遇到类似的深度学习框架魔改的问题,可以不依赖网上教程。

不关心思考过程和源码阅读的小伙伴,可以直接跳到文章的最后,我把修改的步骤做了总结。

2. 源码来源

我使用源码是 Tensorflow v1.13.2

Tensorflow Lite 位于 tensorflow/lite 目录下。

3. 官方教程

官网也有关于 Tensorflow Lite 怎么添加自定义 op 的教程,详见官方地址。

官方教程把"怎么写自定义 op 的代码"讲得很清楚,遗憾的是没有详细说明怎么把这些新写的代码放入到工程中编译。

4. 进入正题

第1步,找到目标文件夹位置

首先我们要找到源码中放置自定义 op 的文件夹位置。有多种寻找的方式:

  1. tensorflow 源码的目录结构非常清楚,有过类似框架阅读经验的同学应该马上能猜出位置;
  2. 官方教程告诉我们,自定义 op 的代码要实现 PrepareEval 这两个函数,那么我们使用 grep 命令查找有哪些代码文件中带有这两个函数。

最终,我们找到的位置是 tensorflow/lite/kernels

找到目标文件夹位置以后,把新增代码放入该文件夹就可以了吗?显然,没有这么简单。有几个方面需要考虑:

  1. 代码逻辑层面,新增代码的逻辑怎么与源码的逻辑连接起来;
  2. 编译层面,新增代码怎么参与编译。

第2步,新增代码的逻辑怎么与源码的逻辑连接起来?

有过类似深度学习框架阅读经验的同学应该很快能想到,对于"添加自定义op"这个操作,就是个"op注册"的过程,所以马上想到去寻找带"register"字样的文件。

而没有深度学习框架阅读经验的同学也不用慌,官方教程告诉我们,自定义op在使用前需要调用 AddCustom 函数。那么很明显,这个函数就起到了将自定义op的逻辑与源码逻辑连接起来的任务。所以使用 grep 命令查找有哪些代码文件中带有这个函数。

两种方式殊途同归,找到关键文件 tensorflow/lite/kernels/register.cc

// 文件:tensorflow/lite/kernels/register.cc
// 行数:22

namespace custom {

TfLiteRegistration* Register_AUDIO_SPECTROGRAM();
TfLiteRegistration* Register_LAYER_NORM_LSTM();
TfLiteRegistration* Register_MFCC();
TfLiteRegistration* Register_DETECTION_POSTPROCESS();
TfLiteRegistration* Register_RELU_1();

} 
// 文件:tensorflow/lite/kernels/register.cc
// 行数:278

  // TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
  // custom ops aren't always included by default.
  AddCustom("Mfcc", tflite::ops::custom::Register_MFCC());
  AddCustom("AudioSpectrogram",
            tflite::ops::custom::Register_AUDIO_SPECTROGRAM());
  AddCustom("LayerNormLstm", tflite::ops::custom::Register_LAYER_NORM_LSTM());
  AddCustom("Relu1", tflite::ops::custom::Register_RELU_1());
  AddCustom("TFLite_Detection_PostProcess",
            tflite::ops::custom::Register_DETECTION_POSTPROCESS());

嘿嘿嘿,我们发现官方源码中也放了5个自定义op,而且官方偷懒把自定义op与内置op的注册过程写在了一起,那么我们来看看官方是怎么写自定义op的吧,比如 Relu1 这个。

// 文件:tensorflow/lite/kernels/relu1.cc

#include "tensorflow/lite/context.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/kernel_util.h"

namespace tflite {
namespace ops {
namespace custom {
namespace relu1 {

TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
  TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
  TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
  const TfLiteTensor* input = GetInput(context, node, 0);
  TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
  TfLiteTensor* output = GetOutput(context, node, 0);
  output->type = input->type;
  return context->ResizeTensor(context, output,
                               TfLiteIntArrayCopy(input->dims));
}

// This is derived from lite/kernels/activations.cc.
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
  const TfLiteTensor* input = GetInput(context, node, 0);
  TfLiteTensor* output = GetOutput(context, node, 0);
  const int elements = NumElements(input);
  const float* in = input->data.f;
  const float* in_end = in + elements;
  float* out = output->data.f;
  for (; in < in_end; ++in, ++out) {
    *out = std::min(std::max(0.f, *in), 1.f);
  }
  return kTfLiteOk;
}

}  // namespace relu1

TfLiteRegistration* Register_RELU_1() {
  static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
                                 relu1::Prepare, relu1::Eval};
  return &r;
}

}  // namespace custom
}  // namespace ops
}  // namespace tflite

可以看到,与官方给出的教程一样,关键点是实现 PrepareEval 这两个函数。我们自己在自定义op的代码时,可以把这个文件当做参考模板。

第3步,新增代码怎么参与编译?

这块需要一些 C++ 大工程开发的知识,Tensorflow 是用 Bazel 作工程编译的,所以关键点在目标文件夹下的 BUILD 文件。

BUILD 文件里面这么多的 library,我们的新代码应该编译到哪个 library 中呢?还记得 官方留的自定义op "Relu1" 吗?我们来看看 "Relu1" 是编译到哪个 library。

// 文件:tensorflow/lite/kernels/BUILD
// 行数:278

cc_library(
    name = "builtin_op_kernels",
    srcs = [
        ...     // 这里有很多其他的源文件
        "mfcc.cc",
        "relu1.cc",
        ...     // 把新写的代码文件加到这边就可以了
    ],
    hdrs = [
    ],
    copts = tflite_copts() + tf_opts_nortti_if_android() + EXTRA_EIGEN_COPTS,
    visibility = ["//visibility:private"],
    deps = [
        ":activation_functor",
        ":eigen_support",
        ":kernel_util",
        ":lstm_eval",
        ":op_macros",
        ":padding",
        "//tensorflow/lite:framework",
        "//tensorflow/lite:string_util",
        "//tensorflow/lite/c:c_api_internal",
        "//tensorflow/lite/kernels:gemm_support",
        "//tensorflow/lite/kernels/internal:audio_utils",
        "//tensorflow/lite/kernels/internal:kernel_utils",
        "//tensorflow/lite/kernels/internal:optimized",
        "//tensorflow/lite/kernels/internal:optimized_base",
        "//tensorflow/lite/kernels/internal:quantization_util",
        "//tensorflow/lite/kernels/internal:reference_base",
        "//tensorflow/lite/kernels/internal:tensor",
        "//tensorflow/lite/kernels/internal:tensor_utils",
        "@farmhash_archive//:farmhash",
        "@flatbuffers",
    ],
)

嘿嘿嘿,官方可真会偷懒,自定义 op 和内置 op 一起编译到 builtin_op_kernels 库。所以,我们只要把新的代码文件添加到 srcs=[] 里,新的代码就能参与到编译过程中了。

5. 总结

Tensorflow Lite v1.13.2 中,官方偷了个懒,自定义 op 与内置 op 写在同一个位置,且都是编译到 builtin_op_kernels 库。

Tensorflow Lite 的自定义 op 添加方式如下:

  1. 参照 官方教程 以及 tensorflow/lite/kernels/relu1.cc 编写 op 代码;
  2. 将 op 代码放入 tensorflow/lite/kernels 文件夹下;
  3. 修改 tensorflow/lite/kernels/register.cc,完成新增 op 在代码逻辑上的"注册";
  4. 修改 tensorflow/lite/kernels/BUILD,将新代码文件加入到 builtin_op_kernels 库的编译过程中;
  5. 参照 官方教程 重新编译整个项目。

你可能感兴趣的:(深度学习,tensorflow)