【TVM系列九】FFI注册机制

一、前言

TVM通过PackedFunc机制实现了Python与C++之间的相互调用,即FFI(Foreign Function Interface),整体流程与原理可以参阅系列文章《【TVM系列六】PackedFunc原理》。本文将对TVM的FFI注册机制进行详细的说明,并实现一个Python端的注册中心Demo。

二、什么是FFI

FFI(Foreign Function Interface)即跨语言交互接口,比如Java中的JNI,用于实现在语言A中调用语言B的函数。通常情况下,要达到这样的目的,一种方式是将函数做成一个服务,通过进程间通信或网络协议通信,这种方式需要至少两个独立的进程才能实现;另一种就是通过 FFI 直接将其它语言的接口内嵌到本语言中,实现高效率的调用。

三、TVM的FFI实现

FFI部分的实现代码主要集中在python/tvm/_ffi/目录,以及各个目录下的_ffi_api.py文件,以下的分析均以python/tvm/runtime内的注册为例子进行说明。

1、如何注册C++端接口到Python端

runtime中的_ffi_api.py:

import tvm._ffi tvm._ffi._init_api("runtime", __name__) 
# 此处的__name__ = "tvm.runtime._ffi_api"

_ffi中的registry.py:

def _init_api(namespace, target_module_name=None):
    target_module_name = target_module_name if target_module_name else namespace
    # 此处 target_module_name = "tvm.runtime._ffi_api",namespace = "runtime"
    if namespace.startswith("tvm."):
        _init_api_prefix(target_module_name, namespace[4:])
    else:
        _init_api_prefix(target_module_name, namespace) # 代码走这个分支
        
def _init_api_prefix(module_name, prefix):
    module = sys.modules[module_name]  # module_name = "tvm.runtime._ffi_api", prefix = "runtime"
    # 每当导入新模块,全局字典sys.modules将记录该模块,当再次导入该模块时,
    # 会直接到字典中查找,从而加快程序运行的速度
    
    for name in list_global_func_names(): 
        # list_global_func_names()通过_LIB.TVMFuncListGlobalNames()得到函数列表,打印得到的内容为:       # ['tvm.contrib.sort.argsort', ... , 'runtime.module.loadbinary_AotExecutorFactory', ...
        #  'tvm.graph_executor_debug.create', 'runtime.module.loadbinary_GraphRuntimeFactory', ...
        #  'runtime.RPCTimeEvaluator', 'topi.nn.nll_loss', 'topi.argmax', ...]
        
        if not name.startswith(prefix):
            continue
        # 过滤掉名称不是以prefix = "runtime" 开始的函数
        
        fname = name[len(prefix) + 1 :]
        # 去掉前缀,如果name = 'runtime.module.loadbinary_AotExecutorFactory',
        # 则fname = 'module.loadbinary_AotExecutorFactory'
        
        target_module = module  # 此时target_module = sys.modules["tvm.runtime._ffi_api"]

        if fname.find(".") != -1: # # 过滤掉非二级模块如: 'runtime.RPCTimeEvaluator'
            continue*
        f = get_global_func(name)  # name = runtime.RPCTimeEvaluator, f 是 C++ 端实现的PackedFunc handle
        ff = _get_api(f)  # 将f的is_global置为true并赋值给ff
        ff.__name__ = fname  # ff的__name__属性设置为: 比如module.loadbinary_AotExecutorFactory
        ff.__doc__ = "TVM PackedFunc %s. " % fname  # 设置ff的__doc__属性
        setattr(target_module, ff.__name__, ff)
        # 通过setattr为target_module = sys.modules["tvm.runtime._ffi_api"]添加函数,
        # 此时就完成了C++实现的函数接口加到python的模块,然后在python端使用的时候就可以
        # 在runtime的其它文件中通过from . import _ffi_api导入这些函数接口。
2、如何注册Python端接口到C++端

TVM利用Python的装饰器以及PackedFunc机制实现了插件式的类型对象与全局函数的注册方式:

  • 注册对象类型
# 在_ffi/_ctypes/object.py中定义全局的字典
OBJECT_TYPE = {}   

def _register_object(index, cls):
    if issubclass(cls, NDArrayBase):
        _register_ndarray(index, cls)
        return
    OBJECT_TYPE[index] = cls   # 全局字典用于保存注册的类型

# 在_ffi/registry.py定义类型注册装饰器
def register_object(type_key=None):
    """
    Examples
    # 这里的装饰器相当于 MyObject = tvm.register_object("test.MyObject")(MyObject),分两步执行:
    # (1)调用tvm.register_object("test.MyObject"),返回内部的register函数指针
    # (2)以MyObject为参数调用内部的register(MyObject)函数
      @tvm.register_object("test.MyObject") 
      class MyObject(Object):
          pass
          
    # 这里的装饰器相当于 OtherObject = tvm.register_object(OtherObject),由于传入的不是字符串,所以只执行:
    # 以OtherObject为入参调用内部的register()函数
      @tvm.register_object
      class OtherObject(Object):
          pass
    """
    
    # 作为装饰器时可以指定type_key字符串, 也可以不指定,此时会使用它所装饰的类名称 type_key.__name__ 作为它的名称
    object_name = type_key if isinstance(type_key, str) else type_key.__name__

    def register(cls):
        """internal register function"""
        if hasattr(cls, "_type_index"):
            tindex = cls._type_index
        else:
            tidx = ctypes.c_uint()
            if not _RUNTIME_ONLY:
                check_call(_LIB.TVMObjectTypeKey2Index(c_str(object_name), ctypes.byref(tidx)))
            else:
                # directly skip unknown objects during runtime.
                ret = _LIB.TVMObjectTypeKey2Index(c_str(object_name), ctypes.byref(tidx))
                if ret != 0:
                    return cls
            tindex = tidx.value
        # 根据object_name找到对应的tindex, 作为键值, 最终注册在OBJECT_TYPE的条目是{tindex: cls}
        _register_object(tindex, cls)
        return cls

    # 如果是字符串,返回的是register函数指针
    if isinstance(type_key, str):
        return register

    # 否则以type_key为参数调用register
    return register(type_key)
  • 注册全局函数
def register_func(func_name, f=None, override=False):
    """Register global function
    Examples
    .. code-block:: python
    # 这里的装饰器相当于 my_func = tvm.register_func("my_func")(my_func),分两步执行:
    # (1)调用tvm.register_func("my_func"),返回内部的register函数指针
    # (2)以my_func为参数调用内部的register(my_func)函数
     @tvm.register_func("my_func")
     def my_func():
         pass
          
    # 这里的装饰器相当于 other_func = tvm.register_func(other_func),由于传入的不是字符串,所以只执行:
    # 以other_func为入参调用内部的register(other_func)函数     
     @tvm.register_func
     def other_func():
        pass
    """
    
    # 如果是没指定字符名称的方式: @tvm.register_func
    if callable(func_name):
        f = func_name    # func_name为函数,赋值给f
        func_name = f.__name__  # 使用函数的__name__作为函数字符名称 

    if not isinstance(func_name, str):
        raise ValueError("expect string function name")

    ioverride = ctypes.c_int(override)

    def register(myf):
        """internal register function"""
        if not isinstance(myf, PackedFuncBase):
            myf = convert_to_tvm_func(myf)
        check_call(_LIB.TVMFuncRegisterGlobal(c_str(func_name), myf.handle, ioverride))
        return myf
    
    # 没有指定字符名称的方式: @tvm.register_func,
    # 将函数f作为参数传入内部register()
    if f:
        return register(f)
        
    # 指定字符名称@tvm.register_func("my_func")的方式,
    # 返回内部register函数指针
    return register

看个实际的例子,在relay/build_module.py中:

@register_func("tvm.relay.build")
def _build_module_no_factory_impl(mod, target, target_host, params, mod_name):
    return build(
        mod, target=target, target_host=target_host, params=params, mod_name=mod_name
    ).module
# 实际调用过程为:
# (1) register_func("tvm.relay.build")返回register函数指针
# (2) register(_build_module_no_factory_impl):
# 首先通过convert_to_tvm_func(_build_module_no_factory_impl)转换得到PackedFunc;
# 然后通过_LIB.TVMFuncRegisterGlobal(c_str(func_name), myf.handle, ioverride)
# 注册到C++端的全局map中,这里的func_name就是"tvm.relay.build"

四、Python端的注册中心实现

通过修改装饰器的内部register函数,就可以把Python端的注册机从TVM中剥离开来,具体的代码说明及使用已上传在github,有兴趣的读者进一步参阅:

https://github.com/Oreobird/demo/tree/master/python/registry

# !/usr/bin/env python
# -*- coding:utf-8 -*-

from os.path import basename, isfile, join
import glob

FUNC_DICT = {}
OBJ_DICT = {}

def register_func(func_name, func=None):
    """Register global function

    Parameters
    ----------
    func_name : str or function
        The function name

    func : function, optional
        The function to be registered.

    Returns
    -------
    fregister : function
        Register function if f is not specified.

    Examples
    --------
    The following code registers my_func.

    .. code-block:: python
      targs = (10, "hello")
      @register_func
      def my_func(*args):
          return 10
      # Get it out from register function table
      f = get_register_func("my_func")
      y = f(*targs)
      assert y == 10
    """
    if callable(func_name):
        func = func_name
        func_name = func.__name__

    if not isinstance(func_name, str):
        raise ValueError("expect string function name")

    def register(f):
        FUNC_DICT[func_name] = f
        return f

    if func:
        return register(func)

    return register

def get_reg_func(name):
    """Get a register function by name

    Parameters
    ----------
    name : str
        The name of the registerd function

    Returns
    -------
    fregister : function
        The function to be returned, None if function is missing.
    """
    if name not in FUNC_DICT.keys():
        return None
    return FUNC_DICT[name]

def get_reg_func_list():
    """Get a register function by name

    Returns
    -------
    fregister_list : List
        The function list to be returned, [] if no function registered.
    """
    func_list = []
    for _, v in FUNC_DICT.items():
        func_list.append(v)
    return func_list


def register_obj(obj_type=None):
    """register object type.

    Parameters
    ----------
    obj_type : str or cls
        The name of the node

    Examples
    --------
    The following code registers MyObject
    using type key "test.MyObject"

    .. code-block:: python

      @tvm.register_object("test.MyObject")
      class MyObject(Object):
          pass
    """

    def register(cls):
        obj_name = obj_type if isinstance(obj_type, str) else cls.__name__
        OBJ_DICT[obj_name] = cls
        return cls

    if isinstance(obj_type, str):
        return register

    return register(obj_type)


def get_reg_obj(obj_type):
    """Get a register object by type

    Parameters
    ----------
    obj_type : str
        The obj_type of the register object

    Returns
    -------
    object : str or cls
        The object to be returned, None if object is missing.
    """
    obj_name = obj_type if isinstance(obj_type, str) else obj_type.__name__
    if obj_name not in OBJ_DICT.keys():
        return None
    return OBJ_DICT[obj_name]()

def get_reg_obj_list():
    """Get a register Obj by name

    Returns
    -------
    fregister_list : List
        The Obj list to be returned, [] if no obj registered.
    """
    obj_list = []
    for _, v in OBJ_DICT.items():
        obj_list.append(v)
    return obj_list

def reg_modules(dir):
    """Import all package modules in dir, call in __init__.py of a package

    Parameters
    ----------
    dir : str
        The directory of the

    Returns
    -------
    modules : List
        The modules list to be imported, set to __all__ in __init__.py

    Examples
    --------
    .. code-block:: python
        from os.path import dirname
        from registry import reg_modules
        __all__ = reg_modules(dirname(__file__))
        from . import *
    """
    modules = glob.glob(join(dir, "*.py"))
    return [basename(f)[:-3] for f in modules if isfile(f) and not f.endswith('__init__.py')]

你可能感兴趣的:(【TVM系列九】FFI注册机制)