TensorFlow 源码阅读[1] OpKernel的注册

OpKernel介绍

TensorFlow 源码阅读[1] OpKernel的注册_第1张图片

在TF的架构中,OpKernel是Ops和硬件的中间层,用来抽象统一各个硬件平台上的Kernel类和接口。

注册过程

我们首先大致列出OpKernel注册的过程,后面再详细分析,我们按照调用顺序,从上层往下说:

  1. 在各个xxx_op.cc文件中调用REGISTER_KERNEL_BUILDER()
  2. 调用OpKernelRegistrar的构造函数
  3. 并在该构造函数中调用OpKernelRegistrar::InitInternal
  4. 调用GlobalKernelRegistry获取保存注册信息的map
  5. 将Key和kernel保存到map中

分析

现在我们来逐个分析,在上面我们是从调用过程往下走,在这里,我们尝试从底层往上走。

1.KernelRegistration

首先我们需要关注的是KernelRegistration类,它用来保存OpKernel注册所需的信息,包括KernelDef、kernel的名字以及kernel的创建方法factory:

struct KernelRegistration {
  KernelRegistration(const KernelDef& d, StringPiece c,
                     std::unique_ptr f)
      : def(d), kernel_class_name(c), factory(std::move(f)) {}

  const KernelDef def;
  const string kernel_class_name;
  std::unique_ptr factory;
};

2.KernelRegistry

这个结构体用来保存OpKernel的注册信息KernelRegistration,并将这些信息保存到一个unordered_multimap里:

struct KernelRegistry {
  mutex mu;
  std::unordered_multimap registry
      TF_GUARDED_BY(mu);
};

这个map维持一个Key到OpKernel注册信息之间的关系,而这个Key,是这样生成的:

const string key =
        Key(kernel_def->op(), DeviceType(kernel_def->device_type()),
            kernel_def->label());

既然是unordered_multimap,说明一个Key可以对应多个KernelRegistration。
KernelRegistry的实例是通过下面这个函数构造的:

void* GlobalKernelRegistry() {
  static KernelRegistry* global_kernel_registry = []() {
    KernelRegistry* registry = new KernelRegistry;
    OpRegistry::Global()->RegisterValidator(ValidateKernelRegistrations);
    return registry;
  }();
  return global_kernel_registry;
}

3.OpKernelRegistrar

上面我们提到了OpKernel需要保存的信息,以及这些信息是保存在一个unordered_multimap中的,下面我们要来看这个保存的过程。
我们首先来看这个类的构造函数:

// 构造函数1
OpKernelRegistrar(const KernelDef* kernel_def, StringPiece kernel_class_name,
                    std::unique_ptr<OpKernelFactory> factory) {
    // Perform the check in the header to allow compile-time optimization
    // to a no-op, allowing the linker to remove the kernel symbols.
    if (kernel_def != nullptr) {
      InitInternal(kernel_def, kernel_class_name, std::move(factory));
    }
  }

//构造函数2
OpKernelRegistrar(const KernelDef* kernel_def, StringPiece kernel_class_name,
                    OpKernel* (*create_fn)(OpKernelConstruction*)) {
    // Perform the check in the header to allow compile-time optimization
    // to a no-op, allowing the linker to remove the kernel symbols.
    if (kernel_def != nullptr) {
      InitInternal(kernel_def, kernel_class_name,
                   absl::make_unique<PtrOpKernelFactory>(create_fn));
    }
  }

这里涉及到另外一个类OpKernelFactory,我们也可以看下它的定义:

class OpKernelFactory {
 public:
  virtual OpKernel* Create(OpKernelConstruction* context) = 0;
  virtual ~OpKernelFactory() = default;
};

从这个类的create函数我们就可以看出,OpKernelRegistrar的亮哥构造函数其实大同小异,第一个参数是kernel_del,第二个参数是kernel_class_name,第三个参数都是创建这个kernel的函数。
我们来看一下OpKernelRegistrar构造函数的核心部分:

void OpKernelRegistrar::InitInternal(const KernelDef* kernel_def,
                                     StringPiece kernel_class_name,
                                     std::unique_ptr factory) {
  // See comments in register_kernel::Name in header for info on _no_register.
  if (kernel_def->op() != "_no_register") {
    const string key =
        Key(kernel_def->op(), DeviceType(kernel_def->device_type()),
            kernel_def->label());

	auto global_registry =
	        reinterpret_cast(GlobalKernelRegistry());
	    mutex_lock l(global_registry->mu);
	    global_registry->registry.emplace(
	        key,
	        KernelRegistration(*kernel_def, kernel_class_name, std::move(factory)));
	}
}

这个GlobalKernelRegistry我们之前已经说过了,它返回的是一个KernelRegistry实例,global_registry->registry 就是我们之前说的保存注册信息的map,也就是说,OpKernel的注册发生在OpKernelRegistrar的构造函数中!
我们顺藤摸瓜,看看这个构造函数是怎么被调用的。

4. REGISTER_KERNEL_BUILDER

OpKernelRegistrar的构造就是在REGISTER_KERNEL_BUILDER宏定义中:

#define REGISTER_KERNEL_BUILDER(kernel_builder, ...) \
  REGISTER_KERNEL_BUILDER_UNIQ_HELPER(__COUNTER__, kernel_builder, __VA_ARGS__)

#define REGISTER_KERNEL_BUILDER_UNIQ_HELPER(ctr, kernel_builder, ...) \
  REGISTER_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, __VA_ARGS__)

#define REGISTER_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, ...)        \
  constexpr bool should_register_##ctr##__flag =                      \
      SHOULD_REGISTER_OP_KERNEL(#__VA_ARGS__);                        \
  static ::tensorflow::kernel_factory::OpKernelRegistrar              \
      registrar__body__##ctr##__object(                               \
          should_register_##ctr##__flag                               \
              ? ::tensorflow::register_kernel::kernel_builder.Build() \
              : nullptr,                                              \
          #__VA_ARGS__,                                               \
          [](::tensorflow::OpKernelConstruction* context)             \
              -> ::tensorflow::OpKernel* {                            \
            return new __VA_ARGS__(context);                          \
          });

宏定义理解起来往往比较麻烦,不要着急,我们一个个看。

首先做一些宏定义知识的补充,可能不是所有人都清楚(比如我-_-!):

__COUNTER__ 可以理解为一个int型计数器,初始值为0,每出现一次,值+1
#x 将x转换成一个字符串
##ctr 变量拼接,就是将ctr的值拼接到整个变量中
__VA_ARGS__可变参数

有了上面这些知识,我们再来看这些宏就没这么复杂了:

  1. 首先REGISTER_KERNEL_BUILDER接受两个参数,一个是kernel_builder,另一个是可变参数;
  2. 将这两个参数传给REGISTER_KERNEL_BUILDER_UNIQ_HELPER,而这个宏在前面的宏的基础上,增加了一个计数器,并将这三个参数传给下一个定义的宏
  3. REGISTER_KERNEL_BUILDER_UNIQ接受了这三个参数,然后定义一个临时变量should_register_##ctr##__flag,根据我们上面宏定义的知识,ctr和flag的值都会拼接到register_后面,而这个bool值的结果是SHOULD_REGISTER_OP_KERNEL(#_VA_ARGS),看字面意思就可以理解为是否需要注册这个OpKernel;然后定义了一个static的OpKernelRegistrar变量registrar__body__##ctr##__object,且调用了OpKernelRegistrar的第二类构造函数:

至此我们找到了构造OpKernelRegistrar的地方,也就是说每次使用宏REGISTER_KERNEL_BUILDER注册OpKernel,都会调用OpKernelRegistrar并将对应的Kernel信息存到map中。

  1. 我们看一下OpKernelRegistrar构造函数的参数:

1)should_register_##ctr##__flag ? ::tensorflow::register_kernel::kernel_builder.Build() : nullptr 也就是说如果需要创建这个OpKernel,就传入::tensorflow::register_kernel::kernel_builder.Build()这个参数的值我们后面会介绍,根据构造函数的三个参数,我们暂时只需要知道这一长串会返回一个KernelDef对象
2) #__VA_ARGS__ 第二个参数是可变参数变成的字符串,也就是kernel_class_name
3)[](::tensorflow::OpKernelConstruction* context) -> ::tensorflow::OpKernel* { return new __VA_ARGS__(context);这是一个lamda表达式函数,入参数OpKernelConstruction* context,返回类型是OpKernel*,这个函数指针本身构成了第三个参数,即OpKernel* (*create_fn)(OpKernelConstruction*)

到此我们应该理解了这个复杂的宏REGISTER_KERNEL_BUILDER,只需要正确使用这个宏,就可以注册一个OpKernel!!!
遗留了一个问题,就是为什么这个kernel_builder.Build(),就相当于是KernelDef对象呢?

5.如何使用这个宏?

我们看一下官方的例子:

REGISTER_KERNEL_BUILDER(Name("Test1").Device(tensorflow::DEVICE_CPU),DummyKernel);

这里我们看到第一个参数是Name("Test1").Device(tensorflow::DEVICE_CPU)这个东西为什么就是KernelDef呢?我们看一下这个Name究竟是什么,说实话这个类不太好找:


class Name : public KernelDefBuilder {
 public:
  explicit Name(const char* op)
      : KernelDefBuilder(SHOULD_REGISTER_OP(op) ? op : "_no_register") {}
};

原来这个Name类是继承自KernelDefBuilder类,且在它的构造函数中,调用了基类的构造函数,传入的是op的名字,我们再来看一下这个基类:

class KernelDefBuilder {
 public:
  explicit KernelDefBuilder(const char* op_name);
  ~KernelDefBuilder();
  KernelDefBuilder& Device(const char* device_type);
  template 
  KernelDefBuilder& AttrConstraint(const char* attr_name, gtl::ArraySlice allowed);
  template 
  KernelDefBuilder& AttrConstraint(const char* attr_name, T allowed);
  KernelDefBuilder& TypeConstraint(const char* attr_name,
gtl::ArraySlice allowed);
  KernelDefBuilder& TypeConstraint(const char* attr_name, DataType allowed);
  template 
  KernelDefBuilder& TypeConstraint(const char* attr_name);
  KernelDefBuilder& HostMemory(const char* arg_name);
  KernelDefBuilder& Label(const char* label);
  KernelDefBuilder& Priority(int32 priority);
  const KernelDef* Build();
 private:
  KernelDef* kernel_def_;
  TF_DISALLOW_COPY_AND_ASSIGN(KernelDefBuilder);
};

基类KernelDefBuilder也接受一个op_name作为构造参数,且我们现在可以看到,刚才Name(“Test1”)后面的.Device()实际上就是KernelDefBuilder的成员函数,返回的是KernelDefBuilder&类型。

在得到这个KernelDefBuilder&类型的返回值后,在通过调用kernel_builder.Build()方法,就得到了const KernelDef* 类型的返回值,这就回答了我们刚才的问题!

总结

我们花了很久的时间,就是为了搞清楚TF究竟是如何设计和实现Opkernel的注册的。我们先是简单介绍了从调用到底层实现,然后详细的从底层开始分析了每一步的实现。不得不说TF这一套东西很复杂,但是只要多看两遍,也可以理解。

对于OpKernel类来说,往下有它自身的数据类和数据管理类,以及构造辅助类,往上被封装到一个宏定义中,在后面说到Op的时候,会发现整体思路和OpKernel十分相似,所以理解其中一个,另一个理解起来是水到渠成。

参考

  1. TF源码
  2. 『深度长文』Tensorflow代码解析(三)

你可能感兴趣的:(TensorFlow,tensorflow,OpKernel,源码,注册)