tensorflow源码阅读-opkernel注册

//tensorflow/core/framework/kernel_def_builder.h
#define REGISTER_KERNEL_BUILDER(kernel_builder, ...) \
  REGISTER_KERNEL_BUILDER_UNIQ_HELPER(__COUNTER__, kernel_builder, __VA_ARGS__)

#define REGISTER_KERNEL_BUILDER_UNIQ_HELPER(ctr, kernel_builder, ...) \
  REGISTER_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, __VA_ARGS__)

#define REGISTER_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, ...)        \
  constexpr bool should_register_##ctr##__flag =                      \
      SHOULD_REGISTER_OP_KERNEL(#__VA_ARGS__);                        \
  static ::tensorflow::kernel_factory::OpKernelRegistrar              \
      registrar__body__##ctr##__object(                               \
          should_register_##ctr##__flag                               \
              ? ::tensorflow::register_kernel::kernel_builder.Build() \
              : nullptr,                                              \
          #__VA_ARGS__,                                               \
          [](::tensorflow::OpKernelConstruction* context)             \
              -> ::tensorflow::OpKernel* {                            \
            return new __VA_ARGS__(context);                          \
          });

opkernel的注册, REGISTER_KERNEL_BUILDER()里面有两个参数。__VA_ARGS__表示不定参数。
REGISTER_KERNEL_BUILDER宏会调用OpKernelRegistrar,它有三个参数。第一个为kernel_builder,返回的是一个KernelDef指针,第二个是opkernel名字,#__VA_ARGS__表示把传入的参数变成字符串,第三个参数是创建opkernel的对象(被包装成Factory )

//tensorflow/core/framework/op_kernel.h
class OpKernelRegistrar {
public:
 typedef OpKernel* (*Factory)(OpKernelConstruction*);

 OpKernelRegistrar(const KernelDef* kernel_def, StringPiece kernel_class_name,
                   Factory factory) {
   if (kernel_def != nullptr) {
     InitInternal(kernel_def, kernel_class_name, factory);
   }
 }
};

OpKernelRegistrar的第一个参数

kernel_def从REGISTER_KERNEL_BUILDER传入:

//tensorflow/core/kernels/matmul_op.cc
#define REGISTER_CPU_EIGEN(T)                                                  \
  REGISTER_KERNEL_BUILDER(                                                     \
      Name("MatMul").Device(DEVICE_CPU).TypeConstraint("T").Label("eigen"), \
      MatMulOp/* cublas, ignored for CPU */>);

Name(“MatMul”).Device(DEVICE_CPU).TypeConstraint(“T”).Label(“eigen”),它会先调用Name(“MatMul”),详细如下:

//tensorflow/core/framework/op_kernel.h
class Name : public KernelDefBuilder {
 public:
  explicit Name(const char* op)
      : KernelDefBuilder(SHOULD_REGISTER_OP(op) ? op : "_no_register") {}
};

类Name继承KernelDefBuilder 并调用KernelDefBuilder构造方法:

//tensorflow/core/framework/kernel_def_builder.cc
KernelDefBuilder::KernelDefBuilder(const char* op_name) {
  kernel_def_ = new KernelDef;
  kernel_def_->set_op(op_name);
}

KernelDefBuilder构造方法创建一个KernelDef对象,并设置op_name。这里的KernelDef是通过brotobuf(tensorflow/core/framework/kernel_def.proto)生成的类,里面定义了opkernel的属性。

既然Name继承了KernelDefBuilder,那Name对象具有KernelDefBuilder的所有方法,因此Name(“MatMul”).Device(DEVICE_CPU)则是设置设备的类型,并返回对象本身。接着串联地设置opkernel属性。

//tensorflow/core/framework/kernel_def_builder.cc
KernelDefBuilder& KernelDefBuilder::Device(const char* device_type) {
  kernel_def_->set_device_type(device_type);
  return *this;
}

最后调用Build()方法,将KernelDef指针返回:

//tensorflow/core/framework/kernel_def_builder.cc
const KernelDef* KernelDefBuilder::Build() {
  KernelDef* r = kernel_def_;
  kernel_def_ = nullptr;
  return r;
}

OpKernelRegistrar的第二个参数

//tensorflow/core/kernels/matmul_op.cc
#define REGISTER_CPU_EIGEN(T)                                                  \
  REGISTER_KERNEL_BUILDER(                                                     \
      Name("MatMul").Device(DEVICE_CPU).TypeConstraint("T").Label("eigen"), \
      MatMulOp/* cublas, ignored for CPU */>);

REGISTER_KERNEL_BUILDER的第二个参数MatMulOp)传入到OpKernelRegistrar,这里的#__VA_ARGS__表示把传入的参数变成字符串,也就是说,第二个参数是opkernel的名字。

OpKernelRegistrar的第三个参数

它是将 MatMulOp/>通过lamda表达式,创建一个opkernel对象并返回。 typedef OpKernel (Factory)(OpKernelConstruction);被包装成Factory

opkernel注册

//tensorflow/core/framework/op_kernel.h
class OpKernelRegistrar {
public:
 typedef OpKernel* (*Factory)(OpKernelConstruction*);

 OpKernelRegistrar(const KernelDef* kernel_def, StringPiece kernel_class_name,
                   Factory factory) {
   if (kernel_def != nullptr) {
     InitInternal(kernel_def, kernel_class_name, factory);
   }
 }
};

OpKernelRegistrar 会调用InitInternal,具体如下:

//tensorflow/core/framework/op_kernel.cc
void OpKernelRegistrar::InitInternal(const KernelDef* kernel_def,
                                    StringPiece kernel_class_name,
                                    Factory factory) {
 // See comments in register_kernel::Name in header for info on _no_register.
 if (kernel_def->op() != "_no_register") {
   const string key =
       Key(kernel_def->op(), DeviceType(kernel_def->device_type()),
           kernel_def->label());
   GlobalKernelRegistryTyped()->insert(std::make_pair(
       key, KernelRegistration(*kernel_def, kernel_class_name, factory)));
 }
 delete kernel_def;
}
struct KernelRegistration {
  KernelRegistration(const KernelDef& d, StringPiece c,
                     kernel_factory::OpKernelRegistrar::Factory f)
      : def(d), kernel_class_name(c), factory(f) {}
  const KernelDef def;
  const string kernel_class_name;
  const kernel_factory::OpKernelRegistrar::Factory factory;
};

可以看到OpKernelRegistrar这个类主要是负责根据传进来的KernelDef和KernelFactory,首先依据一定规则生成一个适当的key,并插入到一个全局唯一的Kernel注册表里,注册表当然是一个map但是值得注意的是它是multimap因此支持一个键对应多个kernel副本。

typedef std::unordered_multimap<string, KernelRegistration> KernelRegistry;

你可能感兴趣的:(c++,c)