目录
背景介绍
代码示例
核心数据结构
REGISTER_OP 分析
REGISTER_KERNEL_BUILDER 分析
自定义算子运行分析
TensorFlow官网如何创建自定义算子OP的How to文档链接,本文基于该文档中的代码示例,着重分析TensorFlow框架是如何实现用户自定义算子扩展的功能,力求知其然还要知其所以然。
TensorFlow源码下载链接,本文的分析基于最新的master分支 2.3.2 版本 (Google 自己又整了个bazel编译框架,编译过程也挺折腾的... )
//TfCustomOp.cpp
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/op_kernel.h"
using namespace tensorflow;
//Define&Register OP
REGISTER_OP("ZeroOut")
.Input("to_zero: int32")
.Output("zeroed: int32")
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
c->set_output(0, c->input(0));
return Status::OK();
});
//Define OP's Kernel
class ZeroOutOp : public OpKernel {
public:
explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* context) override {
// Grab the input tensor
const Tensor& input_tensor = context->input(0);
auto input = input_tensor.flat();
// Create an output tensor
Tensor* output_tensor = NULL;
OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
&output_tensor));
auto output_flat = output_tensor->flat();
// Set all but the first element of the output tensor to 0.
const int N = input.size();
for (int i = 1; i < N; i++) {
output_flat(i) = 0;
}
// Preserve the first input value if possible.
if (N > 0) output_flat(0) = input(0);
}
};
//Register OP's Kernel
REGISTER_KERNEL_BUILDER(Name("ZeroOut").Device(DEVICE_CPU), ZeroOutOp);
#CMakeLists.txt
cmake_minimum_required(VERSION 3.1 FATAL_ERROR)
project(tfcustomop_project)
# Define our library target
add_library(tfcustomop SHARED TfCustomOp.cpp print_stack.cpp)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11 -D_GLIBCXX_USE_CXX11_ABI=1")
include_directories(/tensorflow/include)
# Link against LibTorch
#target_link_libraries(tfcustomop "/tensorflow/libtensorflow_framework.so.2")
从示例代码可以看到创建自定义算子OP有三步
这种用宏扩展自定义OP的方式非常类似于我上一篇文章中介绍的pytorch TORCH_LIBRARY 宏,实际上经过对TensorFlow框架代码的分析也确认两者的设计思想很接近(Google和Facebook里面都是编程高手 -v- );区别在于TensorFlow的设计中,算子和Kernel计算是两个分开的概念,其中算子的语义是平台无关的,而算子对应的Kernel计算则是跟平台(CPU、GPU、TPU)相关的。这样,就可以很方便的对语义不变的算子提供不同平台的具体实现了
经过cmake编译后生成了libtfcustomop.so,使用下面的python测试程序验证结果确认自定义算子和Kernel已经生效
import tensorflow as tf
mylib = tf.load_op_library("path/to/libtfcustomop.so")
with tf.compat.v1.Session():
print(mylib.zero_out([4,4,4,4,4]).eval())
[4,0,0,0,0]
在详细结果算子与Kernel的定义与注册原理前,概况相关核心的数据结构和概念如下
有了上面的基本知识,下面逐步分析TensorFlow框架中的实现原理
REGISTER_OP宏负责定义与注册用户的OP算子,它的定义在
g++ -I /tensorflow/include -std=c++11 -E TfCustomOp.cpp
static ::tensorflow::register_op::OpDefBuilderReceiver register_op0 __attribute__((unused)) = ::tensorflow::register_op::OpDefBuilderWrapper("ZeroOut")
.Input("to_zero: int32")
.Output("zeroed: int32")
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
c->set_output(0, c->input(0));
return Status::OK();
});
可以看到REGISTER_OP宏通过构建了static OpDefBuilderReceiver变量,实现了OP的定义与注册过程,整个过程的时序图如下
REGISTER_KERNEL_BUILDER宏负责自定义Kernel的注册,它的定义在
constexpr bool should_register_1__flag = true;
static ::tensorflow::kernel_factory::OpKernelRegistrar registrar__body__1__object( should_register_1__flag ? ::tensorflow::register_kernel::Name("ZeroOut").Device(DEVICE_CPU).Build() : nullptr,
"ZeroOutOp",
[](::tensorflow::OpKernelConstruction* context) -> ::tensorflow::OpKernel*
{ return new ZeroOutOp(context); });;
通过构建了static OpKernelRegistrar变量,实现了Kernel的注册过程,整个过程的时序图如下
经过自定义OP的注册,Kernel注册,Kernel Compute实现业务逻辑计算这三步后,就可以在python程序中使用自定义的OP算子了,核心的时序如下图