在TF的架构中,OpKernel是Ops和硬件的中间层,用来抽象统一各个硬件平台上的Kernel类和接口。
我们首先大致列出OpKernel注册的过程,后面再详细分析,我们按照调用顺序,从上层往下说:
xxx_op.cc
文件中调用REGISTER_KERNEL_BUILDER()
OpKernelRegistrar
的构造函数OpKernelRegistrar::InitInternal
GlobalKernelRegistry
获取保存注册信息的map现在我们来逐个分析,在上面我们是从调用过程往下走,在这里,我们尝试从底层往上走。
首先我们需要关注的是KernelRegistration类,它用来保存OpKernel注册所需的信息,包括KernelDef、kernel的名字以及kernel的创建方法factory:
struct KernelRegistration {
KernelRegistration(const KernelDef& d, StringPiece c,
std::unique_ptr f)
: def(d), kernel_class_name(c), factory(std::move(f)) {}
const KernelDef def;
const string kernel_class_name;
std::unique_ptr factory;
};
这个结构体用来保存OpKernel的注册信息KernelRegistration,并将这些信息保存到一个unordered_multimap里:
struct KernelRegistry {
mutex mu;
std::unordered_multimap registry
TF_GUARDED_BY(mu);
};
这个map维持一个Key到OpKernel注册信息之间的关系,而这个Key,是这样生成的:
const string key =
Key(kernel_def->op(), DeviceType(kernel_def->device_type()),
kernel_def->label());
既然是unordered_multimap,说明一个Key可以对应多个KernelRegistration。
KernelRegistry的实例是通过下面这个函数构造的:
void* GlobalKernelRegistry() {
static KernelRegistry* global_kernel_registry = []() {
KernelRegistry* registry = new KernelRegistry;
OpRegistry::Global()->RegisterValidator(ValidateKernelRegistrations);
return registry;
}();
return global_kernel_registry;
}
上面我们提到了OpKernel需要保存的信息,以及这些信息是保存在一个unordered_multimap中的,下面我们要来看这个保存的过程。
我们首先来看这个类的构造函数:
// 构造函数1
OpKernelRegistrar(const KernelDef* kernel_def, StringPiece kernel_class_name,
std::unique_ptr<OpKernelFactory> factory) {
// Perform the check in the header to allow compile-time optimization
// to a no-op, allowing the linker to remove the kernel symbols.
if (kernel_def != nullptr) {
InitInternal(kernel_def, kernel_class_name, std::move(factory));
}
}
//构造函数2
OpKernelRegistrar(const KernelDef* kernel_def, StringPiece kernel_class_name,
OpKernel* (*create_fn)(OpKernelConstruction*)) {
// Perform the check in the header to allow compile-time optimization
// to a no-op, allowing the linker to remove the kernel symbols.
if (kernel_def != nullptr) {
InitInternal(kernel_def, kernel_class_name,
absl::make_unique<PtrOpKernelFactory>(create_fn));
}
}
这里涉及到另外一个类OpKernelFactory,我们也可以看下它的定义:
class OpKernelFactory {
public:
virtual OpKernel* Create(OpKernelConstruction* context) = 0;
virtual ~OpKernelFactory() = default;
};
从这个类的create函数我们就可以看出,OpKernelRegistrar的亮哥构造函数其实大同小异,第一个参数是kernel_del,第二个参数是kernel_class_name,第三个参数都是创建这个kernel的函数。
我们来看一下OpKernelRegistrar构造函数的核心部分:
void OpKernelRegistrar::InitInternal(const KernelDef* kernel_def,
StringPiece kernel_class_name,
std::unique_ptr factory) {
// See comments in register_kernel::Name in header for info on _no_register.
if (kernel_def->op() != "_no_register") {
const string key =
Key(kernel_def->op(), DeviceType(kernel_def->device_type()),
kernel_def->label());
auto global_registry =
reinterpret_cast(GlobalKernelRegistry());
mutex_lock l(global_registry->mu);
global_registry->registry.emplace(
key,
KernelRegistration(*kernel_def, kernel_class_name, std::move(factory)));
}
}
这个GlobalKernelRegistry我们之前已经说过了,它返回的是一个KernelRegistry实例,global_registry->registry 就是我们之前说的保存注册信息的map,也就是说,OpKernel的注册发生在OpKernelRegistrar的构造函数中!
我们顺藤摸瓜,看看这个构造函数是怎么被调用的。
OpKernelRegistrar的构造就是在REGISTER_KERNEL_BUILDER宏定义中:
#define REGISTER_KERNEL_BUILDER(kernel_builder, ...) \
REGISTER_KERNEL_BUILDER_UNIQ_HELPER(__COUNTER__, kernel_builder, __VA_ARGS__)
#define REGISTER_KERNEL_BUILDER_UNIQ_HELPER(ctr, kernel_builder, ...) \
REGISTER_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, __VA_ARGS__)
#define REGISTER_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, ...) \
constexpr bool should_register_##ctr##__flag = \
SHOULD_REGISTER_OP_KERNEL(#__VA_ARGS__); \
static ::tensorflow::kernel_factory::OpKernelRegistrar \
registrar__body__##ctr##__object( \
should_register_##ctr##__flag \
? ::tensorflow::register_kernel::kernel_builder.Build() \
: nullptr, \
#__VA_ARGS__, \
[](::tensorflow::OpKernelConstruction* context) \
-> ::tensorflow::OpKernel* { \
return new __VA_ARGS__(context); \
});
宏定义理解起来往往比较麻烦,不要着急,我们一个个看。
首先做一些宏定义知识的补充,可能不是所有人都清楚(比如我-_-!):
__COUNTER__
可以理解为一个int型计数器,初始值为0,每出现一次,值+1
#x
将x转换成一个字符串
##ctr
变量拼接,就是将ctr的值拼接到整个变量中
__VA_ARGS__
可变参数
有了上面这些知识,我们再来看这些宏就没这么复杂了:
should_register_##ctr##__flag
,根据我们上面宏定义的知识,ctr和flag的值都会拼接到register_后面,而这个bool值的结果是SHOULD_REGISTER_OP_KERNEL(#_VA_ARGS),看字面意思就可以理解为是否需要注册这个OpKernel;然后定义了一个static的OpKernelRegistrar变量registrar__body__##ctr##__object
,且调用了OpKernelRegistrar的第二类构造函数:至此我们找到了构造OpKernelRegistrar的地方,也就是说每次使用宏REGISTER_KERNEL_BUILDER注册OpKernel,都会调用OpKernelRegistrar并将对应的Kernel信息存到map中。
1)should_register_##ctr##__flag ? ::tensorflow::register_kernel::kernel_builder.Build() : nullptr 也就是说如果需要创建这个OpKernel,就传入
::tensorflow::register_kernel::kernel_builder.Build()
这个参数的值我们后面会介绍,根据构造函数的三个参数,我们暂时只需要知道这一长串会返回一个KernelDef对象
2)#__VA_ARGS__
第二个参数是可变参数变成的字符串,也就是kernel_class_name
3)[](::tensorflow::OpKernelConstruction* context) -> ::tensorflow::OpKernel* { return new __VA_ARGS__(context);
这是一个lamda表达式函数,入参数OpKernelConstruction* context
,返回类型是OpKernel*,这个函数指针本身构成了第三个参数,即OpKernel* (*create_fn)(OpKernelConstruction*)
到此我们应该理解了这个复杂的宏REGISTER_KERNEL_BUILDER
,只需要正确使用这个宏,就可以注册一个OpKernel!!!
遗留了一个问题,就是为什么这个kernel_builder.Build()
,就相当于是KernelDef对象呢?
我们看一下官方的例子:
REGISTER_KERNEL_BUILDER(Name("Test1").Device(tensorflow::DEVICE_CPU),DummyKernel);
这里我们看到第一个参数是Name("Test1").Device(tensorflow::DEVICE_CPU)
这个东西为什么就是KernelDef呢?我们看一下这个Name究竟是什么,说实话这个类不太好找:
class Name : public KernelDefBuilder {
public:
explicit Name(const char* op)
: KernelDefBuilder(SHOULD_REGISTER_OP(op) ? op : "_no_register") {}
};
原来这个Name类是继承自KernelDefBuilder类,且在它的构造函数中,调用了基类的构造函数,传入的是op的名字,我们再来看一下这个基类:
class KernelDefBuilder {
public:
explicit KernelDefBuilder(const char* op_name);
~KernelDefBuilder();
KernelDefBuilder& Device(const char* device_type);
template
KernelDefBuilder& AttrConstraint(const char* attr_name, gtl::ArraySlice allowed);
template
KernelDefBuilder& AttrConstraint(const char* attr_name, T allowed);
KernelDefBuilder& TypeConstraint(const char* attr_name,
gtl::ArraySlice allowed);
KernelDefBuilder& TypeConstraint(const char* attr_name, DataType allowed);
template
KernelDefBuilder& TypeConstraint(const char* attr_name);
KernelDefBuilder& HostMemory(const char* arg_name);
KernelDefBuilder& Label(const char* label);
KernelDefBuilder& Priority(int32 priority);
const KernelDef* Build();
private:
KernelDef* kernel_def_;
TF_DISALLOW_COPY_AND_ASSIGN(KernelDefBuilder);
};
基类KernelDefBuilder也接受一个op_name作为构造参数,且我们现在可以看到,刚才Name(“Test1”)后面的.Device()实际上就是KernelDefBuilder的成员函数,返回的是KernelDefBuilder&类型。
在得到这个KernelDefBuilder&类型的返回值后,在通过调用kernel_builder.Build()方法,就得到了const KernelDef*
类型的返回值,这就回答了我们刚才的问题!
我们花了很久的时间,就是为了搞清楚TF究竟是如何设计和实现Opkernel的注册的。我们先是简单介绍了从调用到底层实现,然后详细的从底层开始分析了每一步的实现。不得不说TF这一套东西很复杂,但是只要多看两遍,也可以理解。
对于OpKernel类来说,往下有它自身的数据类和数据管理类,以及构造辅助类,往上被封装到一个宏定义中,在后面说到Op的时候,会发现整体思路和OpKernel十分相似,所以理解其中一个,另一个理解起来是水到渠成。