【TVM源码学习笔记】3 模型编译

 在我们的模型编译运行脚本中,使用relay.build编译模型:

# 设置优化级别
with tvm.transform.PassContext(opt_level=3):
    #编译模型
    lib = relay.build(mod, target, params=params)

因为在python/tvm/relay/__init__.py中有:

from .build_module import build, create_executor, optimize

所以这里build直接调用的是python/tvm/relay/build_module.py中的build函数。我们传入了三个参数,其中第一个mod是一个IRModule实例。excutor参数使用的是默认的graph。函数开头部分的autotvm是TVM的优化工具。所以函数开头部分我们可以先不深究,知道流程走的下面的代码即可:

def build(
    ir_mod,
    target=None,
    target_host=None,
    executor=Executor("graph"),
    runtime=Runtime("cpp"),
    workspace_memory_pools=None,
    constant_memory_pools=None,
    params=None,
    mod_name="default",
):
    ...

    with tophub_context:
        bld_mod = BuildModule()
        graph_json, runtime_mod, params = bld_mod.build(
            mod=ir_mod,
            target=raw_targets,
            params=params,
            executor=executor,
            runtime=runtime,
            workspace_memory_pools=workspace_memory_pools,
            constant_memory_pools=constant_memory_pools,
            mod_name=mod_name,
        )
        func_metadata = bld_mod.get_function_metadata()
        devices = bld_mod.get_devices()
        lowered_ir_mods = bld_mod.get_irmodule()
        executor_codegen_metadata = bld_mod.get_executor_codegen_metadata()

        ...

        elif executor.name == "graph":
            executor_factory = _executor_factory.GraphExecutorFactoryModule(
                ir_mod,
                raw_targets,
                executor,
                graph_json,
                runtime_mod,
                mod_name,
                params,
                func_metadata,
            )
        else:
            assert False, "Executor " + executor + " not supported"

        return executor_factory

代码中先实例化了一个BuildModule对象,然后调用BuildModule的build方法编译模型;编译完后读取编译后的模型数据,调用_executor_factory.GraphExecutorFactoryModule创建了一个执行器。

BuildModule.build方法里面,调用的BuildModule._build,这个 _build是在BuildModule.__init__方法中赋值的:

    def __init__(self):
        self.mod = _build_module._BuildModule()
        self._get_graph_json = self.mod["get_graph_json"]
        self._get_module = self.mod["get_module"]
        self._build = self.mod["build"]

self._build挂载的是_build_module.py中_BuildModule()对象的成员。_build_module.py:

import tvm._ffi

tvm._ffi._init_api("relay.build_module", __name__)

我们直接搜索relay.build_module._BuildModule,可以找到接口的注册:

runtime::Module RelayBuildCreate() {
  auto exec = make_object();
  return runtime::Module(exec);
}

TVM_REGISTER_GLOBAL("relay.build_module._BuildModule").set_body([](TVMArgs args, TVMRetValue* rv) {
  *rv = RelayBuildCreate();
});

这里只是创建了一个RelayBuildModule,然后包再runtime::Module里面,返回runtime::Module实例。

python前端的BuildModule.mod["build"]获取了runtime::Module的build属性,在C++端会首先调用到runtime::Module::GetFunction方法。这个调用过程涉及到TVM的PackedFunc机制,可以参考: 

TVM PackedFunc实现机制 | Don't Respond

runtime::Module::GetFunction的实现:

inline PackedFunc Module::GetFunction(const std::string& name, bool query_imports) {
  return (*this)->GetFunction(name, query_imports);
}

而 runtime::Module对取成员操作符->做了重载:

inline ModuleNode* Module::operator->() { return static_cast(get_mutable()); }

get_mutable是Module祖先类ObjectRef的方法,获取自己的_data字段,是对应的Object实例。而在RelayBuildCreate里面,创建Module实例的时候,传入的是一个RelayBuildModule,继承自Object。所以这里的_data就是这个RelayBuildModule,调用到的(*this)->GetFunction调用到的也就是RelayBuildModule::GetFunction:

class RelayBuildModule : public runtime::ModuleNode {
 public:

  ...

  PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final {

    ...

    else if (name == "build") {
      return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
        ICHECK_EQ(args.num_args, 8);
        this->Build(args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7]);
      });

    ... 
  

最终会调用到:

/*!
   * \brief Compile a Relay IR module to runtime module.
   *
   * \param relay_module The Relay IR module.
   * \param params The parameters.
   */
  void BuildRelay(IRModule relay_module, const String& mod_name) {
    // Relay IRModule -> IRModule optimizations.
    //1. 对relay ir做优化,执行优化pass
    IRModule module = WithAttrs(
        relay_module, {{tvm::attr::kExecutor, executor_}, {tvm::attr::kRuntime, runtime_}});
    relay_module = OptimizeImpl(std::move(module));

    // Get the updated function and new IRModule to build.
    // Instead of recreating the IRModule, we should look at the differences between this and the
    // incoming IRModule to see if we can just pass (IRModule, Function) to the code generator.
    //2. 按注释看是希望对IRModule做增量编译,而不是全部重新编译
    Function func = Downcast(relay_module->Lookup("main"));
    IRModule func_module = WithAttrs(IRModule::FromExpr(func),
                                     {{tvm::attr::kExecutor, executor_},
                                      {tvm::attr::kRuntime, runtime_},
                                      {tvm::attr::kWorkspaceMemoryPools, workspace_memory_pools_},
                                      {tvm::attr::kConstantMemoryPools, constant_memory_pools_}});

    // Generate code for the updated function.
    // 3.创建执行器对应的代码生成器
    executor_codegen_ = MakeExecutorCodegen(executor_->name);
    // 4. 初始化代码生成器
    executor_codegen_->Init(nullptr, config_->primitive_targets);
    // 5. 对找到的main函数生成代码
    executor_codegen_->Codegen(func_module, func, mod_name);
    // 6. 更新输出
    executor_codegen_->UpdateOutput(&ret_);
    // 7. 获取参数
    ret_.params = executor_codegen_->GetParams();
    
    auto lowered_funcs = executor_codegen_->GetIRModule();

    // No need to build for external functions.
    Target ext_dev("ext_dev");
    if (lowered_funcs.find(ext_dev) != lowered_funcs.end()) {
      lowered_funcs.Set(ext_dev, IRModule());
    }

    const Target& host_target = config_->host_virtual_device->target;
    const runtime::PackedFunc* pf = runtime::Registry::Get("codegen.LLVMModuleCreate");
    // When there is no lowered_funcs due to reasons such as optimization.
    if (lowered_funcs.size() == 0) {
      if (host_target->kind->name == "llvm") {
        CHECK(pf != nullptr) << "Unable to create empty module for llvm without llvm codegen.";
        // If we can decide the target is LLVM, we then create an empty LLVM module.
        ret_.mod = (*pf)(host_target->str(), "empty_module");
      } else {
        // If we cannot decide the target is LLVM, we create an empty CSourceModule.
        // The code content is initialized with ";" to prevent complaining
        // from CSourceModuleNode::SaveToFile.
        ret_.mod = tvm::codegen::CSourceModuleCreate(";", "", Array{});
      }
    } else {
      // 8. 打包tir运行时
      ret_.mod = tvm::TIRToRuntime(lowered_funcs, host_target);
    }

    auto ext_mods = executor_codegen_->GetExternalModules();
    ret_.mod = tvm::codegen::CreateMetadataModule(ret_.params, ret_.mod, ext_mods, host_target,
                                                  runtime_, executor_,
                                                  executor_codegen_->GetExecutorCodegenMetadata());
    // Remove external params which were stored in metadata module.
    for (tvm::runtime::Module mod : ext_mods) {
      auto pf_var = mod.GetFunction("get_const_vars");
      if (pf_var != nullptr) {
        // 9. 删除常量
        Array variables = pf_var();
        for (size_t i = 0; i < variables.size(); i++) {
          auto it = ret_.params.find(variables[i].operator std::string());
          if (it != ret_.params.end()) {
            VLOG(1) << "constant '" << variables[i] << "' has been captured in external module";
            ret_.params.erase(it);
          }
        }
      }
    }
  }

1. 执行优化pass。这里是根据我们在模型编译运行脚本中设置的优化上下文:

# 设置优化级别
with tvm.transform.PassContext(opt_level=3):
    #编译模型

对relay ir做优化。优化后面再专门研究吧。

2. 查找relay ir中的main函数,然后使用main函数创建一个IRModule,并设置属性。这个将是后面会编译的IRModule。按注释这里是实现了增量编译的,怎么实现的呢?

3. 生成执行器代码生成器类实例。我们在编译时没有传入执行器参数,直接使用了默认的graph执行器,得到的将是relay.build_module._GraphExecutorCodegen对应的类GraphExecutorCodegenModule,定义在src/relay/backend/graph_executor_codegen.cc中。这个类是在代码生成器的基础上做了一层包装。真正代码生成器是它的codegen_成员;

4. 执行GraphExecutorCodegenModule的初始化函数,这里会生成代码生成器,赋给codegen_,类型是GraphExecutorCodegen

5. 代码生成。这里是将relay ir 低级化为 tir形式。后面专门讨论;

6. UpdataOuput只是重新生成了编译得到的json文件;

7. 获取生成的模型tir输出参数。这里都是一些常量值的名字;

8. 根据设置的模型运行目标(target)和当前的host,分别创建两者上运行的模块;

9. 删除常量参数。为什么?

你可能感兴趣的:(TVM源码分析,python,深度学习,机器学习)