XLA all reduce combiner pass 分析

这个pass是hlo层对多个all reduce instruction判断是否需要进行合并的优化pass.也就是tensor fusion了。
首先有一个结构体:

using InstructionGroups =
    std::vector>>;

可以看到是三个vector的嵌套,乍一看不知道是干啥的,所以从创造他的函数CreateComputationGroups入手分析一下:
这个函数首先遍历了一下computation的所有all reduce instruction.然后创建了一个 opcode_groups.

std::map> opcode_groups;

这个是对于不同all reduce的类型(sum, mean 等)分组,比较容易理解。
接下来又基于opcode_groups创建了一个all_reduce_sets:

std::map>>
      all_reduce_sets;
  int64 group_id = 0;
  for (auto& domain_groups : opcode_groups) {
    for (HloInstruction* hlo : domain_groups.second) {
      all_reduce_sets[channel_id(hlo)].emplace_back(group_id, hlo);
    }
    ++group_id;
  }

对每一个op_code group按照遍历升序给了一个group_id。然后又按照all reduce instruction的 all_reduce_id分组。这次分组元素里不仅仅是instruction的指针了,而是group_id,instruction指针的pair.

再紧接着创建了一个all_reduce的group map:

  std::map, std::vector>>
      all_reduce_group_map;
  for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) {
    if (instruction->opcode() != HloOpcode::kAllReduce) {
      continue;
    }
    if (instruction->to_apply()->instruction_count() != 3 ||
        instruction->to_apply()->num_parameters() != 2) {
      VLOG(1) << "Skipping due to non-trivial reduction function.";
      continue;
    }

    int64 arid = channel_id(instruction);
    if (all_reduce_sets.count(arid) == 0) {
      // Already processed.
      continue;
    }

    std::vector group_ids;
    std::vector instructions;
    for (const auto& hlo : all_reduce_sets[arid]) {
      group_ids.push_back(hlo.first);
      instructions.push_back(hlo.second);
    }
    all_reduce_group_map[group_ids].push_back(std::move(instructions));
    all_reduce_sets.erase(arid);
  }
  CHECK(all_reduce_sets.empty());

这个map的key是group id的序列,value是instruction的指针的二维数组。最后那一维数组中的所有instruction都是属于同一个all_reduce id的。

最后整个函数返回了 InstructionGroups groups;

  InstructionGroups groups;
  for (const auto& all_reduce_group : all_reduce_group_map) {
    groups.push_back(all_reduce_group.second);
  }
  return std::move(groups);
}

可以看到InstructionGroups这个结构体的实际含义是一个数组,数组的每个元素代表的是具有相同group_id序列的instruction组成的二维数组。二维数组的每一行的所有instruction都有相同的all_reduce id.
每一列的所有instruction都有相同的group_id.

你可能感兴趣的:(XLA all reduce combiner pass 分析)