Spark Sql 简单校验的实现

        在网上参考了很多资料,都是要依赖Sparksession,这个需要spark环境,非常不友好,jdk版本也不好控制。不使用Sparksession获取上下文,利用spark和antlr的静态方法使用java 实现简单的spark sql 的语法以及内置函数的校验。

1. spark 版本 3.2.0

        
            org.apache.spark
            spark-sql_2.12
            3.2.0
        
        
            org.antlr
            antlr4-runtime
            4.8
        

2.  完整校验代码

import org.antlr.v4.runtime.*;
import org.antlr.v4.runtime.atn.PredictionMode;
import org.antlr.v4.runtime.misc.ParseCancellationException;
import org.antlr.v4.runtime.tree.ParseTreeWalker;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.catalyst.parser.SqlBaseBaseListener;
import org.apache.spark.sql.catalyst.parser.SqlBaseLexer;
import org.apache.spark.sql.catalyst.parser.SqlBaseParser;
import org.apache.spark.sql.functions;

import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;

public class SparkSqlChecker {
    public static boolean validateSyntax(String sql) {
        System.out.println("sql: " + sql);
        // 1.预处理:替换换行符为空格,并去除非必要空格
        sql = sql.replaceAll("\\s+", " ").trim();

        // 2. 关键字转大写
        CharStream input = new UpperCaseCharStream(sql);
        SqlBaseLexer lexer = new SqlBaseLexer(input);

        // 3.语法分析
        CommonTokenStream tokens = new CommonTokenStream(lexer);
        SqlBaseParser parser = new SqlBaseParser(tokens);
        parser.getInterpreter().setPredictionMode(PredictionMode.LL_EXACT_AMBIG_DETECTION);
        parser.setErrorHandler(new BailErrorStrategy());
        parser.removeErrorListeners();
        parser.addErrorListener(new ParseErrorListener());

        try {
            SqlBaseParser.SingleStatementContext context = parser.singleStatement();
            // 检查是否解析到输入末尾
            if (parser.getInputStream().LA(1) != Token.EOF) {
                throw new ParseCancellationException("未完整解析SQL语句");
            }

            //4.函数校验
            List functions = new ArrayList<>();
            ParseTreeWalker walker = new ParseTreeWalker();
            walker.walk(new FunctionNameCollector(functions), context);
            System.out.println("sql 中的函数:"+functions.stream().collect(Collectors.joining(",")));
            for (String function : functions) {
                if(!getBuiltInFunctions().contains(function)) {
                    System.err.println(String.format("%s 函数不存在", function));
                    return false;
                }
            }

            return true;
        } catch (Exception e) {
            System.err.println(e.getMessage());
            return false;
        }
    }

    /**
     * 自定义监听器
     */
    private static class ParseErrorListener extends BaseErrorListener {
        @Override
        public void syntaxError(Recognizer recognizer,
                                Object offendingSymbol,
                                int line,
                                int charPositionInLine,
                                String msg,
                                RecognitionException e) {
            throw new ParseCancellationException(
                    String.format("Line %d, Column %d: %s",
                            line, charPositionInLine + 1, msg)
            );
        }
    }

    /**
     * 收集SQL中函数名
     */
    private static class FunctionNameCollector extends SqlBaseBaseListener {
        private final List functions;

        public FunctionNameCollector(List functions) {
            this.functions = functions;
        }

        @Override
        public void enterFunctionCall(SqlBaseParser.FunctionCallContext ctx) {
            String funcName = ctx.functionName().getText().toUpperCase();
            functions.add(funcName);
        }
    }

    /**
     * 获取spark sql 允许使用的内置函数
     * @return
     */
    private static List getBuiltInFunctions() {
        List functionList = new ArrayList<>();
        try {
            // 反射获取 FunctionRegistry 类
           // Class functionsClass = Class.forName("org.apache.spark.sql.functions");
            Method[] methods = functions.class.getMethods();
            // 过滤出返回类型为Column的静态方法(即内置函数)
            for (Method method : methods) {
                if (method.getReturnType().equals(Column.class) &&
                        java.lang.reflect.Modifier.isStatic(method.getModifiers())) {
                    functionList.add(method.getName());
                }
            }
            functionList = functionList.stream().distinct().map(s -> s.toUpperCase()).collect(Collectors.toList());
            return functionList;
        } catch (Exception e) {
            System.err.println("反射调用失败: " + e.getMessage());
        }
        return Collections.emptyList();
    }
}
/**
 * 关键字转大写,antlr 有严格的大小写区分
 */
class UpperCaseCharStream extends ANTLRInputStream {
    public UpperCaseCharStream(String input) {
        super(input);
    }

    @Override
    public int LA(int i) {
        int c = super.LA(i);
        return c == CharStream.EOF ? c : Character.toUpperCase((char) c);
    }
}

3. 测试

    public static void main(String[] args) {

        System.out.println(validateSyntax("selec "));
        System.out.println("-------------------------------------------------");
        System.out.println(validateSyntax("select m "));
        System.out.println("-------------------------------------------------");
        System.out.println(validateSyntax("SELECT m FROM"));
        System.out.println("-------------------------------------------------");
        System.out.println(validateSyntax("selec m FROM x.xx"));
        System.out.println("-------------------------------------------------");
        System.out.println(validateSyntax("SELECT * FROM x.xx"));
        System.out.println("-------------------------------------------------");
        System.out.println(validateSyntax("""
                select m  from xx.xxx
                 where a = MAX(bi_dt) and b = DATE_ADD(DATE_FORMAT(dat_parse('20250101' ,'yyyyMMdd') ,'yyyy-MM-dd') ,2)
                   and c = 'c'
                """));
        System.out.println("-------------------------------------------------");
        System.out.println(validateSyntax(" select m  from xx.xxx where a ="));
        System.out.println("-------------------------------------------------");
        System.out.println(validateSyntax(" select m  from xx.xxx where a = a b="));
        System.out.println("-------------------------------------------------");

    }

//输出:
sql: selec 
Line 1, Column 1: no viable alternative at input 'selec'
false
-------------------------------------------------
sql: select m 
true
-------------------------------------------------
sql: SELECT m FROM
true
-------------------------------------------------
sql: selec m FROM x.xx
Line 1, Column 1: no viable alternative at input 'selec'
false
-------------------------------------------------
sql: SELECT * FROM x.xx
true
-------------------------------------------------
sql: select m  from xx.xxx
 where a = MAX(bi_dt) and b = DATE_ADD(DATE_FORMAT(dat_parse('20250101' ,'yyyyMMdd') ,'yyyy-MM-dd') ,2)
   and c = 'c'
sql 中的函数:MAX,DATE_ADD,DATE_FORMAT,DAT_PARSE
DAT_PARSE 函数不存在
false
-------------------------------------------------
sql:  select m  from xx.xxx where a =
Line 1, Column 31: no viable alternative at input ''
false
-------------------------------------------------
sql:  select m  from xx.xxx where a = a b=
null
false
-------------------------------------------------

4. 基本的语法校验没有问题,其中还是有一些错误校验,比如,SELECT m FROM  这个sql ,应该校验不通过,from 后面没有表或者子查询,这种是一个很明显的错误,没有校验出来,后面再优化。

你可能感兴趣的:(spark,sql,java)