基于语法树解析Spark SQL 获取访问的表/字段/UDF等信息

最近有需求,需要拆解Spark SQL中的表,字段等信息,然后配合Ranger实现一些权限校验。

其实难度不大,就是需要根据语法树做一些递归拆解,然后就能拆解出一段SQL中的相关信息,再创建一些数据结构bean对象用于配合校验。

下面是部分源码(全部本人原创),不涉及业务信息。

对于Presto我也做了拆解,欢迎沟通交流。

package test.server.utils;

import com.fasterxml.jackson.databind.ObjectMapper;
import test.ranger.AccessType;
import test.ranger.ColumnAccess;
import test.ranger.ColumnsWithTable;
import test.ranger.PartitionData;
import test.server.service.impl.SparkSqlJob;
import lombok.extern.slf4j.Slf4j;
import okhttp3.MediaType;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.Response;
import okhttp3.ResponseBody;
import okio.Buffer;
import okio.BufferedSource;
import org.antlr.v4.runtime.*;
import org.antlr.v4.runtime.tree.ParseTree;
import org.antlr.v4.runtime.tree.RuleNode;
import org.antlr.v4.runtime.tree.TerminalNode;
import org.apache.spark.sql.catalyst.parser.SqlBaseParser;

import java.io.IOException;
import java.nio.charset.Charset;
import java.nio.charset.UnsupportedCharsetException;
import java.sql.SQLException;
import java.util.*;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import static java.lang.String.format;

@Slf4j
public class SparkSqlParser {

    private final ObjectMapper mapper = new ObjectMapper();
    private Pattern pattern = Pattern.compile(".*(?i)set\\s+(.+)=(.+)");
    private String currentDB = "default";
    private Map items = new HashMap<>();
    private Set udfList = new HashSet<>();
    private Set tempTable = new HashSet<>();
    private OkHttpClient client = new OkHttpClient().newBuilder().build();
    private String metadataUrl;
    private String metadataToken;

    public SparkSqlParser(String metadataUrl, String metadataToken) {
        this.metadataUrl = metadataUrl;
        this.metadataToken = metadataToken;
    }

    public Map getItems() {
        return items;
    }

    public Set getUdfList() {
        return udfList;
    }

    public Map getConfigFromSql(String originalSql) throws IOException {
        List commands = SparkSqlJob.prepare(originalSql);
        HashMap configMap = new HashMap<>();
        for (String command : commands) {
            if (command.trim().toLowerCase().startsWith("set")) {
                try {
                    Matcher m = pattern.matcher(command);
                    if (m.find()) {
                        String key = m.group(1).trim();
                        String value = m.group(2).trim();
                        configMap.put(key, value);
                    }
                } catch (Exception e) {
                    log.warn(e.getMessage(), e);
                }
            }
        }
        return configMap;
    }

    public void reset(){
        items.clear();
        udfList.clear();
        tempTable.clear();
    }

    public void visit(ParseTree t) throws Exception {
        if (t instanceof TerminalNode){
        } else if (t instanceof SqlBaseParser.UseContext){
            currentDB = ((SqlBaseParser.UseContext) t)
                    .getText().toLowerCase().replaceFirst("use", "").trim();
        } else if (t instanceof SqlBaseParser.QuerySpecificationContext) { //select权限
            visitSelect(t, null);
        } else if (t instanceof SqlBaseParser.CtesContext) {
            SqlBaseParser.CtesContext with = (SqlBaseParser.CtesContext) t;
            for (SqlBaseParser.NamedQueryContext item : with.namedQuery()) {
                tempTable.add(item.name.getText().replace("`", "").toLowerCase());
                visit(item);
            }
        } else if (t instanceof SqlBaseParser.CacheTableContext){
            SqlBaseParser.CacheTableContext cache = (SqlBaseParser.CacheTableContext) t;
            tempTable.add(cache.multipartIdentifier().getText().replace("`", "").toLowerCase());
        } else if (t instanceof SqlBaseParser.InsertIntoTableContext
                || t instanceof SqlBaseParser.InsertOverwriteTableContext) { // write权限
            visitAndAddTable(t, AccessType.write.toString());
        } else if (t instanceof SqlBaseParser.CreateTableHeaderContext){ //create权限
            SqlBaseParser.CreateTableHeaderContext header = (SqlBaseParser.CreateTableHeaderContext) t;
            String table = header.multipartIdentifier().getText();
            if (header.TEMPORARY() != null){
                tempTable.add(table.replace("`", "").toLowerCase());
            } else {
                addItem(table, AccessType.create.toString(), null);
            }
        } else if (t instanceof SqlBaseParser.CreateViewContext){
            tempTable.add(((SqlBaseParser.CreateViewContext) t)
                    .multipartIdentifier().getText().replace("`", "").toLowerCase());
        } else if (t instanceof SqlBaseParser.CreateTableLikeContext
                || t instanceof SqlBaseParser.CreateTempViewUsingContext) {
            visitAndAddTable(t, AccessType.create.toString());
        } else if (t instanceof SqlBaseParser.DropTableContext // drop权限
                || t instanceof SqlBaseParser.DropTablePartitionsContext) {
            visitAndAddTable(t, AccessType.drop.toString());
        } else if (t instanceof SqlBaseParser.RenameTableContext // update权限
                || t instanceof SqlBaseParser.RenameTablePartitionContext
                || t instanceof SqlBaseParser.AlterViewQueryContext
                || t instanceof SqlBaseParser.AddTablePartitionContext
                || t instanceof SqlBaseParser.HiveChangeColumnContext) {
            visitAndAddTable(t, AccessType.update.toString());
        } else {
            RuleNode r = (RuleNode)t;
            int n = r.getChildCount();
            for(int i = 0; i < n; ++i) {
                ParseTree child = r.getChild(i);
                visit(child);
            }
        }
    }

    private void visitConfig(ParseTree t, Map configMap) throws RecognitionException {
        if (t instanceof SqlBaseParser.SetConfigurationContext) {
            String[] config = t.getText().toLowerCase()
                    .replace("set", "")
                    .replace(" ", "")
                    .split("=");

            if (config.length == 2){
                configMap.put(config[0], config[1]);
            }
        } else if (!(t instanceof TerminalNode)) {
            RuleNode r = (RuleNode)t;
            int n = r.getChildCount();
            for(int i = 0; i < n; ++i) {
                ParseTree child = r.getChild(i);
                visitConfig(child, configMap);
            }
        }
    }

    private String addItem(SqlBaseParser.TableIdentifierContext tableIdentifier, String accessType, Set columns){
        String tableName = tableIdentifier.table.getText().replace("`", "").toLowerCase();
        if (tempTable.contains(tableName)) {
            return ""; //table create in this sql, we don't need
        }
        String dbName = tableIdentifier.db == null ? currentDB : tableIdentifier.db.getText().toLowerCase();
        String table = dbName + "." + tableName;
        if (items.containsKey(table)) {
            items.get(table).addAccess(accessType, columns);
        } else {
            items.put(table, new ColumnAccess(accessType, columns));
        }
        return table;
    }

    private String addItem(String table, String accessType, Set columns){
        String[] part = table.replace("`", "").toLowerCase().split("\\.");
        String tableName = "";
        String fullName = "";
        if (part.length == 1){
            tableName = part[0];
            fullName = currentDB + "." + tableName;
        } else if (part.length == 2){
            tableName = part[1];
            fullName = part[0] + "." + part[1];
        }

        if (tempTable.contains(tableName)) {
            return ""; //table create in this sql, we don't need
        }

        if (items.containsKey(fullName)) {
            items.get(fullName).addAccess(accessType, columns);
        } else {
            items.put(fullName, new ColumnAccess(accessType, columns));
        }

        return fullName;
    }

    private void visitAndAddTable(ParseTree t, String access) throws Exception {
        RuleNode r = (RuleNode)t;
        int n = r.getChildCount();
        for(int i = 0; i < n; ++i) {
            ParseTree child = r.getChild(i);
            if (child instanceof SqlBaseParser.TableIdentifierContext) {
                addItem((SqlBaseParser.TableIdentifierContext) child, access, null);
            } else {
                visit(child);
            }
        }
    }

    private String visitSelect(ParseTree t, ColumnsWithTable columns) throws Exception {
        if (t instanceof TerminalNode){
            return "";
        } else if (t instanceof SqlBaseParser.RegularQuerySpecificationContext) {
            String table = "";
            Set compareColumns = new HashSet<>();
            for (ParseTree child : ((ParserRuleContext) t).children) {
                if (child instanceof SqlBaseParser.SelectClauseContext){ //select
                    if (columns == null){
                        columns = new ColumnsWithTable();
                    }
                    visitColumn(child, columns);
                } else if (child instanceof SqlBaseParser.FromClauseContext){ //from
                    table = visitSelect(child, columns);
                } else if (child instanceof SqlBaseParser.WhereClauseContext){
                    visitWhere(child, compareColumns);
                }
            }

            if (table != null && !table.isEmpty()){
                checkPartition(table, compareColumns);
            }

            return table;
        } else if (t instanceof SqlBaseParser.TableNameContext) {
            SqlBaseParser.TableNameContext table = (SqlBaseParser.TableNameContext) t;
            SqlBaseParser.MultipartIdentifierContext multipart = table.multipartIdentifier();
            String tableName = table.multipartIdentifier().getText();
            if (!tableName.isEmpty()){
                String alias = table.tableAlias().getText().toLowerCase();
                if (alias.isEmpty()){
                    return addItem(tableName, AccessType.select.toString(), columns.getMatchColumns());
                } else {
                    return addItem(tableName, AccessType.select.toString(), columns.getMatchColumns(alias));
                }
            }

            return "";
        } else if (t instanceof SqlBaseParser.TableIdentifierContext) {
            return addItem((SqlBaseParser.TableIdentifierContext) t, AccessType.select.toString(), columns.getMatchColumns());
        } else {
            RuleNode r = (RuleNode)t;
            String table = "";
            int n = r.getChildCount();
            for(int i = 0; i < n; ++i) {
                ParseTree child = r.getChild(i);
                table = visitSelect(child, columns);
            }
            return table;
        }
    }

    private void visitColumn(ParseTree t, ColumnsWithTable columns) throws Exception {
        if (t instanceof TerminalNode) {
        } else if (t instanceof SqlBaseParser.ColumnReferenceContext) {
            columns.addColumn(t.getText().toLowerCase());
        } else if (t instanceof SqlBaseParser.DereferenceContext) {
            columns.addColumn(t.getText().toLowerCase());
        } else if (t instanceof SqlBaseParser.FunctionCallContext) {
            SqlBaseParser.FunctionCallContext function = (SqlBaseParser.FunctionCallContext) t;
            udfList.add(function.start.getText().toLowerCase());
            for (ParseTree child : function.children) {
                visitColumn(child, columns);
            }
        } else {
            RuleNode r = (RuleNode)t;
            int n = r.getChildCount();
            for(int i = 0; i < n; ++i) {
                ParseTree child = r.getChild(i);
                visitColumn(child, columns);
            }
        }
    }

    private void visitWhere(ParseTree t, Set compareColumns) throws Exception {
        if (t instanceof TerminalNode) {
        } else if (t instanceof SqlBaseParser.ColumnReferenceContext) {
            SqlBaseParser.ColumnReferenceContext column = (SqlBaseParser.ColumnReferenceContext) t;
            compareColumns.add(column.getText().replace("`", "").toLowerCase());
        } else {
            RuleNode r = (RuleNode)t;
            int n = r.getChildCount();
            for(int i = 0; i < n; ++i) {
                ParseTree child = r.getChild(i);
                visitWhere(child, compareColumns);
            }
        }
    }

    private void checkPartition(String tableName, Set compareColumns) throws Exception {
        String[] tablePart = tableName.split("\\.");
        if (tablePart.length == 2) {
            String url = format("%s/partitionColumns?dbName=%s&tableName=%s", metadataUrl, tablePart[0], tablePart[1]);
            Request request = new Request.Builder()
                    .url(url).method("GET", null)
                    .addHeader("token", metadataToken)
                    .build();
            List partitions = getPartitions(client.newCall(request).execute());
            if (partitions != null && partitions.size() > 0 && !compareColumns.contains(partitions.get(0))) {
                throw new SQLException(
                        format("查询分区表%s时,请带上分区(%s)作为where限制条件.", tableName, partitions.get(0)));
            }
        }
    }

    public List getPartitions(Response response) {
        try {
            String res = "";
            ResponseBody responseBody = response.body();
            String contentEncoding = response.headers().get("Content-Encoding");
            long contentLength = responseBody.contentLength();

            if (contentEncoding == null || contentEncoding.equalsIgnoreCase("identity")) {
                BufferedSource source = responseBody.source();
                try {
                    source.request(Long.MAX_VALUE); // Buffer the entire body.
                } catch (IOException e) {
                    e.printStackTrace();
                }
                Buffer buffer = source.buffer();

                Charset charset = Charset.forName("UTF-8");
                MediaType contentType = responseBody.contentType();
                if (contentType != null) {
                    try {
                        charset = contentType.charset(Charset.forName("UTF-8"));
                    } catch (UnsupportedCharsetException e) {
                        e.printStackTrace();
                    }
                }

                if (contentLength != 0) {
                    res = buffer.clone().readString(charset);
                }
            }

            PartitionData partitionData = mapper.readValue(res, PartitionData.class);
            if (partitionData.isSuccess()){
                return partitionData.getData();
            } else {
                log.warn(partitionData.getErrMessage());
                return null;
            }
        } catch (Exception e) {
            log.warn(e.getMessage());
            return null;
        }
    }
}

你可能感兴趣的:(JAVA,大数据,Spark,spark,antlr)