【从零开始学深度学习编译器】二十,MLIR的Pattern重写机制

0x0. 前言

这篇文章对MLIR的Pattern重写机制进行梳理和汇总,会结合实际例子把MLIR的两篇文档转化成容易看懂的形式。这两篇文档分别是https://mlir.llvm.org/docs/PatternRewriter/https://mlir.llvm.org/docs/Rationale/RationaleGenericDAGRewriter/ 。做这件事的动机是因为在我的开发过程中已经大量使用了MLIR的这个Pattern Rewrite机制,也经常回看这两篇文档所以翻译+梳理+总结一下。

0x1. Generic DAG Rewriter Infrastructure Rationale

题目可以翻译为通用的Dag重写架构的基本原理。对应 https://mlir.llvm.org/docs/Rationale/RationaleGenericDAGRewriter/ 这篇文档的内容。这里主要介绍了用于MLIR的通用Dag-to-Dag重写架构背后的基本原理。

0x1.1 介绍和动机

编译器IR目标是在各种抽象级别上表示代码 ,这在表示能力和易于变换方面提出了不同的折衷。 但是,表示代码的能力本身并不是很有用——您还需要能够实现这些变换。

编译器的变换有很多,这里主要介绍的是一种对MLIR目标非常重要且反复出现的变换:匹配一系列Op组成的Dag,然后将其替换为另外一个Dag。这是很多学习编译器不可或缺的一部分,对于诸如“消除identity(直连)节点”或者使用"x"替换"x+0"这种优化,通用规范化框架(比如LLVM的指令组合(Instruction Combiner)),以及为编译器在多个中间IR上实现优化算法提供了一个有用的抽象。

MLIR 的一个特殊优势(以及与 LLVM、GCC、XLA、TensorFlow 等其他编译器基础架构的主要区别)是它使用单个编译器 IR 来表示多个抽象级别的代码:MLIR 操作可以是“TensorFlow operation”、“XLA HLO”、仿射循环嵌套、LLVM IR 指令(可传递地包括 X86、Lanai、PTX 和其他目标特定指令)或 MLIR 算子系统可以合理表达的任何其它内容。 鉴于 MLIR 跨越了如此广泛的不同问题范围,用于执行图到图重写的单一基础架构可以帮助解决许多不同的领域挑战。

像 MLIR 这样的基于静态单赋值 (SSA) 的IR可以轻松访问Op的操作数和“users”。 因此,这些图到图重写的自然抽象是 DAG 模式匹配的抽象:客户端定义 DAG tile模式(其中tile是定义 DAG 子图的一系列Op),并且每个模式都包含一个产生的结果 DAG 和产生它的成本(或者相反,叫作进行替换的好处(benifit))。 一个通用的基础设施可以有效地找到并执行重写。

虽然上面提到的概念很简单,但细节很微妙。 这篇文档里定义并探索了可以解决范围广泛的不同问题的一组抽象,并预计可以应用于 MLIR 随着时间的推移将面临的许多不同类型的问题。

常量折叠(Constant Folding)

DAG 到 DAG 模式匹配的一个退化但常见的情况是常量折叠:操作数包含常量的Op通常可以折叠为结果常量值。

MLIR 的Op可能会覆盖fold来实现,与一般的 DAG 到 DAG 模式匹配器相比,它暴露了一个更简单的 API,并适用于通用的匹配器不适用的情况。 例如,DAG 重写可以删除当前函数中的任意节点,这可能会使迭代器无效。 作为 API 的常量折叠则不会删除任何节点,它只是提供一个常量值(列表)并允许客户端根据需要更新其数据结构。

关于常量折叠请看一下后面的0X3节的示例讲解,是这篇https://mlir.llvm.org/docs/Canonicalization 文档的翻译。

相关工作

考虑到几乎每个现有的编译器都必须多次解决这个问题,因此需要考虑大量相关工作。 一个统一的问题是,所有这些系统都旨在解决一个特定的、通常是狭窄的问题:另一方面,MLIR 希望在单个基础设施中解决许多这些问题。 以下是一些相关的Pattern Rewriter系统,以及它们工作的优缺点(与 MLIR 中存在的基础设施最相似的设计是 LLVM DAG-to-DAG 指令选择算法)。

  • AST 级模式匹配器:文本中存在大量的source-to-source的翻译器用来做等价变换以提升性能(比如把x*0变成0)。一个较大的例子是GCC fold函数,它对AST进行了很多优化。Clang具有应用于表达式的简单常量折叠的类似例子(C++的要求),但并不会对AST执行常见的优化。
    AST 优化器的主要缺点是我们无法看到具有多种用途的Op。 众所周知,DAG 模式匹配比树模式匹配更强大,但另一方面,DAG 模式匹配会导致重复计算。
  • 第二种就不介绍了,感兴趣可以看官方文档。
  • LLVM’s DAG-to-DAG Instruction Selection Infrastructure:LLVM 中的指令选择子系统是多年迭代和研究的结果,这是由于 LLVM 需要支持大量的目标代码生成、现代指令集(例如 X86)的代码生成器的复杂性以及狂热的追求跨目标重用代码。 Eli Bendersky 写了一篇关于它如何工作的简短概述,LLVM 文档更深入地描述了它,包括它的优点和局限性。 它允许编写这样的模式。
def : Pat<(or GR64:$src, (not (add GR64:$src, 1))),
          (BLCI64rr GR64:$src)>;

此示例为 X86 目标描述中的“blci”指令定义了一个匹配器,该文件中还有许多其他指令(查找 Pat<> 模式,因为它们没有纠缠于编译器的细节,如汇编器/反汇编器生成逻辑)。

下面说了一些LLVM的这个DAG-to-DAG 指令选择机制的好处和坏处,截图放在下方。

【从零开始学深度学习编译器】二十,MLIR的Pattern重写机制_第1张图片

小结

MLIR 面临着广泛的模式匹配和图重写问题,在多个级别上使用通用代码表示的主要优势之一是它允许投资并高度利用单一基础设施来完成此类工作。

这里后续还介绍了一些Dag重写机制的目标,包括它解决了哪些问题以及使用的匹配策略,以及良好的报错信息等等。

0x2. Pattern Rewriting : Generic DAG-to-DAG Rewriting

本文档详细介绍了 MLIR(通用 DAG 到 DAG 转换框架)中存在的模式重写基础设施的设计和 API。 该框架在整个 MLIR 中广泛用于规范化、转换(conversion)和通用变换(transformation)。

介绍

模式重写框架在很大程度上可以分解为两部分:模式定义和模式应用。

模式定义

模式是通过继承 RewritePattern 类来定义的。 此类表示 MLIR 中所有重写模式的基类,由以下组件组成:

Benefit

这是应用给定模式的预期收益。 这种收益在模式构建时是静态的,但可以在模式初始化时动态计算,例如允许从特定领域的信息(如目标架构)中获得收益。 这种限制允许执行模式融合并将模式编译成一个高效的状态机,并且 Thier、Ertl 和 Krall 已经证明,匹配谓词在几乎所有情况下都不需要动态计算成本:我们可以简单地为每个可能的收益实例化一次相同的模式,并使用谓词来保护匹配。

Root Operation Name(可选)

此模式匹配的根操作的名称。 如果指定,只有具有给定根名称的操作才需要提供matchrewrite实现。 如果没有指定,可以提供任何操作类型。 应尽可能提供根操作名称,因为它可以在应用代价模型时简化模式分析。 要匹配任何操作类型,必须提供一个特殊标签来明确意图:MatchAnyOpTypeTag

match and rewrite 实现

这是与给定根操作匹配并执行 IR 重写的代码块。 RewritePattern 可以通过单独的 match 和 rewrite 方法或通过组合的 matchAndRewrite 方法来指定此实现。 使用组合 matchAndRewrite 方法时,在匹配成功之前不应发生 IR 突变。 当匹配和重写阶段需要non-trivially的可重计算信息时,组合的 matchAndRewrite 很有用。 请参阅下面的示例:

class MyPattern : public RewritePattern {
public:
  /// This overload constructs a pattern that only matches operations with the
  /// root name of `MyOp`.
  MyPattern(PatternBenefit benefit, MLIRContext *context)
      : RewritePattern(MyOp::getOperationName(), benefit, context) {}
  /// This overload constructs a pattern that matches any operation type.
  MyPattern(PatternBenefit benefit)
      : RewritePattern(benefit, MatchAnyOpTypeTag()) {}

  /// In this section, the `match` and `rewrite` implementation is specified
  /// using the separate hooks.
  LogicalResult match(Operation *op) const override {
    // The `match` method returns `success()` if the pattern is a match, failure
    // otherwise.
    // ...
  }
  void rewrite(Operation *op, PatternRewriter &rewriter) {
    // The `rewrite` method performs mutations on the IR rooted at `op` using
    // the provided rewriter. All mutations must go through the provided
    // rewriter.
  }

  /// In this section, the `match` and `rewrite` implementation is specified
  /// using a single hook.
  LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) {
    // The `matchAndRewrite` method performs both the matching and the mutation.
    // Note that the match must reach a successful point before IR mutation may
    // take place.
  }
};

限制

在模式的match部分中,应用以下约束:

  • 不允许IR突变。
    在模式的rewriter部分中,应用以下约束:
  • 所有 IR 突变,包括创建,都必须由给定的 PatternRewriter 执行。 此类提供了用于执行模式中可能发生的所有可能突变的钩子。 例如,这意味着不应通过其erase方法来删除操作。 要删除操作,应使用适当的 PatternRewriter 钩子(在本例中为 eraseOp)。
  • 根操作必须是:inplace更新、替换或删除。

递归应用

递归是模式重写上下文中的一个重点主题,因为一个模式通常对自己的结果也是适用的。

0x3. Operation Canonicalization(操作规范化)

规范化是编译器 IR 设计的重要组成部分:它使实现可靠的编译器转换和推理代码中的优劣变得更加容易,并引发了有关特定 IR 级别目标的有趣讨论。 Dan Gohman 写了一篇文章探讨这些问题; 如果你不熟悉这些概念,则值得一读。文章地址为:https://sunfishcode.github.io/blog/2018/10/22/Canonicalization.html

大多数编译器都有规范化pass,有时它们还有许多不同类型的pass(例如 LLVM 中的 instcombine、dag combine 等)。 因为 MLIR 是一个多级 IR,我们可以提供一个单一的规范化基础设施,并在它所代表的许多不同的IR中重用它。这一节描述了通用的全局规范化方法,并提供了部分用来捕获特定于IR的规则以供参考。

通用设计

MLIR 有一个单一的规范化pass,它以贪心的方式迭代地应用规范化变换,直到IR收敛。 这些变换由Op本身定义,允许每个方言一起定义自己的Op和规范化集合。规范化Pattern需要考虑的几点:

  • Pattern的重复应用应该收敛。 不稳定或循环重写将导致规范化程序中的无限循环。
  • 当操作数重复时,朝着值使用较少的Op进行规范化通常会更好,因为某些模式仅在值具有单个user时才匹配。 例如,将“x + x”规范化为“x * 2”通常是好的,因为这会将 x 的使用次数减少一。
  • 在可能的情况下完全消除Op总是好的,例如 通过折叠已知的恒等(如“x + 0 = x”)。

全局应用规则

这些变换被应用于所有级别的IR:

  • 消除无副作用、无用处的Op。
  • 常量折叠 - 例如 “(addi 1, 2)”到“3”。 常量折叠钩子由Op指定。
  • 将常量操作数移动到右侧的可交换运算符 - 例如 “(addi 4, x)”到“(addi x, 4)”。
  • constant-like Op是唯一的,并被提升到第一个父barrier区域的入口块中。这是一个和上方隔离的区域,如函数的入口块,或者通过DialectFoldInterface上的shouldMaterializeInto方法标记为barrier的入口块。

定义Canonicalizations

有两种机制可用于定义规范化; 一般的 RewritePatterns 和 fold 方法。

Canonicalizing with RewritePattern

这种机制允许将规范化作为一组 RewritePatterns 提供,或者在 C++ 中强制定义或作为声明性重写规则(DRR)声明。 模式重写基础结构允许表达许多不同类型的规范化。 这些转换可能就像用移位替换乘法一样简单,甚至可以用无条件分支替换条件分支。

在ODS中,Op可以通过设置hasCanonicalizer位或者hasCanonicalizeMethod位以生成getCanonicalizationPatterns方法。

def MyOp : ... {
  // I want to define a fully general set of patterns for this op.
  let hasCanonicalizer = 1;
}

def OtherOp : ... {
  // A single "matchAndRewrite" style RewritePattern implemented as a method
  // is good enough for me.
  let hasCanonicalizeMethod = 1;
}

然后可以在源文件中提供规范化Pattern(这个代码是生成的):

void MyOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                       MLIRContext *context) {
  patterns.add<...>(...);
}

LogicalResult OtherOp::canonicalize(OtherOp op, PatternRewriter &rewriter) {
  // patterns and rewrites go here.
  return failure();
}

Canonicalizing with fold 方法

fold机制是一种有意限制但功能强大的机制,它允许在整个编译器的许多地方应用规范化。例如,在规范化pass之外 ,fold在Dialect Conversion基础架构中用作合法化机制,并且可以通过OpBuilder::createOrFold在任何地方使用OpBuilder直接调用。

fold 的限制是不能创建新的Op,只能替换根Op(但不能删除)。 它允许原地更新Op,或返回一组预先存在的值(或属性)以替换Op。 这确保了fold方法是一个真正的“本地”转换,并且可以在不需要Pattern Rewriter的情况下调用。

在 ODS 中,Op可以设置hasFolder位以生成fold方法的声明。 此方法采用不同的形式,具体取决于Op的结构。

def MyOp : ... {
  let hasFolder = 1;
}

如果Op只有一个结果,将生成以下内容:

/// Implementations of this hook can only perform the following changes to the
/// operation:
///
///  1. They can leave the operation alone and without changing the IR, and
///     return nullptr.
///  2. They can mutate the operation in place, without changing anything else
///     in the IR. In this case, return the operation itself.
///  3. They can return an existing value or attribute that can be used instead
///     of the operation. The caller will remove the operation and use that
///     result instead.
///
OpFoldResult MyOp::fold(ArrayRef<Attribute> operands) {
  ...
}

否则将生成下面的内容:

/// Implementations of this hook can only perform the following changes to the
/// operation:
///
///  1. They can leave the operation alone and without changing the IR, and
///     return failure.
///  2. They can mutate the operation in place, without changing anything else
///     in the IR. In this case, return success.
///  3. They can return a list of existing values or attribute that can be used
///     instead of the operation. In this case, fill in the results list and
///     return success. The results list must correspond 1-1 with the results of
///     the operation, partial folding is not supported. The caller will remove
///     the operation and use those results instead.
///
/// Note that this mechanism cannot be used to remove 0-result operations.
LogicalResult MyOp::fold(ArrayRef<Attribute> operands,
                         SmallVectorImpl<OpFoldResult> &results) {
  ...
}

在上面,为每个方法提供了一个 ArrayRef,它对应于每个操作数的常量属性值。 这些操作数是那些实现 ConstantLike 特征的操作数。 如果任何操作数是非常量,则提供 null Attribute 值。 例如,如果 MyOp 提供了三个操作数 [a, b, c],但只有 b 是常量,则操作数的格式为 [Attribute(), b-value, Attribute()]。

上面还展示了OpFoldResult的应用。此类表示fold一个op的可能结果:SSA ValueAttribute(对于常量结果)。 如果提供了 SSA Value,则它必须对应于现有值。 fold 方法不允许生成新Value。 返回的 Attribute 值的形式没有特定的限制,但重要的是要确保特定 TypeAttribute 表示形式是一致的。

当Op上的fold钩子不成功时,Dialect可以通过实现 DialectFoldInterface 并覆盖fold钩子来提供fallback。

从属性产生常量

fold 方法返回一个 Attribute 作为结果时,它表示这个结果是“常量”。 Attribute是值的常量表示。 fold 方法的使用者,例如 canonicalizer pass,将获取这些 Attributes 并在 IR 中实现常量Op来表示它们。 要启用此实现,Op的Dialect必须实现 materializeConstant 钩子。 这个钩子接受一个Attribute值,通常由fold返回,并产生一个“constant-like”的Op来表示该值。

在 ODS 中,Dialect可以设置 hasConstantMaterializer 位以生成 materializeConstant 方法的声明。

def MyDialect_Dialect : ... {
  let hasConstantMaterializer = 1;
}

然后可以在源文件中具体化常量:

/// Hook to materialize a single constant operation from a given attribute value
/// with the desired resultant type. This method should use the provided builder
/// to create the operation without changing the insertion position. The
/// generated operation is expected to be constant-like. On success, this hook
/// should return the value generated to represent the constant value.
/// Otherwise, it should return nullptr on failure.
Operation *MyDialect::materializeConstant(OpBuilder &builder, Attribute value,
                                          Type type, Location loc) {
  ...
}

你可能感兴趣的:(人工智能,深度学习)