Calcite 自定义优化器规则

1)总结
1.创建 CSVProjectRule 继承 RelRule
a)在 CSVProjectRule.Config 接口中实现匹配规则
Config DEFAULT = EMPTY
                .withOperandSupplier(b0 ->b0.operand(LogicalProject.class).anyInputs())
                .as(Config.class);
b)在 CSVProjectRule 实现类中,如果匹配上了规则,则进行转换
 		@Override
    public void onMatch(RelOptRuleCall call) {
        final LogicalProject project = call.rel(0);
        final RelNode converted = convert(project);
        if (converted != null) {
            call.transformTo(converted);
        }
    }
    
    ------------------------------------------------
    
    public RelNode convert(RelNode rel) {
        final LogicalProject project = (LogicalProject) rel;
        final RelTraitSet traitSet = project.getTraitSet();
        return new CSVProject(project.getCluster(), traitSet,
                project.getInput(), project.getProjects(),
                project.getRowType());
    }
2.创建转换后的RelNode 即 CSVProject
2)代码示例

CSVProjectRule

package cn.com.ptpress.cdm.optimization.RelBuilder.optimizer;

import cn.com.ptpress.cdm.optimization.RelBuilder.csvRelNode.CSVProject;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelRule;
import org.apache.calcite.plan.RelTraitSet;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.logical.LogicalProject;

public class CSVProjectRule  extends RelRule {

    @Override
    public void onMatch(RelOptRuleCall call) {
        final LogicalProject project = call.rel(0);
        final RelNode converted = convert(project);
        if (converted != null) {
            call.transformTo(converted);
        }
    }

    /** Rule configuration. */
    public interface Config extends RelRule.Config {
        Config DEFAULT = EMPTY
                .withOperandSupplier(b0 ->
                        b0.operand(LogicalProject.class).anyInputs())
                .as(Config.class);

        @Override default CSVProjectRule toRule() {
            return new CSVProjectRule(this);
        }
    }

    private CSVProjectRule(Config config) {
        super(config);
    }


    public RelNode convert(RelNode rel) {
        final LogicalProject project = (LogicalProject) rel;
        final RelTraitSet traitSet = project.getTraitSet();
        return new CSVProject(project.getCluster(), traitSet,
                project.getInput(), project.getProjects(),
                project.getRowType());
    }
}

CSVProjectRuleWithCost

package cn.com.ptpress.cdm.optimization.RelBuilder.optimizer;

import cn.com.ptpress.cdm.optimization.RelBuilder.csvRelNode.CSVProject;
import cn.com.ptpress.cdm.optimization.RelBuilder.csvRelNode.CSVProjectWithCost;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelRule;
import org.apache.calcite.plan.RelTraitSet;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.logical.LogicalProject;

public class CSVProjectRuleWithCost extends RelRule {

    @Override
    public void onMatch(RelOptRuleCall call) {
        final LogicalProject project = call.rel(0);
        final RelNode converted = convert(project);
        if (converted != null) {
            call.transformTo(converted);
        }
    }

    /** Rule configuration. */
    public interface Config extends RelRule.Config {
        Config DEFAULT = EMPTY
                .withOperandSupplier(b0 ->
                        b0.operand(LogicalProject.class).anyInputs())
                .as(Config.class);

        @Override default CSVProjectRuleWithCost toRule() {
            return new CSVProjectRuleWithCost(this);
        }
    }

    private CSVProjectRuleWithCost(Config config) {
        super(config);
    }


    public RelNode convert(RelNode rel) {
        final LogicalProject project = (LogicalProject) rel;
        final RelTraitSet traitSet = project.getTraitSet();
        return new CSVProjectWithCost(project.getCluster(), traitSet,
                project.getInput(), project.getProjects(),
                project.getRowType());
    }
}

CSVProject

package cn.com.ptpress.cdm.optimization.RelBuilder.csvRelNode;

import com.google.common.collect.ImmutableList;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelOptCost;
import org.apache.calcite.plan.RelOptPlanner;
import org.apache.calcite.plan.RelTraitSet;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexNode;

import java.util.List;

public class CSVProject extends Project {

    public CSVProject(RelOptCluster cluster, RelTraitSet traits, RelNode input, List projects, RelDataType rowType) {
        super(cluster,traits, ImmutableList.of(),input,projects,rowType);
    }

    @Override
    public Project copy(RelTraitSet traitSet, RelNode input, List projects, RelDataType rowType) {
        return new CSVProject(getCluster(),traitSet,input,projects,rowType);
    }

    @Override
    public RelOptCost computeSelfCost(RelOptPlanner planner, RelMetadataQuery mq) {
        return planner.getCostFactory().makeZeroCost();
    }
}

CSVProjectWithCost

package cn.com.ptpress.cdm.optimization.RelBuilder.csvRelNode;

import com.google.common.collect.ImmutableList;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelOptCost;
import org.apache.calcite.plan.RelOptPlanner;
import org.apache.calcite.plan.RelTraitSet;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexNode;

import java.util.List;

public class CSVProjectWithCost extends Project{
    public CSVProjectWithCost(RelOptCluster cluster, RelTraitSet traits, RelNode input, List projects, RelDataType rowType) {
        super(cluster,traits, ImmutableList.of(),input,projects,rowType);
    }

    @Override
    public Project copy(RelTraitSet traitSet, RelNode input, List projects, RelDataType rowType) {
        return new CSVProjectWithCost(getCluster(),traitSet,input,projects,rowType);
    }

    @Override
    public RelOptCost computeSelfCost(RelOptPlanner planner, RelMetadataQuery mq) {
        return planner.getCostFactory().makeInfiniteCost();
    }
}

SqlToRelNode

package cn.com.ptpress.cdm.optimization.RelBuilder.Utils;

import cn.com.ptpress.cdm.ds.csv.CsvSchema;
import org.apache.calcite.config.CalciteConnectionConfigImpl;
import org.apache.calcite.config.CalciteConnectionProperty;
import org.apache.calcite.jdbc.CalciteSchema;
import org.apache.calcite.jdbc.JavaTypeFactoryImpl;
import org.apache.calcite.prepare.CalciteCatalogReader;
import org.apache.calcite.schema.SchemaPlus;
import org.apache.calcite.sql.parser.SqlParser;
import org.apache.calcite.tools.Frameworks;

import java.util.Properties;

public class CatalogReaderUtil {
    public static CalciteCatalogReader createCatalogReader(SqlParser.Config parserConfig) {
        SchemaPlus rootSchema = Frameworks.createRootSchema(true);
        rootSchema.add("csv", new CsvSchema("data.csv"));
        return createCatalogReader(parserConfig, rootSchema);
    }

    public static CalciteCatalogReader createCatalogReader(SqlParser.Config parserConfig, SchemaPlus rootSchema) {

        Properties prop = new Properties();
        prop.setProperty(CalciteConnectionProperty.CASE_SENSITIVE.camelName(),
                String.valueOf(parserConfig.caseSensitive()));
        CalciteConnectionConfigImpl calciteConnectionConfig = new CalciteConnectionConfigImpl(prop);
        return new CalciteCatalogReader(
                CalciteSchema.from(rootSchema),
                CalciteSchema.from(rootSchema).path("csv"),
                new JavaTypeFactoryImpl(),
                calciteConnectionConfig
        );
    }
}

CatalogReaderUtil

package cn.com.ptpress.cdm.optimization.RelBuilder.Utils;

import cn.com.ptpress.cdm.ds.csv.CsvSchema;
import org.apache.calcite.config.CalciteConnectionConfigImpl;
import org.apache.calcite.config.CalciteConnectionProperty;
import org.apache.calcite.jdbc.CalciteSchema;
import org.apache.calcite.jdbc.JavaTypeFactoryImpl;
import org.apache.calcite.prepare.CalciteCatalogReader;
import org.apache.calcite.schema.SchemaPlus;
import org.apache.calcite.sql.parser.SqlParser;
import org.apache.calcite.tools.Frameworks;

import java.util.Properties;

public class CatalogReaderUtil {
    public static CalciteCatalogReader createCatalogReader(SqlParser.Config parserConfig) {
        SchemaPlus rootSchema = Frameworks.createRootSchema(true);
        rootSchema.add("csv", new CsvSchema("data.csv"));
        return createCatalogReader(parserConfig, rootSchema);
    }

    public static CalciteCatalogReader createCatalogReader(SqlParser.Config parserConfig, SchemaPlus rootSchema) {

        Properties prop = new Properties();
        prop.setProperty(CalciteConnectionProperty.CASE_SENSITIVE.camelName(),
                String.valueOf(parserConfig.caseSensitive()));
        CalciteConnectionConfigImpl calciteConnectionConfig = new CalciteConnectionConfigImpl(prop);
        return new CalciteCatalogReader(
                CalciteSchema.from(rootSchema),
                CalciteSchema.from(rootSchema).path("csv"),
                new JavaTypeFactoryImpl(),
                calciteConnectionConfig
        );
    }
}

PlannerTest

import cn.com.ptpress.cdm.optimization.RelBuilder.Utils.SqlToRelNode;
import cn.com.ptpress.cdm.optimization.RelBuilder.optimizer.CSVProjectRule;
import cn.com.ptpress.cdm.optimization.RelBuilder.optimizer.CSVProjectRuleWithCost;
import org.apache.calcite.plan.RelOptPlanner;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.plan.hep.HepPlanner;
import org.apache.calcite.plan.hep.HepProgram;
import org.apache.calcite.plan.hep.HepProgramBuilder;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.rules.FilterJoinRule;
import org.apache.calcite.sql.parser.SqlParseException;
import org.junit.jupiter.api.Test;


class PlannerTest {
    @Test
    public void testCustomRule() throws SqlParseException {
        final String sql = "select Id from data ";

        HepProgramBuilder programBuilder = HepProgram.builder();

        // 测试交换 CSVProjectRule 和 CSVProjectRuleWithCost 的顺序
        HepPlanner hepPlanner =
                new HepPlanner(
                        programBuilder
                                .addRuleInstance(CSVProjectRule.Config.DEFAULT.toRule())
                                .addRuleInstance(CSVProjectRuleWithCost.Config.DEFAULT.toRule())
                                .build());

//        HepPlanner hepPlanner =
//                new HepPlanner(
//                        programBuilder
//                                .addRuleInstance(CSVProjectRuleWithCost.Config.DEFAULT.toRule())
//                                .addRuleInstance(CSVProjectRule.Config.DEFAULT.toRule())
//                                .build());

        RelNode relNode = SqlToRelNode.getSqlNode(sql, hepPlanner);
        System.out.println(RelOptUtil.toString(relNode));

        RelOptPlanner planner = relNode.getCluster().getPlanner();
        planner.setRoot(relNode);
        RelNode bestExp = planner.findBestExp();
        System.out.println(RelOptUtil.toString(bestExp));

        RelOptPlanner relOptPlanner = relNode.getCluster().getPlanner();
        relOptPlanner.addRule(CSVProjectRule.Config.DEFAULT.toRule());
        relOptPlanner.addRule(CSVProjectRuleWithCost.Config.DEFAULT.toRule());
        relOptPlanner.setRoot(relNode);
        RelNode exp = relOptPlanner.findBestExp();
        System.out.println(RelOptUtil.toString(exp));
    }

    /**
     * 未优化算子树结构
     * LogicalProject(ID=[$0])
     *   LogicalFilter(condition=[>(CAST($0):INTEGER NOT NULL, 1)])
     *     LogicalJoin(condition=[=($0, $3)], joinType=[inner])
     *       LogicalTableScan(table=[[csv, data]])
     *       LogicalTableScan(table=[[csv, data]])
     *
     * 优化后接结果
     * LogicalProject(ID=[$0])
     *   LogicalJoin(condition=[=($0, $3)], joinType=[inner])
     *     LogicalFilter(condition=[>(CAST($0):INTEGER NOT NULL, 1)])
     *       LogicalTableScan(table=[[csv, data]])
     *     LogicalTableScan(table=[[csv, data]])
     */
    @Test
    public void testHepPlanner() throws SqlParseException {
        final String sql = "select a.Id from data as a join data b on a.Id = b.Id where a.Id>1";
        HepProgramBuilder programBuilder = HepProgram.builder();
        HepPlanner hepPlanner =
                new HepPlanner(
                        programBuilder.addRuleInstance(FilterJoinRule.FilterIntoJoinRule.Config.DEFAULT.toRule())
                                .build());
        RelNode relNode = SqlToRelNode.getSqlNode(sql, hepPlanner);
        //未优化算子树结构
        System.out.println(RelOptUtil.toString(relNode));
        RelOptPlanner planner = relNode.getCluster().getPlanner();
        planner.setRoot(relNode);
        RelNode bestExp = planner.findBestExp();
        //优化后接结果
        System.out.println(RelOptUtil.toString(bestExp));

    }

    /**
     * 未转化Dag算子树结构
     * LogicalProject(Id=[$0], Name=[$1], Score=[$2])
     *   LogicalFilter(condition=[=(CAST($0):INTEGER NOT NULL, 1)])
     *     LogicalTableScan(table=[[csv, data]])
     *
     * 转化为Dag图
     * Breadth-first from root:  {
     *     rel#8:HepRelVertex(rel#7:LogicalProject.(input=HepRelVertex#6,inputs=0..2)) = rel#7:LogicalProject.(input=HepRelVertex#6,inputs=0..2), rowcount=15.0, cumulative cost=130.0
     *     rel#6:HepRelVertex(rel#5:LogicalFilter.(input=HepRelVertex#4,condition==(CAST($0):INTEGER NOT NULL, 1))) = rel#5:LogicalFilter.(input=HepRelVertex#4,condition==(CAST($0):INTEGER NOT NULL, 1)), rowcount=15.0, cumulative cost=115.0
     *     rel#4:HepRelVertex(rel#1:LogicalTableScan.(table=[csv, data])) = rel#1:LogicalTableScan.(table=[csv, data]), rowcount=100.0, cumulative cost=100.0
     * }
     */
    @Test
    public void testGraph() throws SqlParseException {
        final String sql = "select * from data where Id=1";
        HepProgramBuilder programBuilder = HepProgram.builder();
        HepPlanner hepPlanner =
                new HepPlanner(
                        programBuilder.build());
        RelNode relNode = SqlToRelNode.getSqlNode(sql, hepPlanner);
        //未转化Dag算子树结构
        System.out.println("未转化Dag算子树结构");
        System.out.println(RelOptUtil.toString(relNode));
        //转化为Dag图
        System.out.println("转化为Dag图");
        hepPlanner.setRoot(relNode);
        //查看需要把log4j.properties级别改为trace
    }
}

data.csv

Id:VARCHAR Name:VARCHAR Score:INTEGER
1,小明,90
2,小红,98
3,小亮,95

你可能感兴趣的:(Flink精通~源码设计解析,知识图谱,人工智能)