深度学习编译器之公共子表达式消除和死代码消除实现

0x0. 前言

更多的深度学习编译器知识可以在 https://github.com/BBuf/tvm_mlir_learn 找到。同时也维护了一个cuda学习仓库 https://github.com/BBuf/how-to-optim-algorithm-in-cuda 以及一个如何学习深度学习框架(PyTorch和OneFlow)的学习仓库,https://github.com/BBuf/how-to-learn-deep-learning-framework , 有需要的小伙伴可以点一点star 。在https://github.com/BBuf/how-to-optim-algorithm-in-cuda/tree/master/large-language-model-note 这个目录下收集了一系列和LLM训练,推理相关的文章。

【省流】上次介绍了深度学习编译器之Layerout Transform优化 ,在这篇文章中提到还会介绍常量折叠优化Pass的实现,但在介绍常量折叠Pass之前我想再介绍一个类似的优化方法也就是公共子表达式消除实现(CSE)。仍然是以OneFlow中基于MLIR进行实现的CSE Pass为例子来讲解。在解析代码实现的过程中,我发现基于MLIR来做公共子表达式消除的时候还顺带做了死代码消除的功能。另外,在考虑公共子表达式消除的时候需要保证两个重复的操作处于同一个基本块中以及两个重复操作之间没有其它具有副作用的操作才可以消除。在OneFlow的实现中只是对OneFlow的UserOp的特殊属性即OpName和SymbolID进行了擦除,用一个魔法属性来代替,这是因为这两个属性不应该去影响公共子表达式的消除。这个优化还是比较有用的,在OneFlow的Stable Diffusion优化中发挥了不小的作用。

0x1. 效果

公共子表达式消除的作用很简单,就是把公共的表达式折叠为1个表达式来避免重复的计算开销。我们以OneFlow针对CSE Pass写的2个测试为例子来进行说明。这两个例子在 https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/ir/test/OneFlow/cse.mlir ,这里提供了一个 MLIR Module,包含两个函数:@Cast_1__FUSE__ScalarMulByTensor_2 和 @f2。

其中,第一个函数 @Cast_1__FUSE__ScalarMulByTensor_2 接受一个形状为 96x96xi64 的张量作为输入,并执行两个类型转换操作,将输入转换为 96x96xf32 张量。然后,它使用 oneflow.add_n 操作将两个结果张量相加,并返回结果 96x96xf32 张量。FileCheck 命令验证了具有 “ScalarMulByTensor_2” op_name 属性的 “oneflow.cast” 和 “oneflow.add_n2” 操作的存在。这里再解释一下 CHECK 指定,比如CHECK: %[[OUT:[a-zA-Z0-9_]+]] = “oneflow.cast” 是一个 FileCheck 指令,用于验证生成的代码是否符合预期。FileCheck 是 LLVM 项目的一部分,用于为编译器测试提供模式匹配功能。%[[OUT:[a-zA-Z0-9_]+]] 是一个正则表达式捕获组,用于捕获一个以 % 开头、后跟一系列字母、数字或下划线的字符串。这个字符串对应于 MLIR 中的一个值名称。“oneflow.cast” 表示我们希望找到一个名为 “oneflow.cast” 的操作。

第二个函数 @f2 接受三个输入张量:一个形状为 2x64x64x320xf16 的张量,一个形状为 320x320x3x3xf16 的张量,和一个形状为 320xf16 的张量。它将第二个输入张量转置两次,并使用转置后的张量、第一个输入张量和第三个输入张量执行两个 conv2d 操作。该函数返回两个形状为 2x64x64x320xf16 的结果张量。FileCheck 命令验证了具有等于 163 的 scope_symbol_id 属性的 “oneflow.conv2d” 操作的存在,并检查输出的两个结果张量。

这两个函数有一个共同点,那就是它们都存在一个完全相同的公共Op,我们可以编译oneflow之后使用下面的命令将CSE Pass添加到opt pass pipline里面来运行这个mlir表达式做变换,我们可以关注变换后的表达式。命令如下:

oneflow/build/oneflow/ir/bin/oneflow-opt oneflow/oneflow/ir/test/OneFlow/cse.mlir -cse-with-attributes-ignored -cse -cse-put-attributes -canonicalize

解释一下这里的几个选项:

  • cse-with-attributes-ignored: 此参数告诉优化器在执行公共子表达式消除(CSE)时忽略OneFlow IR特有的会影响CSE的属性(这里是OpName和SymbolID)。
  • cse: 这个参数开启公共子表达式消除(CSE)优化。CSE 是一种编译器优化技术,用于删除冗余的子表达式,从而减少计算量和提高程序运行速度。
  • cse-put-attributes: 此参数指示优化器在执行 CSE 之后,将原始属性放回原始操作。这有助于确保在优化过程中保留操作的属性信息。(也暗示我们必须把原始的属性保存下来)
  • canonicalize: 这个参数开启规范化优化。规范化优化会将程序中的操作和表达式转换为一种统一的标准形式,从而简化后续优化的实现和提高效率。(这两个给定的例子里,不开启canonicalize也不会影响输出IR的表达)

接下来是运行上述命令后输出的MLIR Module。

module {
  func.func @Cast_1__FUSE__ScalarMulByTensor_2(%arg0: tensor<96x96xi64>) -> tensor<96x96xf32> {
    %0 = "oneflow.cast"(%arg0) {device_name = ["0:0"], device_tag = "cpu", dtype = 2 : i32, hierarchy = [1], op_name = "Cast_1", op_type_name = "cast", pin_memory = false, scope_symbol_id = 4611686018427416574 : i64} : (tensor<96x96xi64>) -> tensor<96x96xf32>
    %1 = "oneflow.add_n2"(%0, %0) {device_name = ["0:0"], device_tag = "cpu", hierarchy = [1], op_name = "ScalarMulByTensor_2", op_type_name = "add_n", scope_symbol_id = 4611686018427416574 : i64} : (tensor<96x96xf32>, tensor<96x96xf32>) -> tensor<96x96xf32>
    return %1 : tensor<96x96xf32>
  }
  func.func @f2(%arg0: tensor<2x64x64x320xf16>, %arg1: tensor<320x320x3x3xf16>, %arg2: tensor<320xf16>) -> (tensor<2x64x64x320xf16>, tensor<2x64x64x320xf16>) {
    %0 = "oneflow.transpose"(%arg1) {device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "unet.down_blocks.0.resnets.0.conv1-conv2d-31_transpose_input_1", perm = [0 : si32, 2 : si32, 3 : si32, 1 : si32], scope_symbol_id = 163 : i64} : (tensor<320x320x3x3xf16>) -> tensor<320x3x3x320xf16>
    %1 = "oneflow.conv2d"(%arg0, %0, %arg2) {data_format = "channels_last", device_name = ["@0:0"], device_tag = "cuda", dilation_rate = [1 : si32, 1 : si32], filters = 320 : si32, groups = 1 : si32, hierarchy = [1], kernel_size = [3 : si32, 3 : si32], op_name = "unet.down_blocks.0.resnets.0.conv1-conv2d-31", operand_segment_sizes = array<i32: 1, 1, 1, 0>, padding_before = [1 : si32, 1 : si32], scope_symbol_id = 163 : i64, strides = [1 : si32, 1 : si32], tuning_cache = ""} : (tensor<2x64x64x320xf16>, tensor<320x3x3x320xf16>, tensor<320xf16>) -> tensor<2x64x64x320xf16>
    return %1, %1 : tensor<2x64x64x320xf16>, tensor<2x64x64x320xf16>
  }
}

和原始的MLIR ModuleOp对比,我们发现这两个函数里面的公共子表达式(cast和transpose)都只保留了一个,实现了公共子表达式消除的目的。在OneFlow编译器中,这个优化率先在OneFlow的Stable Diffusion引人,加速了模型的推理速度。

0x2. 原理&代码实现

基于 OneFlow 实现 CSE 的原理是,我们需要先消除 OneFlow 的 UserOp 的 OpName 和 SymbolID 这两个属性,这两个属性对 CSE 来说是没影响的,但是是由 OneFlow 系统添加的,所以我们需要做个预处理忽略掉这两个不一致。然后调用MLIR系统的 CSE Pass 之后我们需要把这个忽略的属性加回来。这样才可以保证优化后的IR可以转回OneFlow的图并正确执行。

首先基于ODS在https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/ir/include/OneFlow/OneFlowPasses.td#L156-L172 定义了两个CSE相关的Pass类,MLIR会自动生成这两个Pass的定义。我们详细看一下细节:

def CSEWithAttributesIgnored : Pass<"cse-with-attributes-ignored", "ModuleOp"> { //  定义了一个名为 "cse-with-attributes-ignored" 的 Pass,它作用在 MLIR 中的模块操作(ModuleOp)上。
  let summary = "ignore oneflow attributes to have cse work"; // summary 和 description: 提供了有关 Pass 功能的简短描述和详细说明。这个 Pass 的目的是执行 CSE 优化,同时忽略 OneFlow 属性(如操作名、符号 ID 等)。
  let description = [{
    cse and ignore oneflow attributes like op name, symbol id, etc.
  }];
  let constructor = "mlir::oneflow::createCSEWithAttributesIgnored()"; // 指定用于创建这个 Pass 的函数,即 mlir::oneflow::createCSEWithAttributesIgnored()。
  let dependentDialects = []; // 列出这个 Pass 依赖的其他方言。在这种情况下,它是空的,表示没有依赖关系。
}

def CSEPutAttributes : Pass<"cse-put-attributes", "ModuleOp"> {
  let summary = "cse and ignore oneflow attributes";
  let description = [{
    put back oneflow attributes like op name, symbol id, etc.
  }];
  let constructor = "mlir::oneflow::createCSEPutAttributes()";
  let dependentDialects = [];
}

可以看到 CSE 的预处理和后处理 Pass 主要就是实现 createCSEWithAttributesIgnored 和 createCSEPutAttributes 这两个函数。它们的定义在:https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/ir/include/OneFlow/Transform/CSEWithAttributesIgnored.h#L25-L33

// CSEState 结构体包含两个成员:
// scopeSymbolIDs:一个 llvm::DenseMap,将 Operation* 类型的指针映射到 IntegerAttr 类型的属性。这个映射可能用于存储操作的范围符号ID。
// opNames:一个 llvm::DenseMap,将 Operation* 类型的指针映射到 StringAttr 类型的属性。这个映射可能用于存储操作的名称。
struct CSEState {
  llvm::DenseMap<Operation*, IntegerAttr> scopeSymbolIDs;
  llvm::DenseMap<Operation*, StringAttr> opNames;
};
// 这个函数返回一个 std::unique_ptr 类型的对象。根据函数名称,这个函数创建一个CSE Pass,其中忽略了属性。
std::unique_ptr<mlir::Pass> createCSEWithAttributesIgnored();
// 这个函数也返回一个 std::unique_ptr 类型的对象。根据函数名称,这个函数创建一个CSE Pass,会处理或放置属性。
std::unique_ptr<mlir::Pass> createCSEPutAttributes();
// 这个函数接受一个 std::shared_ptr 类型的参数,并返回一个 std::pair,其中包含两个 std::unique_ptr 类型的对象。这个函数创建一对CSE Pass,它们共享给定的 CSEState。
std::pair<std::unique_ptr<Pass>, std::unique_ptr<Pass>> createCSEPasses(
    std::shared_ptr<CSEState> state);
// 这个函数接受一个 std::shared_ptr 类型的参数。根据函数名称,这个函数可能会注册一组CSE Pass,它们共享给定的 CSEState。
void registerCSEPasses(std::shared_ptr<CSEState> state);

接下来看下这几个 Pass 的具体实现。代码在 https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/ir/lib/OneFlow/Transform/CSEWithAttributesIgnored.cpp

首先来看createCSEWithAttributesIgnored:


struct EraseAttributes : public mlir::OpInterfaceRewritePattern<UserOpCompatible> {
  explicit EraseAttributes(mlir::MLIRContext* context, std::shared_ptr<CSEState> state)
      : OpInterfaceRewritePattern<UserOpCompatible>(context, /*benefit=*/1), state_{state} {}
  mlir::LogicalResult matchAndRewrite(UserOpCompatible op,
                                      mlir::PatternRewriter& rewriter) const override {
    if (op->getAttrOfType<StringAttr>(OpTrait::IsOpConfCompatible<void>::getOpNameAttr())
            .getValue()
            .str()
        != MAGIC_OP_NAME) {
      if (state_) {
        state_->opNames[op] =
            op->getAttrOfType<StringAttr>(OpTrait::IsOpConfCompatible<void>::getOpNameAttr());
        state_->scopeSymbolIDs[op] = op->getAttrOfType<IntegerAttr>(
            OpTrait::IsOpConfCompatible<void>::getScopeSymbolIDAttr());
      }
      op->setAttr(OpTrait::IsOpConfCompatible<void>::getOpNameAttr(),
                  rewriter.getStringAttr(MAGIC_OP_NAME));
      op->setAttr(OpTrait::IsOpConfCompatible<void>::getScopeSymbolIDAttr(),
                  rewriter.getI64IntegerAttr(MAGIC_SCOPE_SYMBOL_ID));
      return success();
    } else {
      return failure();
    }
  }

 private:
  std::shared_ptr<CSEState> state_;
};

class CSEWithAttributesIgnored : public CSEWithAttributesIgnoredBase<CSEWithAttributesIgnored> {
 public:
  explicit CSEWithAttributesIgnored() {}
  explicit CSEWithAttributesIgnored(std::shared_ptr<CSEState> state) : state_(state) {}
  void runOnOperation() override {
    Operation* op = getOperation();
    RewritePatternSet patterns(op->getContext());
    patterns.add<EraseAttributes>(op->getContext(), state_);
    (void)applyPatternsAndFoldGreedily(op, std::move(patterns));
  }

 private:
  std::shared_ptr<CSEState> state_;
};

std::unique_ptr<Pass> createCSEWithAttributesIgnored() {
  return std::make_unique<CSEWithAttributesIgnored>();
}

这段代码定义了一个 EraseAttributes 重写类, 它会移除 op 中的某些属性。它继承自 OpInterfaceRewritePattern, 意味着它可以匹配实现了 UserOpCompatible 这个 OpInterface 的 op。然后 EraseAttributes 构造函数接受一个 MLIRContext* 和一个shared_ptr。CSEState 用于跟踪已重写的 op 的属性。matchAndRewrite 方法检查 op 是否有名为 OpNameAttr 的 StringAttr 属性, 如果有, 并且其值不等于 MAGIC_OP_NAME, 则该方法会:

  • 将 op 的 OpNameAttr 和 ScopeSymbolIDAttr 属性记录在 CSEState 中。
  • 将 OpNameAttr 设置为 MAGIC_OP_NAME, 将 ScopeSymbolIDAttr 设置为 MAGIC_SCOPE_SYMBOL_ID。

然后,CSEWithAttributesIgnored 继承自 CSEWithAttributesIgnoredBase, 重写了其 runOnOperation 方法。该方法会实例化一个 RewritePatternSet, 添加 EraseAttributes 这个匹配重写模板, 然后应用该模板, 从而移除user op 中的属性。它还保存一个指向CSEState 的 shared_ptr , 可以在 EraseAttributes 中使用。注意这里的 CSEWithAttributesIgnoredBase 是通过ODS自动生成的 Pass 类定义。createCSEWithAttributesIgnored 函数会创建一个 CSEWithAttributesIgnored pass 并返回。

接着看一下 createCSEPutAttributes 的实现,

struct PutAttributes : public mlir::OpInterfaceRewritePattern<UserOpCompatible> {
  explicit PutAttributes(mlir::MLIRContext* context, std::shared_ptr<CSEState> state)
      : OpInterfaceRewritePattern<UserOpCompatible>(context, /*benefit=*/1), state_{state} {}
  mlir::LogicalResult matchAndRewrite(UserOpCompatible op,
                                      mlir::PatternRewriter& rewriter) const override {
    if (op->getAttrOfType<StringAttr>(OpTrait::IsOpConfCompatible<void>::getOpNameAttr())
            .getValue()
            .str()
        == MAGIC_OP_NAME) {
      if (state_) {
        op->setAttr(OpTrait::IsOpConfCompatible<void>::getOpNameAttr(), state_->opNames[op]);
        op->setAttr(OpTrait::IsOpConfCompatible<void>::getScopeSymbolIDAttr(),
                    state_->scopeSymbolIDs[op]);
      }
      return success();
    } else {
      return failure();
    }
  }

 private:
  std::shared_ptr<CSEState> state_;
};

class CSEPutAttributes : public CSEPutAttributesBase<CSEPutAttributes> {
 public:
  explicit CSEPutAttributes() {}
  explicit CSEPutAttributes(std::shared_ptr<CSEState> state) { state_ = state; }

  void runOnOperation() override {
    Operation* op = getOperation();
    RewritePatternSet patterns(op->getContext());
    patterns.add<PutAttributes>(op->getContext(), state_);
    (void)applyPatternsAndFoldGreedily(op, std::move(patterns));
  }

 private:
  std::shared_ptr<CSEState> state_;
};


std::unique_ptr<Pass> createCSEPutAttributes() { return std::make_unique<CSEPutAttributes>(); }

这个 PutAttributes 重写模板与 EraseAttributes 相反, 它会将先前删除的属性恢复回 op。PutAttributes 构造函数也接受一个 MLIRContext* 和一个 shared_ptr。它使用 CSEState 来查找先前删除的属性值。matchAndRewrite 方法检查 op 是否有一个名为 OpNameAttr 的 StringAttr 属性,其值等 于 MAGIC_OP_NAME 。如果是,它会从 CSEState 中查找原先的 OpNameAttr 和 ScopeSymbolIDAttr 属性值。将 OpNameAttr 设置为原先的值,将 ScopeSymbolIDAttr 设置为原先的值。

上面的2个Pass都是OneFlow中的预处理和后处理,而真的CSE Pass则是MLIR自带的CSE Pass(oneflow/build/oneflow/ir/llvm_monorepo-src/mlir/lib/Transforms/CSE.cpp), 我们来解析一下。

struct SimpleOperationInfo : public llvm::DenseMapInfo<Operation *> {
  static unsigned getHashValue(const Operation *opC) {
    return OperationEquivalence::computeHash(
        const_cast<Operation *>(opC),
        /*hashOperands=*/OperationEquivalence::directHashValue,
        /*hashResults=*/OperationEquivalence::ignoreHashValue,
        OperationEquivalence::IgnoreLocations);
  }
  static bool isEqual(const Operation *lhsC, const Operation *rhsC) {
    auto *lhs = const_cast<Operation *>(lhsC);
    auto *rhs = const_cast<Operation *>(rhsC);
    if (lhs == rhs)
      return true;
    if (lhs == getTombstoneKey() || lhs == getEmptyKey() ||
        rhs == getTombstoneKey() || rhs == getEmptyKey())
      return false;
    return OperationEquivalence::isEquivalentTo(
        const_cast<Operation *>(lhsC), const_cast<Operation *>(rhsC),
        OperationEquivalence::IgnoreLocations);
  }
};

SimpleOperationInfo 这个结构体继承自 llvm::DenseMapInfo。此结构体旨在为用于 LLVM DenseMap 中的 Operation 对象提供自定义的哈希和相等性函数。它重载了两个方法:

  • getHashValue: 为 Operation* 计算哈希值。它使用 OperationEquivalence::computeHash 来计算哈希值,并传递 hashOperands=directHashValue 和 hashResults=ignoreHashValue。这意味着它会直接对 op 的操作数计算哈希值,但会忽略结果。
  • isEqual: 检查两个 Operation* 是否相等。它首先检查是否是相同的 op , 如果是,则返回 true。否则,它使用OperationEquivalence::isEquivalentTo 检查两个 op 是否等价。同样,它传递了 IgnoreLocations, 意味着它会忽略 op 的位置信息。

所以, 这个 DenseMapInfo 允许以忽略结果和位置的方式将 Operation* 用作 DenseMap 的键。操作数用于等价性检查和哈希值计算。

/// Simple common sub-expression elimination.
// 这是一个名为CSE(Common Sub-expression Elimination,公共子表达式消除)的结构体定义,用于执行简单的公共子表达式消除。CSE是一种编译器优化技术,用于消除程序中的重复计算,提高执行效率。
struct CSE : public impl::CSEBase<CSE> {
  /// Shared implementation of operation elimination and scoped map definitions.
  // 使用AllocatorTy和ScopedMapTy来定义分配器和作用域映射。ScopedMapTy是一个散列表,用于存储操作之间的映射关系。
  using AllocatorTy = llvm::RecyclingAllocator<
      llvm::BumpPtrAllocator,
      llvm::ScopedHashTableVal<Operation *, Operation *>>;
  using ScopedMapTy = llvm::ScopedHashTable<Operation *, Operation *,
                                            SimpleOperationInfo, AllocatorTy>;

  /// Cache holding MemoryEffects information between two operations. The first
  /// operation is stored has the key. The second operation is stored inside a
  /// pair in the value. The pair also hold the MemoryEffects between those
  /// two operations. If the MemoryEffects is nullptr then we assume there is
  /// no operation with MemoryEffects::Write between the two operations.
  // MemEffectsCache 用于在两个操作之间缓存 MemoryEffects 信息。MemoryEffects 表示某个操作对内存的影响。
  using MemEffectsCache =
      DenseMap<Operation *, std::pair<Operation *, MemoryEffects::Effect *>>;

  /// Represents a single entry in the depth first traversal of a CFG.
  // CFGStackNode结构体表示控制流图(CFG)深度优先遍历中的一个节点。包括作用域、节点、子节点迭代器等信息。
  struct CFGStackNode {
    CFGStackNode(ScopedMapTy &knownValues, DominanceInfoNode *node)
        : scope(knownValues), node(node), childIterator(node->begin()) {}

    /// Scope for the known values.
    ScopedMapTy::ScopeTy scope;

    DominanceInfoNode *node;
    DominanceInfoNode::const_iterator childIterator;

    /// If this node has been fully processed yet or not.
    bool processed = false;
  };

  /// Attempt to eliminate a redundant operation. Returns success if the
  /// operation was marked for removal, failure otherwise.
  // simplifyOperation 函数尝试消除冗余操作。如果操作被标记为移除,则返回成功,否则返回失败。
  LogicalResult simplifyOperation(ScopedMapTy &knownValues, Operation *op,
                                  bool hasSSADominance);
  // simplifyBlock函数简化指定的基本块(Block)。
  void simplifyBlock(ScopedMapTy &knownValues, Block *bb, bool hasSSADominance);
  // simplifyRegion函数简化指定的区域(Region)。
  void simplifyRegion(ScopedMapTy &knownValues, Region &region);
	
	// runOnOperation函数是重写的基类方法,用于执行CSE优化。
  void runOnOperation() override;

private:
	// replaceUsesAndDelete函数用于替换操作的使用和删除操作。
  void replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op,
                            Operation *existing, bool hasSSADominance);

  /// Check if there is side-effecting operations other than the given effect
  /// between the two operations.
  // hasOtherSideEffectingOpInBetween函数检查给定操作之间是否存在其他具有副作用的操作。
  bool hasOtherSideEffectingOpInBetween(Operation *fromOp, Operation *toOp);
	
  /// Operations marked as dead and to be erased.
  // opsToErase是一个用于存储将要删除的操作的向量。
  std::vector<Operation *> opsToErase;
  // domInfo是一个指向支配信息(DominanceInfo)的指针。
  DominanceInfo *domInfo = nullptr;
  // memEffectsCache是一个缓存,用于存储操作之间的内存效果信息。
  MemEffectsCache memEffectsCache;
};
} // namespace

我们先看一下核心的runOperation方法。

void CSE::runOnOperation() {
  /// A scoped hash table of defining operations within a region.
  // 定义一个名为knownValues的局部变量。它是一个作用域内的哈希表,用于存储在一个区域内定义的操作。
  ScopedMapTy knownValues;
	
	// 从DominanceInfo分析中获取支配关系信息,并将其存储在名为domInfo的变量中。
  domInfo = &getAnalysis<DominanceInfo>();
  // 获取当前操作(rootOp),并遍历其所有区域。对每个区域执行简化操作(simplifyRegion)。
  Operation *rootOp = getOperation();

  for (auto &region : rootOp->getRegions())
    simplifyRegion(knownValues, region);
	
	// 如果opsToErase(要删除的操作)为空,说明没有操作被删除,因此保留所有分析。
  // If no operations were erased, then we mark all analyses as preserved.
  if (opsToErase.empty())
    return markAllAnalysesPreserved();

  /// Erase any operations that were marked as dead during simplification.
  // 如果opsToErase中有操作,遍历opsToErase并删除其中的操作。然后清空opsToErase。
  for (auto *op : opsToErase)
    op->erase();
  opsToErase.clear();

  // We currently don't remove region operations, so mark dominance as
  // preserved.
  // 由于当前代码不会删除区域操作,因此将支配关系信息(DominanceInfo)和后支配关系信息(PostDominanceInfo)标记为已保留。将domInfo设置为nullptr。
  markAnalysesPreserved<DominanceInfo, PostDominanceInfo>();
  domInfo = nullptr;
}

这里首先会获取当前 ModuleOp 中 Region 里的支配关系,以便后续执行完 CSE 之后删除 Op 后可以更新支配信息。这里的重点是 simplifyRegion 函数,这是执行 CSE 的具体细节。这个函数主要使用支配树遍历区域中的基本块,并调用 simplifyBlock() 函数对每个基本块进行简化。

// 函数接受一个类型为ScopedMapTy的引用knownValues和一个类型为Region的引用region作为参数。
void CSE::simplifyRegion(ScopedMapTy &knownValues, Region &region) {
  // If the region is empty there is nothing to do.
  if (region.empty())
    return;
	// 判断区域是否具有SSA支配关系(Static Single Assignment Dominance),并将结果存储在变量hasSSADominance中。
  bool hasSSADominance = domInfo->hasSSADominance(&region);

  // If the region only contains one block, then simplify it directly.
  // 如果区域只包含一个基本块,那么直接对其进行简化。创建一个名为scope的ScopedMapTy::ScopeTy对象,然后调用simplifyBlock()函数对该基本块进行简化。
  if (region.hasOneBlock()) {
    ScopedMapTy::ScopeTy scope(knownValues);
    simplifyBlock(knownValues, &region.front(), hasSSADominance);
    return;
  }

  // If the region does not have dominanceInfo, then skip it.
  // TODO: Regions without SSA dominance should define a different
  // traversal order which is appropriate and can be used here.
  // 如果区域没有支配关系信息(hasSSADominance为false),则跳过它。此处提到了一个TODO:对于没有SSA支配关系的区域,应该定义一个不同的遍历顺序。
  if (!hasSSADominance)
    return;

  // Note, deque is being used here because there was significant performance
  // gains over vector when the container becomes very large due to the
  // specific access patterns. If/when these performance issues are no
  // longer a problem we can change this to vector. For more information see
  // the llvm mailing list discussion on this:
  // http://lists.llvm.org/pipermail/llvm-commits/Week-of-Mon-20120116/135228.html
  // 定义一个名为stack的std::deque容器,用于存储CFGStackNode的std::unique_ptr。这里使用deque是因为它在容器变大时具有更好的性能表现。
  std::deque<std::unique_ptr<CFGStackNode>> stack;

  // Process the nodes of the dom tree for this region.
  // 处理这个区域的支配树节点。将区域的根节点压入栈中。
  stack.emplace_back(std::make_unique<CFGStackNode>(
      knownValues, domInfo->getRootNode(&region)));
	// 当栈不为空时,执行以下循环操作:
  while (!stack.empty()) {
    // 获取栈顶的当前节点(currentNode)。
    auto &currentNode = stack.back();

    // Check to see if we need to process this node.
    // 检查当前节点是否需要被处理。如果未处理,则将其标记为已处理,并调用simplifyBlock()函数对当前节点所在的基本块进行简化。
    if (!currentNode->processed) {
      currentNode->processed = true;
      simplifyBlock(knownValues, currentNode->node->getBlock(),
                    hasSSADominance);
    }

    // Otherwise, check to see if we need to process a child node.
    // 检查是否需要处理子节点。如果当前节点的子节点迭代器未到达末尾,将子节点压入栈中。
    if (currentNode->childIterator != currentNode->node->end()) {
      auto *childNode = *(currentNode->childIterator++);
      stack.emplace_back(
          std::make_unique<CFGStackNode>(knownValues, childNode));
    } else {
      // Finally, if the node and all of its children have been processed
      // then we delete the node.
      // 如果当前节点及其所有子节点都已处理完毕,则将节点从栈中弹出。
      stack.pop_back();
    }
  }
}

函数的执行流程请看注释,到这一步之后CSE的具体实现实际上就在 simplifyBlock 函数了,我们继续追踪。函数接受一个类型为 ScopedMapTy 的引用 knownValues,一个类型为 Block 的指针 bb,以及一个布尔值 hasSSADominance 作为参数。从代码中可以推测,该函数的目的是简化一个给定的基本块。

void CSE::simplifyBlock(ScopedMapTy &knownValues, Block *bb,
                        bool hasSSADominance) {
  // 遍历基本块bb中的所有操作(op)
  for (auto &op : *bb) {
    // Most operations don't have regions, so fast path that case.
    // 检查操作是否包含区域。如果操作包含区域,执行以下操作:
    if (op.getNumRegions() != 0) {
      // If this operation is isolated above, we can't process nested regions
      // with the given 'knownValues' map. This would cause the insertion of
      // implicit captures in explicit capture only regions.
      // 如果操作具有IsIsolatedFromAbove特性,那么我们不能使用给定的knownValues映射来处理嵌套区域,
      // 因为这可能导致在仅显式捕获的区域中插入隐式捕获。在这种情况下,创建一个新的nestedKnownValues映射,
      // 并对操作的每个区域调用simplifyRegion()函数。
      if (op.mightHaveTrait<OpTrait::IsIsolatedFromAbove>()) {
        ScopedMapTy nestedKnownValues;
        for (auto &region : op.getRegions())
          simplifyRegion(nestedKnownValues, region);
      } else {
        // Otherwise, process nested regions normally.
        // 如果操作没有IsIsolatedFromAbove特性,那么正常处理嵌套区域。
        // 对操作的每个区域调用simplifyRegion()函数,传入knownValues映射。
        for (auto &region : op.getRegions())
          simplifyRegion(knownValues, region);
      }
    }
		// 如果操作被简化(调用simplifyOperation()函数并检查其返回值),则不处理操作包含的任何区域,继续处理下一个操作。
    // If the operation is simplified, we don't process any held regions.
    if (succeeded(simplifyOperation(knownValues, &op, hasSSADominance)))
      continue;
  }
  // Clear the MemoryEffects cache since its usage is by block only.
  // 在处理完所有操作后,清空memEffectsCache,因为它的使用仅限于单个基本块。
  memEffectsCache.clear();
}

在 simplifyBlock 中会进一步调用到 simplifyOperation 来对 Operation 做优化。我们最后跟进这个函数看一下。
函数的参数和 simplifyBlock 一样,接受一个类型为 ScopedMapTy 的引用 knownValues,一个类型为 Operation 的指针op,以及一个布尔值 hasSSADominance 作为参数。

/// Attempt to eliminate a redundant operation.
LogicalResult CSE::simplifyOperation(ScopedMapTy &knownValues, Operation *op,
                                     bool hasSSADominance) {
  // Don't simplify terminator operations.
  // 如果操作是终止操作(具有IsTerminator特性),则不对其进行简化。
  if (op->hasTrait<OpTrait::IsTerminator>())
    return failure();

  // If the operation is already trivially dead just add it to the erase list.
  // 如果操作已经是无关紧要的死代码,将其添加到待擦除操作列表opsToErase中,增加死代码消除计数,然后返回成功。
  if (isOpTriviallyDead(op)) {
    opsToErase.push_back(op);
    ++numDCE;
    return success();
  }

  // Don't simplify operations with regions that have multiple blocks.
  // TODO: We need additional tests to verify that we handle such IR correctly.
  // 不简化具有多个基本块的区域中的操作。这里提到了一个TODO:需要额外的测试来验证处理此类IR的正确性。
  if (!llvm::all_of(op->getRegions(), [](Region &r) {
        return r.getBlocks().empty() || llvm::hasSingleElement(r.getBlocks());
      }))
    return failure();

  // Some simple use case of operation with memory side-effect are dealt with
  // here. Operations with no side-effect are done after.
  // 首先处理具有内存副作用的简单操作。没有副作用的操作会在后面处理。
  if (!isMemoryEffectFree(op)) {
    auto memEffects = dyn_cast<MemoryEffectOpInterface>(op);
    // TODO: Only basic use case for operations with MemoryEffects::Read can be
    // eleminated now. More work needs to be done for more complicated patterns
    // and other side-effects.
    // 如果操作不是无内存副作用的,尝试获取其MemoryEffectOpInterface。
    // 如果操作没有MemoryEffectOpInterface,或者它不仅仅具有MemoryEffects::Read副作用,则返回失败。
    if (!memEffects || !memEffects.onlyHasEffect<MemoryEffects::Read>())
      return failure();

    // Look for an existing definition for the operation.
    // 查找操作的现有定义。如果找到现有定义,并且操作在同一个基本块中,并且两者之间没有其它具有副作用的操作,
    // 则可以删除冗余操作。调用replaceUsesAndDelete()函数替换使用并删除操作。
    if (auto *existing = knownValues.lookup(op)) {
      if (existing->getBlock() == op->getBlock() &&
          !hasOtherSideEffectingOpInBetween(existing, op)) {
        // The operation that can be deleted has been reach with no
        // side-effecting operations in between the existing operation and
        // this one so we can remove the duplicate.
        replaceUsesAndDelete(knownValues, op, existing, hasSSADominance);
        return success();
      }
    }
    // 将操作插入knownValues映射中,并返回失败。
    knownValues.insert(op, op);
    return failure();
  }

  // Look for an existing definition for the operation.
  // 查找操作的现有定义。如果找到现有定义,调用replaceUsesAndDelete()函数替换使用并删除操作,
  // 增加公共子表达式消除计数,并返回成功。
  if (auto *existing = knownValues.lookup(op)) {
    replaceUsesAndDelete(knownValues, op, existing, hasSSADominance);
    ++numCSE;
    return success();
  }

  // Otherwise, we add this operation to the known values map.
  // 否则,将此操作添加到knownValues映射中,并返回失败。
  knownValues.insert(op, op);
  return failure();
}

我们可以看到在 simplifyOperation 中,不仅仅包含公共子表达式消除(CSE),而且包含了死代码消除(DCE)。此外,在处理 Operation 时,它会考虑 Operation 的内存副作用以及 Operation 是否在具有多个基本块的区域中。

0x3. 总结

在阅读代码实现的过程中,我发现基于MLIR来做公共子表达式消除的时候还顺带做了死代码消除的功能。另外,在考虑公共子表达式消除的时候需要保证两个重复的操作处于同一个基本块中以及两个重复操作之间没有其它具有副作用的操作才可以消除。在OneFlow的实现中只是对OneFlow的UserOp的特殊属性即OpName和SymbolID进行了擦除,用一个魔法属性来代替,这是因为这两个属性不应该去影响公共子表达式的消除。这个优化还是比较有用的,在OneFlow的Stable Diffusion优化中发挥了不小的作用。

0x4. 相关链接

  • TVM的CSE Pass实现解析:https://blog.csdn.net/Eurypterid/article/details/123118666

你可能感兴趣的:(深度学习,人工智能,机器学习)