一、前言
TVM通过PackedFunc机制实现了Python与C++之间的相互调用,即FFI(Foreign Function Interface),整体流程与原理可以参阅系列文章《【TVM系列六】PackedFunc原理》。本文将对TVM的FFI注册机制进行详细的说明,并实现一个Python端的注册中心Demo。
二、什么是FFI
FFI(Foreign Function Interface)即跨语言交互接口,比如Java中的JNI,用于实现在语言A中调用语言B的函数。通常情况下,要达到这样的目的,一种方式是将函数做成一个服务,通过进程间通信或网络协议通信,这种方式需要至少两个独立的进程才能实现;另一种就是通过 FFI 直接将其它语言的接口内嵌到本语言中,实现高效率的调用。
三、TVM的FFI实现
FFI部分的实现代码主要集中在python/tvm/_ffi/目录,以及各个目录下的_ffi_api.py文件,以下的分析均以python/tvm/runtime内的注册为例子进行说明。
1、如何注册C++端接口到Python端
runtime中的_ffi_api.py:
import tvm._ffi tvm._ffi._init_api("runtime", __name__)
# 此处的__name__ = "tvm.runtime._ffi_api"
_ffi中的registry.py:
def _init_api(namespace, target_module_name=None):
target_module_name = target_module_name if target_module_name else namespace
# 此处 target_module_name = "tvm.runtime._ffi_api",namespace = "runtime"
if namespace.startswith("tvm."):
_init_api_prefix(target_module_name, namespace[4:])
else:
_init_api_prefix(target_module_name, namespace) # 代码走这个分支
def _init_api_prefix(module_name, prefix):
module = sys.modules[module_name] # module_name = "tvm.runtime._ffi_api", prefix = "runtime"
# 每当导入新模块,全局字典sys.modules将记录该模块,当再次导入该模块时,
# 会直接到字典中查找,从而加快程序运行的速度
for name in list_global_func_names():
# list_global_func_names()通过_LIB.TVMFuncListGlobalNames()得到函数列表,打印得到的内容为: # ['tvm.contrib.sort.argsort', ... , 'runtime.module.loadbinary_AotExecutorFactory', ...
# 'tvm.graph_executor_debug.create', 'runtime.module.loadbinary_GraphRuntimeFactory', ...
# 'runtime.RPCTimeEvaluator', 'topi.nn.nll_loss', 'topi.argmax', ...]
if not name.startswith(prefix):
continue
# 过滤掉名称不是以prefix = "runtime" 开始的函数
fname = name[len(prefix) + 1 :]
# 去掉前缀,如果name = 'runtime.module.loadbinary_AotExecutorFactory',
# 则fname = 'module.loadbinary_AotExecutorFactory'
target_module = module # 此时target_module = sys.modules["tvm.runtime._ffi_api"]
if fname.find(".") != -1: # # 过滤掉非二级模块如: 'runtime.RPCTimeEvaluator'
continue*
f = get_global_func(name) # name = runtime.RPCTimeEvaluator, f 是 C++ 端实现的PackedFunc handle
ff = _get_api(f) # 将f的is_global置为true并赋值给ff
ff.__name__ = fname # ff的__name__属性设置为: 比如module.loadbinary_AotExecutorFactory
ff.__doc__ = "TVM PackedFunc %s. " % fname # 设置ff的__doc__属性
setattr(target_module, ff.__name__, ff)
# 通过setattr为target_module = sys.modules["tvm.runtime._ffi_api"]添加函数,
# 此时就完成了C++实现的函数接口加到python的模块,然后在python端使用的时候就可以
# 在runtime的其它文件中通过from . import _ffi_api导入这些函数接口。
2、如何注册Python端接口到C++端
TVM利用Python的装饰器以及PackedFunc机制实现了插件式的类型对象与全局函数的注册方式:
- 注册对象类型
# 在_ffi/_ctypes/object.py中定义全局的字典
OBJECT_TYPE = {}
def _register_object(index, cls):
if issubclass(cls, NDArrayBase):
_register_ndarray(index, cls)
return
OBJECT_TYPE[index] = cls # 全局字典用于保存注册的类型
# 在_ffi/registry.py定义类型注册装饰器
def register_object(type_key=None):
"""
Examples
# 这里的装饰器相当于 MyObject = tvm.register_object("test.MyObject")(MyObject),分两步执行:
# (1)调用tvm.register_object("test.MyObject"),返回内部的register函数指针
# (2)以MyObject为参数调用内部的register(MyObject)函数
@tvm.register_object("test.MyObject")
class MyObject(Object):
pass
# 这里的装饰器相当于 OtherObject = tvm.register_object(OtherObject),由于传入的不是字符串,所以只执行:
# 以OtherObject为入参调用内部的register()函数
@tvm.register_object
class OtherObject(Object):
pass
"""
# 作为装饰器时可以指定type_key字符串, 也可以不指定,此时会使用它所装饰的类名称 type_key.__name__ 作为它的名称
object_name = type_key if isinstance(type_key, str) else type_key.__name__
def register(cls):
"""internal register function"""
if hasattr(cls, "_type_index"):
tindex = cls._type_index
else:
tidx = ctypes.c_uint()
if not _RUNTIME_ONLY:
check_call(_LIB.TVMObjectTypeKey2Index(c_str(object_name), ctypes.byref(tidx)))
else:
# directly skip unknown objects during runtime.
ret = _LIB.TVMObjectTypeKey2Index(c_str(object_name), ctypes.byref(tidx))
if ret != 0:
return cls
tindex = tidx.value
# 根据object_name找到对应的tindex, 作为键值, 最终注册在OBJECT_TYPE的条目是{tindex: cls}
_register_object(tindex, cls)
return cls
# 如果是字符串,返回的是register函数指针
if isinstance(type_key, str):
return register
# 否则以type_key为参数调用register
return register(type_key)
- 注册全局函数
def register_func(func_name, f=None, override=False):
"""Register global function
Examples
.. code-block:: python
# 这里的装饰器相当于 my_func = tvm.register_func("my_func")(my_func),分两步执行:
# (1)调用tvm.register_func("my_func"),返回内部的register函数指针
# (2)以my_func为参数调用内部的register(my_func)函数
@tvm.register_func("my_func")
def my_func():
pass
# 这里的装饰器相当于 other_func = tvm.register_func(other_func),由于传入的不是字符串,所以只执行:
# 以other_func为入参调用内部的register(other_func)函数
@tvm.register_func
def other_func():
pass
"""
# 如果是没指定字符名称的方式: @tvm.register_func
if callable(func_name):
f = func_name # func_name为函数,赋值给f
func_name = f.__name__ # 使用函数的__name__作为函数字符名称
if not isinstance(func_name, str):
raise ValueError("expect string function name")
ioverride = ctypes.c_int(override)
def register(myf):
"""internal register function"""
if not isinstance(myf, PackedFuncBase):
myf = convert_to_tvm_func(myf)
check_call(_LIB.TVMFuncRegisterGlobal(c_str(func_name), myf.handle, ioverride))
return myf
# 没有指定字符名称的方式: @tvm.register_func,
# 将函数f作为参数传入内部register()
if f:
return register(f)
# 指定字符名称@tvm.register_func("my_func")的方式,
# 返回内部register函数指针
return register
看个实际的例子,在relay/build_module.py中:
@register_func("tvm.relay.build")
def _build_module_no_factory_impl(mod, target, target_host, params, mod_name):
return build(
mod, target=target, target_host=target_host, params=params, mod_name=mod_name
).module
# 实际调用过程为:
# (1) register_func("tvm.relay.build")返回register函数指针
# (2) register(_build_module_no_factory_impl):
# 首先通过convert_to_tvm_func(_build_module_no_factory_impl)转换得到PackedFunc;
# 然后通过_LIB.TVMFuncRegisterGlobal(c_str(func_name), myf.handle, ioverride)
# 注册到C++端的全局map中,这里的func_name就是"tvm.relay.build"
四、Python端的注册中心实现
通过修改装饰器的内部register函数,就可以把Python端的注册机从TVM中剥离开来,具体的代码说明及使用已上传在github,有兴趣的读者进一步参阅:
https://github.com/Oreobird/demo/tree/master/python/registry
# !/usr/bin/env python
# -*- coding:utf-8 -*-
from os.path import basename, isfile, join
import glob
FUNC_DICT = {}
OBJ_DICT = {}
def register_func(func_name, func=None):
"""Register global function
Parameters
----------
func_name : str or function
The function name
func : function, optional
The function to be registered.
Returns
-------
fregister : function
Register function if f is not specified.
Examples
--------
The following code registers my_func.
.. code-block:: python
targs = (10, "hello")
@register_func
def my_func(*args):
return 10
# Get it out from register function table
f = get_register_func("my_func")
y = f(*targs)
assert y == 10
"""
if callable(func_name):
func = func_name
func_name = func.__name__
if not isinstance(func_name, str):
raise ValueError("expect string function name")
def register(f):
FUNC_DICT[func_name] = f
return f
if func:
return register(func)
return register
def get_reg_func(name):
"""Get a register function by name
Parameters
----------
name : str
The name of the registerd function
Returns
-------
fregister : function
The function to be returned, None if function is missing.
"""
if name not in FUNC_DICT.keys():
return None
return FUNC_DICT[name]
def get_reg_func_list():
"""Get a register function by name
Returns
-------
fregister_list : List
The function list to be returned, [] if no function registered.
"""
func_list = []
for _, v in FUNC_DICT.items():
func_list.append(v)
return func_list
def register_obj(obj_type=None):
"""register object type.
Parameters
----------
obj_type : str or cls
The name of the node
Examples
--------
The following code registers MyObject
using type key "test.MyObject"
.. code-block:: python
@tvm.register_object("test.MyObject")
class MyObject(Object):
pass
"""
def register(cls):
obj_name = obj_type if isinstance(obj_type, str) else cls.__name__
OBJ_DICT[obj_name] = cls
return cls
if isinstance(obj_type, str):
return register
return register(obj_type)
def get_reg_obj(obj_type):
"""Get a register object by type
Parameters
----------
obj_type : str
The obj_type of the register object
Returns
-------
object : str or cls
The object to be returned, None if object is missing.
"""
obj_name = obj_type if isinstance(obj_type, str) else obj_type.__name__
if obj_name not in OBJ_DICT.keys():
return None
return OBJ_DICT[obj_name]()
def get_reg_obj_list():
"""Get a register Obj by name
Returns
-------
fregister_list : List
The Obj list to be returned, [] if no obj registered.
"""
obj_list = []
for _, v in OBJ_DICT.items():
obj_list.append(v)
return obj_list
def reg_modules(dir):
"""Import all package modules in dir, call in __init__.py of a package
Parameters
----------
dir : str
The directory of the
Returns
-------
modules : List
The modules list to be imported, set to __all__ in __init__.py
Examples
--------
.. code-block:: python
from os.path import dirname
from registry import reg_modules
__all__ = reg_modules(dirname(__file__))
from . import *
"""
modules = glob.glob(join(dir, "*.py"))
return [basename(f)[:-3] for f in modules if isfile(f) and not f.endswith('__init__.py')]