LLVM官方教程Kaleidoscope 4

[TOC]

参考

4. Kaleidoscope: Adding JIT and Optimizer Support

1. 前言

之前的3章,实现了一个简单的语言,并且支持了 LLVM IR 的生成。本章主要介绍两项技术:优化器和 JIT。

2. 琐碎的常量折叠

默认的 IRBuilder 包含了下文介绍的常量折叠技术。

第三章的示范很优雅,便于拓展,但是不能生产绝妙的代码。比如说,编译如下的简单代码,并不能获得显而易见的优化:

ready> def test(x) 1+2+x;
Read function definition:
define double @test(double %x) {
entry:
        %addtmp = fadd double 1.000000e+00, 2.000000e+00
        %addtmp1 = fadd double %addtmp, %x
        ret double %addtmp1
}

上述生成的代码仅仅是构建后的 AST 的文字转录,没有常量折叠的优化。常量折叠是最常用和普遍的优化手段。

使用 LLVM,你不需要在 AST 支持常量折叠。因为所有创建 LLVM IR 的调用,都会经过 LLVM Builder,Builder 自身去检查是否有常量折叠的机会才是完美的做法。如果可以折叠,Builder 就会执行常量折叠并且返回常量值,而不是创建指令。这就是LLVMFoldingBuilder类做的事情。

我们所有需要做的就是从 LLVMBuilder 切换到 LLVMFoldingBuilder。尽管没有改变其他代码,我们的所有指令都会隐式地做常量折叠优化。比如说:

ready> def test(x) 1+2+x;
Read function definition:
define double @test(double %x) {
entry:
        %addtmp = fadd double 3.000000e+00, %x
        ret double %addtmp
}

可以看见,常量折叠的优化很简单。实际上,我们推荐总是用LLVMFoldingBuilder生成这样的代码。LLVMFoldingBuilder不会引入语法上的开销,并可以在某些场景大量减少指令数。

另一方面,LLVMFoldingBuilder 受以下事实限制:它在生成代码时会与代码内联进行所有分析。看一个复杂一点的例子:

ready> def test(x) (1+2+x)*(x+(1+2));
ready> Read function definition:
define double @test(double %x) {
entry:
        %addtmp = fadd double 3.000000e+00, %x
        %addtmp1 = fadd double %x, 3.000000e+00
        %multmp = fmul double %addtmp, %addtmp1
        ret double %multmp
}

这个例子中,乘号的左右操作符是一样的。我们真正希望看到是tmp = x+3; result = tmp*tmp;,而不是计算x*3两次。

不幸的是,没有局部分析可以检查和修正这个问题。这需要两个装换:表达式的重新关联(以使加法操作在词法上相同)和常见子表达式裁剪(CSE)。幸运的是,LLVM 通过路径("passes")的形式提供了很多优化可用。

3. LLVM optimize passes

LLVM 提供了很多优化pass,可以做很多种事情,有不同的权衡。不像其他系统,LLVM 不会错误地认为一组优化适用于所有的语言和情况。LLVM 允许编译器实现者完全决定用什么优化、以什么顺序优化以及优化的场景。

举一个具体的例子。LLVM 支持两个 'whole module' 路径,他遍历了尽可能多的代码主体(通常是整个文件,但是如果在链接时运行,则这可能是整个程序的重要部分)。它同样支持和包含了'per-function' pass,一次只操作一个函数,不会关注其他的函数。更多关于 pass 的内容,参见 How to Write a Pass和 List of LLVM Passes。

对Kaleidoscope而言,我们在运行中产生函数,在用户输入的时候,一次生成一个。我们不会做极致的优化,但是也会在可能的位置,做一些简单和快速的优化。因此我们选择在用户输入函数的时候,使用 'per-function' 优化。如果我们想做一个静态的Kaleidoscope编译器,我们将完全使用现在拥有的代码,除了我们将推迟运行优化器,直到整个文件解析完为止。

为了继续进行 per-function 优化,我们需要设置Llvm.PassManager以持有和管理我们想要运行的优化。代码如下:

void InitializeModuleAndPassManager(void) {
  // Open a new module.
  TheModule = std::make_unique("my cool jit", TheContext);

  // Create a new pass manager attached to it.
  TheFPM = std::make_unique(TheModule.get());

  // Do simple "peephole" optimizations and bit-twiddling optzns.
  TheFPM->add(createInstructionCombiningPass());
  // Reassociate expressions.
  TheFPM->add(createReassociatePass());
  // Eliminate Common SubExpressions.
  TheFPM->add(createGVNPass());
  // Simplify the control flow graph (deleting unreachable blocks, etc).
  TheFPM->add(createCFGSimplificationPass());

  TheFPM->doInitialization();
}

上述代码定义了全局的TheModuleTheFPM。设置完成后,我们调用一系列的add来添加一堆LLVM pass。

这个 case 中,我们选择增加了4个 pass,这里我们选择的 pass 是一组相当标准的 'cleanup' 优化,这些优化对很多代码都有用。这里不会介绍它内部执行逻辑。

PassManager设置之后,我们需要使用它了。我们在构造新创建的函数(FunctionAST::codegen())之后,在返回给 client 之前,来执行此操作:

if (Value *RetVal = Body->codegen()) {
  // Finish off the function.
  Builder.CreateRet(RetVal);

  // Validate the generated code, checking for consistency.
  verifyFunction(*TheFunction);

  // Optimize the function.
  TheFPM->run(*TheFunction);

  return TheFunction;
}

如上所示,很简单明了。the_fpm在适当位置优化和更新了LLVM Function*,以改善(希望)其主体。有了这个,我们可以再次进行上面的测试:

ready> def test(x) (1+2+x)*(x+(1+2));
ready> Read function definition:
define double @test(double %x) {
entry:
        %addtmp = fadd double %x, 3.000000e+00
        %multmp = fmul double %addtmp, %addtmp
        ret double %multmp
}

我们获得了期望之中的优化代码,减少了这个函数每次执行时的浮点加法执行。

LLVM 提供了很多种优化,可用于某种特定的场景。可以在documentation about the various passes进行查看,但这个页面并不是很全。查看 Clang 的pass 使用,可能是入门的另一个好想法来源。"opt"工具允许你在命令行中实验 pass,所以你可以看到pass做的任何事情。

截至目前,我们已经从我们的前端中产出了合理的代码,现在开始讨论如何执行他们。

4. 添加 JIT 编译器

LLVM IR 中可用的代码可以适用于多种工具。比如说,你可以在 IR 上进行优化(如我们前文所做的),你可以以文本或二进制的格式dump IR,你也可以编译代码成某个目标的一个组装文件(.s),你还可以使用 JIT 编译它。LLVM IR的好处在于,他是编译器各个不通部分的“通用货币”。

在这一小节,我们将添加一个 JIT 编译器支持到我们的翻译器。我们的基本想法是,用户像现在这样输入函数的主体,但是能够立即计算出 top-level 的表达式。比如说,输入了1+2,我们应该计算并打印出结果3。如果他们定义了一个函数,他们应该可以从命令行直接调用。

为了实现这个目标,我们首先声明和初始化 JIT。这一步通过在 main 方法中,添加一个全局变量TheJIT,并调用InitializeNativeTarget*这类函数完成:

static std::unique_ptr TheJIT;
...
int main() {
  InitializeNativeTarget();
  InitializeNativeTargetAsmPrinter();
  InitializeNativeTargetAsmParser();

  // Install standard binary operators.
  // 1 is lowest precedence.
  BinopPrecedence['<'] = 10;
  BinopPrecedence['+'] = 20;
  BinopPrecedence['-'] = 20;
  BinopPrecedence['*'] = 40; // highest.

  // Prime the first token.
  fprintf(stderr, "ready> ");
  getNextToken();

  TheJIT = std::make_unique();

  // Run the main "interpreter loop" now.
  MainLoop();

  return 0;
}

然后我们还需要为 JIT 设置 data layout:

void InitializeModuleAndPassManager(void) {
  // Open a new module.
  TheModule = std::make_unique("my cool jit", TheContext);
  TheModule->setDataLayout(TheJIT->getTargetMachine().createDataLayout());

  // Create a new pass manager attached to it.
  TheFPM = std::make_unique(TheModule.get());
  ...

KaleidoscopeJIT类是专门为教程实现的一个简化版 JIT,具体实现可在 LLVM源码路径llvm-src/examples/Kaleidoscope/include/KaleidoscopeJIT.h下找到。后续的章节,我们会了解他是如何运转和拓展新功能。他的接口非常简单,addModule增加一个 LLVM IR module 到 JIT,使其函数对执行可见;removeModule删除一个 module,释放和 module 中的代码相关联的所有内存;findSymbol允许我们查找已编译代码的指针。

有了以上的代码,我们可以改变解析 top-level 表达式的代码如下:

static void HandleTopLevelExpression() {
  // Evaluate a top-level expression into an anonymous function.
  if (auto FnAST = ParseTopLevelExpr()) {
    if (FnAST->codegen()) {

      // JIT the module containing the anonymous expression, keeping a handle so
      // we can free it later.
      auto H = TheJIT->addModule(std::move(TheModule));
      InitializeModuleAndPassManager();

      // Search the JIT for the __anon_expr symbol.
      auto ExprSymbol = TheJIT->findSymbol("__anon_expr");
      assert(ExprSymbol && "Function not found");

      // Get the symbol's address and cast it to the right type (takes no
      // arguments, returns a double) so we can call it as a native function.
      double (*FP)() = (double (*)())(intptr_t)ExprSymbol.getAddress();
      fprintf(stderr, "Evaluated to %f\n", FP());

      // Delete the anonymous expression module from the JIT.
      TheJIT->removeModule(H);
    }

如果解析并生成代码成功,下一步就是把还有 top-level 表达式的 module 添加到 JIT。我们通过调用addModule方法来完成添加,这个方法会触发 module 中所有函数的代码生成,并返回一个句柄,用于稍后从 JIT 中删除 module。一旦 module 被添加到JIT,就不能再被修改,所以我们通过调用InitializeModuleAndPassManager来打开一个新的 module ,用于持有后续的代码。

添加 module 到 JIT 之后,我们需要获取一个指针,指向最终生成的代码。通过调用 JIT 的findSymbol方法,并传入top-level 表达式函数名(__anon_expr),就可以获得代码指针。

下一步,我们通过调用getAddress获取__anon_expr函数的内存地址。回忆一下,我们之前把 top-level 表达式编译到一个独立的 LLVM 函数中,这个函数没有参数,返回计算后的 double 类型数据。因为 LLVM JIT 编译器和本地平台的 API 匹配,意味着你可以把结果指针转换成那种类型的函数指针并调用这个函数。这也表明,JIT 编译的代码和静态链接到应用程序的本地机器码没有区别。

最后,由于我们不支持 top-level 表达式的复算,结束的时候,我们从 JIT 把 module 删除,释放相关联的内存。但是我们几行代码之前创建module 还打开着,可以添加新的代码进去。

通过这两个变化,看下Kaleidoscope现在怎么工作。

ready> 4+5;
define double @""() {
entry:
        ret double 9.000000e+00
}

Evaluated to 9.000000

这看起来基本能用了。dump 出来的 function 展示了“总是返回 double 的无参数函数”,我们会为输入的每一个 top-level 函数合乘该函数。这演示了很基本的功能,但是我们能做得更多事情么?

ready> def testfunc(x y) x + y*2;
Read function definition:
define double @testfunc(double %x, double %y) {
entry:
        %multmp = fmul double %y, 2.000000e+00
        %addtmp = fadd double %multmp, %x
        ret double %addtmp
}

ready> testfunc(4, 10);
define double @""() {
entry:
        %calltmp = call double @testfunc(double 4.000000e+00, double 1.000000e+01)
        ret double %calltmp
}

Evaluated to 24.000000

ready> testfunc(5, 10);
ready> LLVM ERROR: Program used external function 'testfunc' which could not be resolved!

函数定义和调用也可以工作了,但是最后一行出错了。最后一行看起来是没有问题的,问题在哪呢?从 API 来看,module 是JIT的分配单元,testfunc是同一个包含匿名表达式的模块的一部分。当我们从 JIT 删除 module 之后,匿名表达式的内存就被释放了,testfunc的定义也就随之删除了。再次调用testfunc的时候,JIT 就找不到这个函数了。

最简单的修复方法是将匿名表达式放在一个独立的 module,和函数定义的其他部分区分开来。JIT 将乐于解决跨模块边界的函数调用,只要每一个被调函数有 prototype,且在调用前被添加到了JIT 中。通过将匿名表达式放到不同的module,我们可以在不影响函数其他部分的情况下删除 module。

实际上,我们将更进一步地把每个函数放到自己的 module 中。这样做,可以让我们探索KaleidoscopeJIT更有用的特性,让我们的环境更像 REPL:函数可以被多次添加到 JIT(不像只有一个 module的时候,函数需要有一个唯一的定义)。当你在KaleidoscopeJIT查找一个符号的时候,总是返回最近的定义给你。

ready> def foo(x) x + 1;
Read function definition:
define double @foo(double %x) {
entry:
  %addtmp = fadd double %x, 1.000000e+00
  ret double %addtmp
}

ready> foo(2);
Evaluated to 3.000000

ready> def foo(x) x + 2;
define double @foo(double %x) {
entry:
  %addtmp = fadd double %x, 2.000000e+00
  ret double %addtmp
}

ready> foo(2);
Evaluated to 4.000000

为了允许每一个函数的生命周期在他自己的模块内,我们需要一个方法,重新生成之前的函数声明到打开的每一个新 module 中:

static std::unique_ptr TheJIT;

...

Function *getFunction(std::string Name) {
  // First, see if the function has already been added to the current module.
  if (auto *F = TheModule->getFunction(Name))
    return F;

  // If not, check whether we can codegen the declaration from some existing
  // prototype.
  auto FI = FunctionProtos.find(Name);
  if (FI != FunctionProtos.end())
    return FI->second->codegen();

  // If no existing prototype exists, return null.
  return nullptr;
}

...

Value *CallExprAST::codegen() {
  // Look up the name in the global module table.
  Function *CalleeF = getFunction(Callee);

...

Function *FunctionAST::codegen() {
  // Transfer ownership of the prototype to the FunctionProtos map, but keep a
  // reference to it for use below.
  auto &P = *Proto;
  FunctionProtos[Proto->getName()] = std::move(Proto);
  Function *TheFunction = getFunction(P.getName());
  if (!TheFunction)
    return nullptr;

这个功能开启之后,我们先添加一个新的全局FunctionProtos,持有每一个函数最近的 prototype。另外增加了一个便捷的方法,getFunction(),替代为TheModule->getFunction()的调用。我们的便捷方法在TheModule中查找存在的函数声明,如果函数声明不存在, 就进行回退操作——从FunctionProtos中生成新的声明。在CallExprAST::codegen()中,我们仅仅需要替换成TheModule->getFunction()的调用。在FunctionAST::codegen()中,我们首先需要更新FunctionProtos表,然后调用getFunction()。做完以上工作,我们总是能够从当前的 module 中,获取到先迁声明过的函数声明。

同样我们需要更新HandleDefinitionHandleExtern

static void HandleDefinition() {
  if (auto FnAST = ParseDefinition()) {
    if (auto *FnIR = FnAST->codegen()) {
      fprintf(stderr, "Read function definition:");
      FnIR->print(errs());
      fprintf(stderr, "\n");
      TheJIT->addModule(std::move(TheModule));
      InitializeModuleAndPassManager();
    }
  } else {
    // Skip token for error recovery.
     getNextToken();
  }
}

static void HandleExtern() {
  if (auto ProtoAST = ParseExtern()) {
    if (auto *FnIR = ProtoAST->codegen()) {
      fprintf(stderr, "Read extern: ");
      FnIR->print(errs());
      fprintf(stderr, "\n");
      FunctionProtos[ProtoAST->getName()] = std::move(ProtoAST);
    }
  } else {
    // Skip token for error recovery.
    getNextToken();
  }
}

HandleDefinition中,我们增加了两行代码,将新定义的函数转换到 JIT ,并打开一个新的 module。在HandleExtern中,我们增加了一行代码,将 prototype 添加到FunctionProtos

做完以上变动,再次尝试 REPL(省略了dump 的内容)。

ready> def foo(x) x + 1;
ready> foo(2);
Evaluated to 3.000000

ready> def foo(x) x + 2;
ready> foo(2);
Evaluated to 4.000000

再看看其他的牛逼功能:

ready> extern sin(x);
Read extern:
declare double @sin(double)

ready> extern cos(x);
Read extern:
declare double @cos(double)

ready> sin(1.0);
Read top-level expression:
define double @2() {
entry:
  ret double 0x3FEAED548F090CEE
}

Evaluated to 0.841471

ready> def foo(x) sin(x)*sin(x) + cos(x)*cos(x);
Read function definition:
define double @foo(double %x) {
entry:
  %calltmp = call double @sin(double %x)
  %multmp = fmul double %calltmp, %calltmp
  %calltmp2 = call double @cos(double %x)
  %multmp4 = fmul double %calltmp2, %calltmp2
  %addtmp = fadd double %multmp, %multmp4
  ret double %addtmp
}

ready> foo(4.0);
Read top-level expression:
define double @3() {
entry:
  %calltmp = call double @foo(double 4.000000e+00)
  ret double %calltmp
}

Evaluated to 1.000000

JIT 是怎么知道sincos的呢?答案很简单:这个例子里面,KaleidoscopeJIT有一个简单明了的符号解析规则用于查找在任意给定 module 中都不可用的符号:首先在 添加到JIT的所有 module 中查找,从新到旧,查找最新的定义。如果 JIT 中没有找到,执行回退操作,在Kaleidoscope进程上调用dlsym("sin") 。由于sin在 JIT 的地址空间有定义,因此它只需要修补模块中的调用,即可调用libm 版本的sin。但是在一些场景中,这可以更进一步:因为 sin 和 cos 都是标准的数学函数,当时用常量调用的时候,像上文的 sin(1.0),常量折叠会直接计算函数调用到正确的值。

未来,我们会看到如何稍微调整这个符号解析规则,用于使能各种有用的特性,包括安全性,基于符号名字动态生成代码以及懒编译机制。

符号解析规则的中间一个直接好处是我们能通过写任意的c++语言实现运算。比如:

#ifdef _WIN32
#define DLLEXPORT __declspec(dllexport)
#else
#define DLLEXPORT
#endif

/// putchard - putchar that takes a double and returns 0.
extern "C" DLLEXPORT double putchard(double X) {
  fputc((char)X, stderr);
  return 0;
}

注意,在 windows 系统上,我们需要准确地将函数导出,因为动态符号加载器会使用GetProcAddress来查找符号。

现在我们可以产出简单的输出到命令行界面,比如通过使用extern putchard(x); putchard(120);这样的代码,这段代码输出小写的x(120是x的ascii 码)。类似的代码可以用于实现文件I/O、命令行输出、在Kaleidoscope中的许多其他能力。

现在就完成了Kaleidoscope中 JIT 和优化器的介绍。现在我们可以编译非图灵机完备的语言,以用户主导的方式优化并使用 JIT 编译它。稍后我们会继续介绍控制流结构体的扩展 extending the language with control flow constructs,顺便处理一些有趣的 LLVM IR 问题。

5. 代码清单

全部代码见代码清单

你可能感兴趣的:(LLVM官方教程Kaleidoscope 4)