代码位置在 paddle\fluid\framework\paddle2cinn\build_cinn_pass_test.cc
,因为paddle CINN和PIR部分依旧在高频更新,所以各位看到的可能和我的不一样
inline bool CheckNodeExisted(const std::unordered_set& nodes,
const std::string& op_name) {
return std::find_if(nodes.begin(), nodes.end(), [&op_name](const Node* node) {
return node->Name() == op_name;
}) != nodes.end();
}
用一个内联函数, 去看一个 unordered_set
(一系列节点) 中是否有某个 node 的名字是 op_name,用 std::find_if
去实现, 第三个参数传入的是匿名函数。[&op_name]
闭包被定义在Lambda表达式声明中的方括号[]内. 这个机制允许这些变量被按值或按引用捕获.
函数匿名函数的闭包可以参考这篇文章: https://www.cnblogs.com/pzhfei/archive/2013/01/14/lambda_expression.html
接下来就是返回名字为 op_name
的 node
数量
inline int CountNode(const std::unordered_set& nodes,
const std::string& op_name) {
return std::count_if(
nodes.begin(), nodes.end(), [&op_name](const Node* node) {
return node->Name() == op_name;
});
}
接下来是返回节点名字是 op_name
的 节点,注意 std::find_if
前面为啥有 *
呢,因为 find_if
返回一个迭代器, *迭代器
可以返回一个 Node*
inline Node* GetNode(const std::unordered_set& nodes,
const std::string& op_name) {
return *std::find_if(
nodes.begin(), nodes.end(), [&op_name](const Node* node) {
return node->Name().find(op_name) != std::string::npos;
});
}
CheckGraphIndependence
内部定义了一个 check_node_ok
匿名函数,匿名函数中 n1
和 n2
都是节点 Node 的指针,
( 说明一下,Paddle PIR之前的节点,节点既有 Op, 也有 Var )
只有 n1
和 n2
一个为 OP, 一个为 Var 才有可能返回 true;
inline bool CheckGraphIndependence(const std::unordered_set& nodes) {
auto check_node_ok = [&nodes](Node* n1, Node* n2) -> bool {
if (n1->IsOp() && !n2->IsVar()) {
return false;
}
if (n1->IsVar() && !n2->IsOp()) {
return false;
}
if (nodes.count(n2) == 0) {
return false;
}
return true;
};
for (auto node : nodes) {
for (auto in : node->inputs) {
if (!check_node_ok(node, in)) {
return false;
}
}
for (auto out : node->outputs) {
if (!check_node_ok(node, out)) {
return false;
}
}
}
return true;
}
这里需要说明一下,由于 Paddle pir之前 Op 和 Var 都是node, 所以这样定义
var1 -> op1 -> var2
op3-> var3 -> op4
op1的输入是 var1,输出是 var2,而下边那一行是
va3 的输入是 op3,var3 的输出是 op4 , 这样写有点儿诡异,不过确实是这样定义的
所以 CheckGraphIndependence
的用法就是,首先检查是不是 op->var
和 var->op
的关系,其次就是看当前 op/var
在不在当前 Graph 的 unordered_set
中
可以看到之后的调用就是将计算图的节点 g->Nodes()
传入 CheckGraphIndependence
,如果返回值不为 True
则报错
ASSERT_TRUE(CheckGraphIndependence(g->Nodes()));
这个函数主要是将 kCinnLaunchOp
的 operators::kCompilationKey
属性取出来扔到 compilation_keys
这个 vector
中, 目前暂时未知有什么用
// Get compilation_key values
std::vector GetCompilationKeys(const Graph& graph) {
std::vector compilation_keys;
for (auto& node : graph.Nodes()) {
if (node->IsOp() && node->Name() == kCinnLaunchOp) {
compilation_keys.emplace_back(PADDLE_GET_CONST(
int64_t, node->Op()->GetAttr(operators::kCompilationKey)));
}
}
return compilation_keys;
}
接下来创建一个CINN子图,创建一个空图 Graph
, 之后依次添加 op 和 var
std::unique_ptr BuildNoCinnSubgraph() {
ProgramDesc prog;
auto g = std::make_unique(prog);
// var1 --
// | --> fake1 --> var3 --> fake2 --> var4
// var2 --
// *Desc 是之后用来创建 OpNode 和 VarNode 的类
OpDesc fake1_op;
fake1_op.SetType("fake1");
OpDesc fake2_op;
fake2_op.SetType("fake2");
VarDesc var1("var1");
VarDesc var2("var2");
var2.SetPersistable(true);
var2.SetIsParameter(true);
VarDesc var3("var3");
VarDesc var4("var4");
// 之后用 graph 的 Create*Node 来创建对应的 ir::Node
ir::Node* fake1 = g->CreateOpNode(&fake1_op);
ir::Node* fake2 = g->CreateOpNode(&fake2_op);
ir::Node* v1 = g->CreateVarNode(&var1);
ir::Node* v2 = g->CreateVarNode(&var2);
ir::Node* v3 = g->CreateVarNode(&var3);
ir::Node* v4 = g->CreateVarNode(&var4);
// ----------- 创建完 node 之后, 把 op/var 串起来
// fill op node
fake1->inputs = {v1, v2};
fake1->outputs = {v3};
fake2->inputs = {v3};
fake2->outputs = {v4};
// fill variable node
v1->outputs = {fake1};
v2->outputs = {fake1};
v3->inputs = {fake1};
v3->outputs = {fake2};
v4->inputs = {fake2};
return g;
}
接下来出现第一个单测
TEST(BuildCinnPassTest, NoCinnSubgraph) {
auto g = BuildNoCinnSubgraph(); // 调用上边的函数建计算图
auto previous_nodes = g->Nodes(); // 取出计算图的节点
// 创建 pass 这个应该是旧IR的pass
auto pass =
paddle::framework::ir::PassRegistry::Instance().Get("build_cinn_pass");
// g.get() 返回的是图的指针, g是个 unique_ptr 的智能指针
pass->Apply(g.get());
// After search, origin graph should no change
// 注释的意思是, pass search 之后, 原来的计算图不应当修改
ASSERT_EQ(previous_nodes, g->Nodes());
ASSERT_TRUE(CheckGraphIndependence(g->Nodes())); // 接下来看计算图是否合法且不依赖其他计算图
// After search, there should be no cinn subgraph
ASSERT_TRUE(GetCompilationKeys(*g).empty()); // pass search之后没有 cinn subgraph 子图怎么理解
}
接下来依旧是 BuildAllOpSupportCinnGraph
与上一个建图的函数没啥太大区别
fake2
变成了 elementwise_add
| mul
| relu
std::unique_ptr BuildAllOpSupportCinnGraph() {
ProgramDesc prog;
auto g = std::make_unique(prog);
// v1 --
// | --> mul --> v3 --
// v2 -- | --> add --> v5 --> relu --> v6
// v4 --
OpDesc add_op;
add_op.SetType("elementwise_add");
OpDesc mul_op;
mul_op.SetType("mul");
OpDesc relu_op;
relu_op.SetType("relu");
VarDesc var1("var1");
VarDesc var2("var2");
var2.SetPersistable(true);
var2.SetIsParameter(true);
VarDesc var3("var3");
VarDesc var4("var4");
VarDesc var5("var5");
VarDesc var6("var6");
ir::Node* add = g->CreateOpNode(&add_op);
ir::Node* mul = g->CreateOpNode(&mul_op);
ir::Node* relu = g->CreateOpNode(&relu_op);
ir::Node* v0 = g->CreateEmptyNode("var0", Node::Type::kVariable); // 创建空节点用意是?
ir::Node* v1 = g->CreateVarNode(&var1);
ir::Node* v2 = g->CreateVarNode(&var2);
ir::Node* v3 = g->CreateVarNode(&var3);
ir::Node* v4 = g->CreateVarNode(&var4);
ir::Node* v5 = g->CreateVarNode(&var5);
ir::Node* v6 = g->CreateVarNode(&var6);
ir::Node* v7 = g->CreateControlDepVar();
// fill op node
mul->inputs = {v0, v1, v2};
mul->outputs = {v3};
add->inputs = {v3, v4};
add->outputs = {v5};
relu->inputs = {v5};
relu->outputs = {v6, v7};
// fill variable node
v0->outputs = {mul};
v1->outputs = {mul};
v2->outputs = {mul};
v3->inputs = {mul};
v3->outputs = {add};
v4->outputs = {add};
v5->inputs = {add};
v5->outputs = {relu};
v6->inputs = {relu};
v7->inputs = {relu};
return g;
}
上边这个注释有点儿问题:
// v1 --
// | --> mul --> v3 --
// v2 -- | --> add --> v5 --> relu --> v6
// v4 --
应该改成:
// v0 --|
// v1 --|
// v2 --| --> mul --> v3 --|
// --> v4 --| --> add --> v5 --> relu --> v6
// --> v7
接下来的 TEST 和之前的一样,只不过由于图结构变化,pass 之后图结构都变化为 kCinnLaunchOp
TEST(BuildCinnPassTest, AllOpSupportCinn) {
auto g = BuildAllOpSupportCinnGraph();
auto pass =
paddle::framework::ir::PassRegistry::Instance().Get("build_cinn_pass");
pass->Apply(g.get());
// After search, the graph should as following
// v0 --|
// v1 --| |--> v6
// v2 --| --> kCinnLaunchOp |--> v7
// v4 --|
const auto& nodes = g->Nodes();
ASSERT_EQ(nodes.size(), static_cast(7)); // 节点数为 7, 4个输入, 2个输出 和 1 个 Op 节点
ASSERT_TRUE(CheckGraphIndependence(nodes)); // 检测该图是否独立,是否会依赖其他图
// A new op named kCinnLaunchOp should be added
ASSERT_TRUE(CheckNodeExisted(nodes, kCinnLaunchOp)); // kCinnLaunchOp 是个常量字符串, 检测节点 vector 中有无 kCinnLaunchOp
auto* cinn_op = GetNode(nodes, kCinnLaunchOp);
auto* v0 = GetNode(nodes, "var0");
auto* v1 = GetNode(nodes, "var1"); // 依次获取对应的 var Node 指针
auto* v2 = GetNode(nodes, "var2");
auto* v4 = GetNode(nodes, "var4");
auto* v6 = GetNode(nodes, "var6");
auto* v7 = GetNode(nodes, Node::kControlDepVarName);
// 查看 cinn_op 的输入输出是否与 `v0, v1, v2, v4` 和 `v6, v7` 对应
ASSERT_EQ(
std::unordered_set(cinn_op->inputs.begin(), cinn_op->inputs.end()),
std::unordered_set({v0, v1, v2, v4}));
ASSERT_EQ(std::unordered_set(cinn_op->outputs.begin(),
cinn_op->outputs.end()),
std::unordered_set({v6, v7}));
// 查看 var 节点的输入输出是否是 cinn_op
ASSERT_EQ(v1->outputs, std::vector({cinn_op}));
ASSERT_EQ(v6->inputs, std::vector({cinn_op}));
// previous op (mul, add, relu) should all removed
// 由于 mul/elementwise_add/relu 被整体合并为 cinn_op 所以图中不应该被搜索到
ASSERT_FALSE(CheckNodeExisted(nodes, "mul"));
ASSERT_FALSE(CheckNodeExisted(nodes, "elementwise_add"));
ASSERT_FALSE(CheckNodeExisted(nodes, "relu"));
// After search, there should has just one cinn subgraph
// feed --> v1 --
// | --> mul --> v3 --
// feed --> v2 -- | --> add --> v5 --> relu --> v6 --> fetch
// feed --> v4 --
// 获取编译完毕之后的 key, 之后会根据 key 去取对应的 subgraph
auto compilation_keys = GetCompilationKeys(*g);
ASSERT_EQ(compilation_keys.size(), static_cast(1)); // 因为只有一个 kCinnLaunchOp 所以 key 的数量也为 1
auto* cinn_compiler = CinnCompiler::GetInstance();
const auto& subgraph = cinn_compiler->FindGraph(compilation_keys[0]); // 根据 key 拿对应的子图
const auto& subnodes = subgraph.Nodes(); // 拿子图的节点set
ASSERT_EQ(subnodes.size(), static_cast(13));
ASSERT_TRUE(CheckGraphIndependence(subnodes));
// 该 cinn op 就是这三 mul | elementwise_add | relu 的合体
ASSERT_TRUE(CheckNodeExisted(subnodes, "mul"));
ASSERT_TRUE(CheckNodeExisted(subnodes, "elementwise_add"));
ASSERT_TRUE(CheckNodeExisted(subnodes, "relu"));
ASSERT_EQ(CountNode(subnodes, "feed"), 3); // 上边注释有 3个feed Op
ASSERT_EQ(CountNode(subnodes, "fetch"), 1); // 1 个 fetch Op
// 在 kCinnLaunchOp 中有参和无参的 node 都应当有 feed Op
// No-parameter input should has feed op
auto new_v1 = GetNode(subnodes, "var1");
ASSERT_EQ(new_v1->inputs.size(), static_cast(1));
ASSERT_EQ(new_v1->outputs.size(), static_cast(1));
ASSERT_EQ(new_v1->inputs[0]->Name(), "feed");
ASSERT_EQ(new_v1->outputs[0]->Name(), "mul");
// Parameter input should also have the feed op
auto new_v2 = GetNode(subnodes, "var2");
ASSERT_EQ(new_v2->inputs.size(), static_cast(1));
ASSERT_EQ(new_v2->inputs[0]->Name(), "feed");
ASSERT_EQ(new_v2->outputs.size(), static_cast(1));
ASSERT_EQ(new_v2->outputs[0]->Name(), "mul");
// kCinnLaunchOp 输出中应当有 fetch Op
// output should has fetch op
auto new_v6 = GetNode(subnodes, "var6");
ASSERT_EQ(new_v6->inputs.size(), static_cast(1));
ASSERT_EQ(new_v6->outputs.size(), static_cast(1));
ASSERT_EQ(new_v6->inputs[0]->Name(), "relu");
ASSERT_EQ(new_v6->outputs[0]->Name(), "fetch");
}
第一个单测是只有 fake Op 没办法 pass 优化,第二个单测是所有Op 都支持 CINN Pass, 那下一个就是一半是 fake Op,另一半是 只是 CINN Pass 的 OP
std::unique_ptr BuildGraphWithOneCinnSubgraph() {
ProgramDesc prog;
auto g = std::make_unique(prog);
// fake1 --> v1 --
// | --> mul --> v3 --> relu --> v4 --> fake2
// v2 --
OpDesc fake1_op;
fake1_op.SetType("fake1");
OpDesc mul_op;
mul_op.SetType("mul");
OpDesc relu_op;
relu_op.SetType("relu");
OpDesc fake2_op;
fake2_op.SetType("fake2");
VarDesc var1("var1");
VarDesc var2("var2");
var2.SetPersistable(true);
var2.SetIsParameter(true);
VarDesc var3("var3");
VarDesc var4("var4");
ir::Node* fake1 = g->CreateOpNode(&fake1_op);
ir::Node* mul = g->CreateOpNode(&mul_op);
ir::Node* relu = g->CreateOpNode(&relu_op);
ir::Node* fake2 = g->CreateOpNode(&fake2_op);
ir::Node* v1 = g->CreateVarNode(&var1);
ir::Node* v2 = g->CreateVarNode(&var2);
ir::Node* v3 = g->CreateVarNode(&var3);
ir::Node* v4 = g->CreateVarNode(&var4);
// fill op node
fake1->outputs = {v1};
mul->inputs = {v2, v1};
mul->outputs = {v3};
relu->inputs = {v3};
relu->outputs = {v4};
fake2->inputs = {v4};
// fill variable node
v2->outputs = {mul};
v1->inputs = {fake1};
v1->outputs = {mul};
v3->inputs = {mul};
v3->outputs = {relu};
v4->inputs = {relu};
v4->outputs = {fake2};
return g;
}
上边的函数就是建立了一个这样的一个图
// fake1 --> v1 --
// | --> mul --> v3 --> relu --> v4 --> fake2
// v2 --
通过 cinn pass 之后这个图的节点变成下边儿这样:
// fake1 --> v1 --
// | --> kCinnLaunchOp --> v4 --> fake2
// v2 --
只有一个 kCinnLaunchOp 其子图为,有9个节点
// feed --> v1 --
// | --> mul --> v3 --> relu --> v4 --> fetch
// feed --> v2 --
之前的图是单个 cinn op,下一个单测是多个 cinn op 的情况:
std::unique_ptr BuildGraphWithMultiCinnSubgraph() {
ProgramDesc prog;
auto g = std::make_unique(prog);
// fake1 --> v1 --
// | --> mul --> v3 --> fake2 --> v4 --> relu --> v5 --> fake3
// v2 --
OpDesc fake1_op;
fake1_op.SetType("fake1");
OpDesc mul_op;
mul_op.SetType("mul");
OpDesc relu_op;
relu_op.SetType("relu");
OpDesc fake2_op;
fake2_op.SetType("fake2");
OpDesc fake3_op;
fake3_op.SetType("fake3");
VarDesc var1("var1");
VarDesc var2("var2");
var2.SetPersistable(true);
var2.SetIsParameter(true);
VarDesc var3("var3");
VarDesc var4("var4");
VarDesc var5("var5");
ir::Node* fake1 = g->CreateOpNode(&fake1_op);
ir::Node* mul = g->CreateOpNode(&mul_op);
ir::Node* relu = g->CreateOpNode(&relu_op);
ir::Node* fake2 = g->CreateOpNode(&fake2_op);
ir::Node* fake3 = g->CreateOpNode(&fake3_op);
ir::Node* v1 = g->CreateVarNode(&var1);
ir::Node* v2 = g->CreateVarNode(&var2);
ir::Node* v3 = g->CreateVarNode(&var3);
ir::Node* v4 = g->CreateVarNode(&var4);
ir::Node* v5 = g->CreateVarNode(&var5);
// fill op node
fake1->outputs = {v1};
mul->inputs = {v2, v1};
mul->outputs = {v3};
fake2->inputs = {v3};
fake2->outputs = {v4};
relu->inputs = {v4};
relu->outputs = {v5};
fake3->inputs = {v5};
// fill variable node
v2->outputs = {mul};
v1->inputs = {fake1};
v1->outputs = {mul};
v3->inputs = {mul};
v3->outputs = {fake2};
v4->inputs = {fake2};
v4->outputs = {relu};
v5->inputs = {relu};
v5->outputs = {fake3};
return g;
}
以上代码建立一个这样的图:
// fake1 --> v1 --
// | --> mul --> v3 --> fake2 --> v4 --> relu --> v5 --> fake3
// v2 --
以 fake2
op 为界,可以建立两个 cinn op pass
// fake1 -> v1 -
// | -> CinnOp -> v3 -> fake2 -> v4 -> CinnOp ->v5 -> fake3
// v2 -
cinn pass 就两句代码:
auto pass =
paddle::framework::ir::PassRegistry::Instance().Get("build_cinn_pass");
pass->Apply(g.get());
此处是检验有两个 cinn pass Op 的代码:
// A new op named kCinnLaunchOp should be added
ASSERT_TRUE(CheckNodeExisted(nodes, kCinnLaunchOp));
ASSERT_EQ(CountNode(nodes, kCinnLaunchOp), 2);
最后的编译结果是 cinn pass 之后有两个 子图:
// subgraph1:
// feed --> v4 --> relu --> v5 --> fetch
// subgraph2:
// feed --> v1 --
// | --> mul --> v3 --> fetch
// v2 --
BuildGraphWithNoNeedBufferInput
就是建立一个这样的子图:
// fake1 --> v1 -- --> v4 --> relu_grad --> v6
// v2 -- | --> add_grad |
// v3 -- --> v5 --> fake2
BuildGraphWithNoNeedBufferInput
与之前不同的是,add_grad_op
使用了设置输入的 API SetInput
OpDesc add_grad_op;
add_grad_op.SetType("elementwise_add_grad");
add_grad_op.SetInput(::paddle::framework::GradVarName("Out"), {"var1"});
add_grad_op.SetInput("X", {"var2"});
add_grad_op.SetInput("Y", {"var3"});
之后的单测写了,no_need_buffer_x
不知道什么意思.
// A new op named kCinnLaunchOp should be added and
// its input arguments are set correctly
ASSERT_TRUE(CheckNodeExisted(nodes, kCinnLaunchOp));
ASSERT_EQ(CountNode(nodes, kCinnLaunchOp), 1);
auto* cinn_op_node = GetNode(nodes, kCinnLaunchOp);
ASSERT_EQ(cinn_op_node->Op()->Input(operators::kX),
std::vector({"var1"}));
auto& no_need_buffer_x = cinn_op_node->Op()->Input(operators::kNoNeedBufferX);
ASSERT_EQ(std::unordered_set(no_need_buffer_x.begin(),
no_need_buffer_x.end()),
std::unordered_set({"var2", "var3"}));
这里的 no_need_buffer_feeds
什么意思??
ASSERT_TRUE(CheckNodeExisted(subnodes, "elementwise_add_grad"));
ASSERT_TRUE(CheckNodeExisted(subnodes, "relu_grad"));
ASSERT_EQ(CountNode(subnodes, "feed"), 3);
ASSERT_EQ(CountNode(subnodes, "fetch"), 2);
const auto& no_need_buffer_feeds =
subgraph.Get>(kNoNeedBufferFeeds);
ASSERT_EQ(no_need_buffer_feeds.size(), 2);
ASSERT_EQ(no_need_buffer_feeds,
std::unordered_set({"var2", "var3"}));
// check the attributes of variable lists are saved correctly
ASSERT_TRUE(subgraph.Has(kInputVars));
EXPECT_EQ(subgraph.Get>(kInputVars),
std::vector({"var1"}));
ASSERT_TRUE(subgraph.Has(kInternalVars));
EXPECT_EQ(subgraph.Get>(kInternalVars),
std::vector({"var4"}));
ASSERT_TRUE(subgraph.Has(kOutputVars));
const auto& output_vars = subgraph.Get>(kOutputVars);
EXPECT_EQ(
std::unordered_set(output_vars.begin(), output_vars.end()),
std::unordered_set({"var5", "var6"}));
TEST(BuildCinnPassTest, TestSkipGcVars){
auto g = BuildGraphWithOneCinnSubgraph();
// 这里什么意思????
std::unordered_set all_skip_gc_vars = {"var1", "var3"};
g->SetNotOwned(kSkipGcVarNames, &all_skip_gc_vars);
auto pass =
paddle::framework::ir::PassRegistry::Instance().Get("build_cinn_pass");
pass->Apply(g.get());
// After search, the graph should as following
// fake1 --> v1 --
// | --> kCinnLaunchOp --> v4 --> fake2
// v2 --
const auto& nodes = g->Nodes();
ASSERT_EQ(nodes.size(), static_cast(7)); // 这里为啥变成了 7
ASSERT_TRUE(CheckGraphIndependence(nodes));
// A new op named kCinnLaunchOp should be added
ASSERT_TRUE(CheckNodeExisted(nodes, kCinnLaunchOp));
// After search, there should has just one cinn subgraph
// Note v3 has fetched because of v3 in kSkipGcVarNames
// And v1 is a feed var so v1 no need fetched though it in kSkipGcVarNames
// feed --> v1 --
// | --> mul --> v3 --> relu --> v4 --> fetch
// feed --> v2 -- --> fetch
auto compilation_keys = GetCompilationKeys(*g);
ASSERT_EQ(compilation_keys.size(), static_cast(1));
auto* cinn_compiler = CinnCompiler::GetInstance();
const auto& subgraph = cinn_compiler->FindGraph(compilation_keys[0]);
const auto& subnodes = subgraph.Nodes();
ASSERT_EQ(subnodes.size(), static_cast(10));
ASSERT_TRUE(CheckGraphIndependence(subnodes));
ASSERT_EQ(CountNode(subnodes, "feed"), 2);
// var3 and var4 should has fetch op
ASSERT_EQ(CountNode(subnodes, "fetch"), 2);
}
最后两个 TEST
没看懂,留下问题