//tensorflow/core/framework/kernel_def_builder.h
#define REGISTER_KERNEL_BUILDER(kernel_builder, ...) \
REGISTER_KERNEL_BUILDER_UNIQ_HELPER(__COUNTER__, kernel_builder, __VA_ARGS__)
#define REGISTER_KERNEL_BUILDER_UNIQ_HELPER(ctr, kernel_builder, ...) \
REGISTER_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, __VA_ARGS__)
#define REGISTER_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, ...) \
constexpr bool should_register_##ctr##__flag = \
SHOULD_REGISTER_OP_KERNEL(#__VA_ARGS__); \
static ::tensorflow::kernel_factory::OpKernelRegistrar \
registrar__body__##ctr##__object( \
should_register_##ctr##__flag \
? ::tensorflow::register_kernel::kernel_builder.Build() \
: nullptr, \
#__VA_ARGS__, \
[](::tensorflow::OpKernelConstruction* context) \
-> ::tensorflow::OpKernel* { \
return new __VA_ARGS__(context); \
});
opkernel的注册, REGISTER_KERNEL_BUILDER()里面有两个参数。__VA_ARGS__表示不定参数。
REGISTER_KERNEL_BUILDER宏会调用OpKernelRegistrar,它有三个参数。第一个为kernel_builder,返回的是一个KernelDef指针,第二个是opkernel名字,#__VA_ARGS__表示把传入的参数变成字符串,第三个参数是创建opkernel的对象(被包装成Factory )
//tensorflow/core/framework/op_kernel.h
class OpKernelRegistrar {
public:
typedef OpKernel* (*Factory)(OpKernelConstruction*);
OpKernelRegistrar(const KernelDef* kernel_def, StringPiece kernel_class_name,
Factory factory) {
if (kernel_def != nullptr) {
InitInternal(kernel_def, kernel_class_name, factory);
}
}
};
kernel_def从REGISTER_KERNEL_BUILDER传入:
//tensorflow/core/kernels/matmul_op.cc
#define REGISTER_CPU_EIGEN(T) \
REGISTER_KERNEL_BUILDER( \
Name("MatMul").Device(DEVICE_CPU).TypeConstraint("T").Label("eigen"), \
MatMulOp/* cublas, ignored for CPU */>);
Name(“MatMul”).Device(DEVICE_CPU).TypeConstraint(“T”).Label(“eigen”),它会先调用Name(“MatMul”),详细如下:
//tensorflow/core/framework/op_kernel.h
class Name : public KernelDefBuilder {
public:
explicit Name(const char* op)
: KernelDefBuilder(SHOULD_REGISTER_OP(op) ? op : "_no_register") {}
};
类Name继承KernelDefBuilder 并调用KernelDefBuilder构造方法:
//tensorflow/core/framework/kernel_def_builder.cc
KernelDefBuilder::KernelDefBuilder(const char* op_name) {
kernel_def_ = new KernelDef;
kernel_def_->set_op(op_name);
}
KernelDefBuilder构造方法创建一个KernelDef对象,并设置op_name。这里的KernelDef是通过brotobuf(tensorflow/core/framework/kernel_def.proto)生成的类,里面定义了opkernel的属性。
既然Name继承了KernelDefBuilder,那Name对象具有KernelDefBuilder的所有方法,因此Name(“MatMul”).Device(DEVICE_CPU)则是设置设备的类型,并返回对象本身。接着串联地设置opkernel属性。
//tensorflow/core/framework/kernel_def_builder.cc
KernelDefBuilder& KernelDefBuilder::Device(const char* device_type) {
kernel_def_->set_device_type(device_type);
return *this;
}
最后调用Build()方法,将KernelDef指针返回:
//tensorflow/core/framework/kernel_def_builder.cc
const KernelDef* KernelDefBuilder::Build() {
KernelDef* r = kernel_def_;
kernel_def_ = nullptr;
return r;
}
//tensorflow/core/kernels/matmul_op.cc
#define REGISTER_CPU_EIGEN(T) \
REGISTER_KERNEL_BUILDER( \
Name("MatMul").Device(DEVICE_CPU).TypeConstraint("T").Label("eigen"), \
MatMulOp/* cublas, ignored for CPU */>);
REGISTER_KERNEL_BUILDER的第二个参数MatMulOp
它是将 MatMulOp
//tensorflow/core/framework/op_kernel.h
class OpKernelRegistrar {
public:
typedef OpKernel* (*Factory)(OpKernelConstruction*);
OpKernelRegistrar(const KernelDef* kernel_def, StringPiece kernel_class_name,
Factory factory) {
if (kernel_def != nullptr) {
InitInternal(kernel_def, kernel_class_name, factory);
}
}
};
OpKernelRegistrar 会调用InitInternal,具体如下:
//tensorflow/core/framework/op_kernel.cc
void OpKernelRegistrar::InitInternal(const KernelDef* kernel_def,
StringPiece kernel_class_name,
Factory factory) {
// See comments in register_kernel::Name in header for info on _no_register.
if (kernel_def->op() != "_no_register") {
const string key =
Key(kernel_def->op(), DeviceType(kernel_def->device_type()),
kernel_def->label());
GlobalKernelRegistryTyped()->insert(std::make_pair(
key, KernelRegistration(*kernel_def, kernel_class_name, factory)));
}
delete kernel_def;
}
struct KernelRegistration {
KernelRegistration(const KernelDef& d, StringPiece c,
kernel_factory::OpKernelRegistrar::Factory f)
: def(d), kernel_class_name(c), factory(f) {}
const KernelDef def;
const string kernel_class_name;
const kernel_factory::OpKernelRegistrar::Factory factory;
};
可以看到OpKernelRegistrar这个类主要是负责根据传进来的KernelDef和KernelFactory,首先依据一定规则生成一个适当的key,并插入到一个全局唯一的Kernel注册表里,注册表当然是一个map但是值得注意的是它是multimap因此支持一个键对应多个kernel副本。
typedef std::unordered_multimap<string, KernelRegistration> KernelRegistry;