presto sql输入表、输入字段、limit、join操作解析

前言

一段时间没有写文章了,写下最近做的事情。目前我们这边有一个metabase 查询平台供运营、分析师、产品等人员使用,我们的查询都是使用 presto 引擎。并且我们的大数据组件都使用的是 emr 组件,并且涉及到中国、美西、美东、印度、欧洲、西欧等多个区域,表的权限管理就特别困难。所以就需要一个统一的权限管理来维护某些人拥有那些表的权限,避免隐私的数据泄漏。于是我们就需要一款sql解析工具来解析 presto sql 的输入表。另外还有一点,由于使用的人较多,资源较少,为了避免长查询,我们还会对含有 join 操作查询、 select * 的查询直接拒绝

sql 解析

第一种方法

presto 本身也是用的 antlr 进行 sql 语法的编辑,如果你clone了presto的源码,会在 presto-parse 模块中发现 presto/presto-parser/src/main/antlr4/com/facebook/presto/sql/parser/SqlBase.g4 文件,也就是说我们可以通过直接使用该文件生成解析的配置文件1,然后进行 sql 解析
,但是这种方法太过复杂,我尝试了下放弃了,因为从语法树中获取某些值时比较混乱,容错较小,还需要再遍历其儿子、兄弟节点,并且通过节点的 getText 方法获得节点值。
presto sql输入表、输入字段、limit、join操作解析_第1张图片

第二种方法

我们肯定很容易的就想到,presto 源码肯定也对 sql 进行了解析,何不直接使用 presto 的解析类呢?
功夫不负有心人,我在源码中发现了 SqlParser 这个类,该类在 presto-parser 模块中,通过调用 createStatement(String sql) 方法会返回一个Statement 2,后面我们只需要对 Statement 进行遍历即可

去掉注释

在 sql执行之前,我们需要进行一些预操作,比如去掉注释,分号分割多行代码

   /**
     * 替换sql注释
     *
     * @param sqlText sql
     * @return 替换后的sl
     */
    protected String replaceNotes(String sqlText) {
        StringBuilder newSql = new StringBuilder();
        String lineBreak = "\n";
        String empty = "";
        String trimLine;
        for (String line : sqlText.split(lineBreak)) {
            trimLine = line.trim();
            if (!trimLine.startsWith("--") && !trimLine.startsWith("download")) {
                //过滤掉行内注释
                line = line.replaceAll("/\\*.*\\*/", empty);
                if (org.apache.commons.lang3.StringUtils.isNotBlank(line)) {
                    newSql.append(line).append(lineBreak);
                }
            }
        }
        return newSql.toString();
    }

分号分割多段 sql


    /**
     * ;分割多段sql
     *
     * @param sqlText sql
     * @return
     */
    protected ArrayList<String> splitSql(String sqlText) {
        String[] sqlArray = sqlText.split(Constants.SEMICOLON);
        ArrayList<String> newSqlArray = new ArrayList<>(sqlArray.length);
        String command = "";
        int arrayLen = sqlArray.length;
        String oneCmd;
        for (int i = 0; i < arrayLen; i++) {
            oneCmd = sqlArray[i];
            boolean keepSemicolon = (oneCmd.endsWith("'") && i + 1 < arrayLen && sqlArray[i + 1].startsWith("'"))
                    || (oneCmd.endsWith("\"") && i + 1 < arrayLen && sqlArray[i + 1].startsWith("\""));
            if (oneCmd.endsWith("\\")) {
                command += org.apache.commons.lang.StringUtils.chop(oneCmd) + Constants.SEMICOLON;
                continue;
            } else if (keepSemicolon) {
                command += oneCmd + Constants.SEMICOLON;
                continue;
            } else {
                command += oneCmd;
            }
            if (org.apache.commons.lang3.StringUtils.isBlank(command)) {
                continue;
            }
            newSqlArray.add(command);
            command = "";
        }
        return newSqlArray;
    }

sql解析

经过预处理之后,就需要对 sql 进行解析。inputTables、outputTables、tempTables 分别表示输入表、输出表、临时表

 @Override
    protected Tuple3<HashSet<TableInfo>, HashSet<TableInfo>, HashSet<TableInfo>> parseInternal(String sqlText) throws SqlParseException {
        this.inputTables = new HashSet<>();
        this.outputTables = new HashSet<>();
        this.tempTables = new HashSet<>();
        try {
        	//ParsingOptions.DecimalLiteralTreatment.AS_DECIMAL 表示数字以DECIMAL类型解析
            check(new SqlParser().createStatement(sqlText, new ParsingOptions(ParsingOptions.DecimalLiteralTreatment.AS_DECIMAL)));
        } catch (ParsingException e) {
            throw new SqlParseException("parse sql exception:" + e.getMessage(), e);
        }
        return new Tuple3<>(inputTables, outputTables, tempTables);
    }
根节点识别

进入 check 方法进行 Statement 的遍历

   /**
     * statement 过滤 只识别select 语句
     *
     * @param statement
     * @throws SqlParseException
     */
    private void check(Statement statement) throws SqlParseException {
    	//如果根节点是查询节点 获取所有的孩子节点,深度优先搜索遍历
        if (statement instanceof Query) {
            Query query = (Query) statement;
            List<Node> children = query.getChildren();
            for (Node child : children) {
                checkNode(child);
            }
        } else if (statement instanceof Use) {
            Use use = (Use) statement;
            this.currentDb = use.getSchema().getValue();
        } else if (statement instanceof ShowColumns) {
            ShowColumns show = (ShowColumns) statement;
            String allName = show.getTable().toString().replace("hive.", "");
            inputTables.add(buildTableInfo(allName, OperatorType.READ));
        } else if (statement instanceof ShowTables) {
            ShowTables show = (ShowTables) statement;
            QualifiedName qualifiedName = show.getSchema().orElseThrow(() -> new SqlParseException("unkonw table name or db name" + statement.toString()));
            String allName = qualifiedName.toString().replace("hive.", "");
            if (allName.contains(Constants.POINT)) {
                allName += Constants.POINT + "*";
            }
            inputTables.add(buildTableInfo(allName, OperatorType.READ));

        } else {
            throw new SqlParseException("sorry,only support read statement,unSupport statement:" + statement.getClass().getName());
        }
    }
  • 如果根节点是 Query 查询节点 获取所有的孩子节点,深度优先搜索遍历
  • 如果根节点是 Use 切换数据库的节点,修改当前的数据库名称
  • 如果根节点是ShowColumns 查看表字段的节点,将该表加入输入表
  • 如果根节点是ShowTables 查看表结构的节点,将该表加入输入表
  • 否则抛出无法解析的异常

子节点遍历

主要进入 checkNode 方法,进行查询语句所有孩子节点的遍历

/**
     * node 节点的遍历
     *
     * @param node
     */
    private void checkNode(Node node) throws SqlParseException {
    	//查询子句
        if (node instanceof QuerySpecification) {
            QuerySpecification query = (QuerySpecification) node;
            //如果查询包含limit语句 直接将limit入栈
            query.getLimit().ifPresent(limit -> limitStack.push(limit));
            //遍历子节点
            loopNode(query.getChildren());
        } else if (node instanceof TableSubquery) {
            loopNode(node.getChildren());
        } else if (node instanceof AliasedRelation) {
            // 表的别名 需要放到tableAliaMap供别别名的字段解析使用
            AliasedRelation alias = (AliasedRelation) node;
            String value = alias.getAlias().getValue();
            if (alias.getChildren().size() == 1 && alias.getChildren().get(0) instanceof Table) {
                Table table = (Table) alias.getChildren().get(0);
                tableAliaMap.put(value, table.getName().toString());
            } else {
                tempTables.add(buildTableInfo(value, OperatorType.READ));
            }
            loopNode(node.getChildren());
        } else if (node instanceof Query || node instanceof SubqueryExpression
                || node instanceof Union || node instanceof With
                || node instanceof LogicalBinaryExpression || node instanceof InPredicate) {
            loopNode(node.getChildren());

        } else if (node instanceof Join) {
        	//发现join操作  设置hasJoin 为true
            hasJoin = true;
            loopNode(node.getChildren());
        }
        //基本都是where条件,过滤掉,如果需要,可以调用getColumn解析字段
        else if (node instanceof LikePredicate || node instanceof NotExpression
                || node instanceof IfExpression
                || node instanceof ComparisonExpression || node instanceof GroupBy
                || node instanceof OrderBy || node instanceof Identifier
                || node instanceof InListExpression || node instanceof DereferenceExpression
                || node instanceof IsNotNullPredicate || node instanceof IsNullPredicate
                || node instanceof FunctionCall) {
            print(node.getClass().getName());

        } else if (node instanceof WithQuery) {
        	//with 子句的临时表 
            WithQuery withQuery = (WithQuery) node;
            tempTables.add(buildTableInfo(withQuery.getName().getValue(), OperatorType.READ));
            loopNode(withQuery.getChildren());
        } else if (node instanceof Table) {
        	//发现table节点 放入输入表
            Table table = (Table) node;
            inputTables.add(buildTableInfo(table.getName().toString(), OperatorType.READ));
            loopNode(table.getChildren());
        } else if (node instanceof Select) {
        	//发现select 子句,需要调用getColumn方法从selectItems中获取select的字段
            Select select = (Select) node;
            List<SelectItem> selectItems = select.getSelectItems();
            HashSet<String> columns = new HashSet<>();
            for (SelectItem item : selectItems) {
                if (item instanceof SingleColumn) {
                    columns.add(getColumn(((SingleColumn) item).getExpression()));
                } else if (item instanceof AllColumns) {
                    columns.add(item.toString());
                } else {
                    throw new SqlParseException("unknow column type:" + item.getClass().getName());
                }
            }
            //将字段入栈
            columnsStack.push(columns);

        } else {
            throw new SqlParseException("unknow node type:" + node.getClass().getName());
        }
    }

上面需要注意的是,每次想输入表、临时表中添加表时都对应一个 column的集合从 columnsStack 出栈。
后面看从 selectItems 中获取字段的方法 getColumn.

  /**
     * select 字段表达式中获取字段
     *
     * @param expression
     * @return
     */
    private String getColumn(Expression expression) throws SqlParseException {
        if (expression instanceof IfExpression) {
            IfExpression ifExpression = (IfExpression) expression;
            List<Expression> list = new ArrayList<>();
            list.add(ifExpression.getCondition());
            list.add(ifExpression.getTrueValue());
            ifExpression.getFalseValue().ifPresent(list::add);
            return getString(list);
        } else if (expression instanceof Identifier) {
            Identifier identifier = (Identifier) expression;
            return identifier.getValue();
        } else if (expression instanceof FunctionCall) {
            FunctionCall call = (FunctionCall) expression;
            StringBuilder columns = new StringBuilder();
            List<Expression> arguments = call.getArguments();
            int size = arguments.size();
            for (int i = 0; i < size; i++) {
                Expression exp = arguments.get(i);
                if (i == 0) {
                    columns.append(getColumn(exp));
                } else {
                    columns.append(getColumn(exp)).append(columnSplit);
                }
            }
            return columns.toString();
        } else if (expression instanceof ComparisonExpression) {
            ComparisonExpression compare = (ComparisonExpression) expression;
            return getString(compare.getLeft(), compare.getRight());
        } else if (expression instanceof Literal || expression instanceof ArithmeticUnaryExpression) {
            return "";
        } else if (expression instanceof Cast) {
            Cast cast = (Cast) expression;
            return getColumn(cast.getExpression());
        } else if (expression instanceof DereferenceExpression) {
            DereferenceExpression reference = (DereferenceExpression) expression;
            return reference.toString();
        } else if (expression instanceof ArithmeticBinaryExpression) {
            ArithmeticBinaryExpression binaryExpression = (ArithmeticBinaryExpression) expression;
            return getString(binaryExpression.getLeft(), binaryExpression.getRight());
        } else if (expression instanceof SearchedCaseExpression) {
            SearchedCaseExpression caseExpression = (SearchedCaseExpression) expression;
            List<Expression> exps = caseExpression.getWhenClauses().stream().map(whenClause -> (Expression) whenClause).collect(Collectors.toList());
            caseExpression.getDefaultValue().ifPresent(exps::add);
            return getString(exps);
        } else if (expression instanceof WhenClause) {
            WhenClause whenClause = (WhenClause) expression;
            return getString(whenClause.getOperand(), whenClause.getResult());
        } else if (expression instanceof LikePredicate) {
            LikePredicate likePredicate = (LikePredicate) expression;
            return likePredicate.getValue().toString();
        } else if (expression instanceof InPredicate) {
            InPredicate predicate = (InPredicate) expression;
            return predicate.getValue().toString();
        } else if (expression instanceof SubscriptExpression) {
            SubscriptExpression subscriptExpression = (SubscriptExpression) expression;
            return getColumn(subscriptExpression.getBase());
        } else if (expression instanceof LogicalBinaryExpression) {
            LogicalBinaryExpression logicExp = (LogicalBinaryExpression) expression;
            return getString(logicExp.getLeft(), logicExp.getRight());
        } else if (expression instanceof IsNullPredicate) {
            IsNullPredicate isNullExp = (IsNullPredicate) expression;
            return getColumn(isNullExp.getValue());
        } else if (expression instanceof IsNotNullPredicate) {
            IsNotNullPredicate notNull = (IsNotNullPredicate) expression;
            return getColumn(notNull.getValue());
        } else if (expression instanceof CoalesceExpression) {
            CoalesceExpression coalesce = (CoalesceExpression) expression;
            return getString(coalesce.getOperands());
        }
        throw new SqlParseException("无法识别的表达式:" + expression.getClass().getName());
        //   return expression.toString();
    }

由于我们 select 的字段可能包含很多种函数,所以需要一一进行解析,就不在细说。

后续

其实我也实现了 spark sql、hive sql 的输入表、输出表的解析,代码放在了github 上 :https://github.com/scxwhite/parseX
分享不易,请不要吝啬你的star


  1. 生成配置文件的方式可以通过 idea 安装 antlr 插件,对 sqlBase.g4文件进行配置后生成 java 等语言的解析类 。 ↩︎

  2. Statement 可以理解为对语法树 node 节点的一层封装,方便于我们的解析 ↩︎

你可能感兴趣的:(大数据)