在网上参考了很多资料,都是要依赖Sparksession,这个需要spark环境,非常不友好,jdk版本也不好控制。不使用Sparksession获取上下文,利用spark和antlr的静态方法使用java 实现简单的spark sql 的语法以及内置函数的校验。
1. spark 版本 3.2.0
org.apache.spark
spark-sql_2.12
3.2.0
org.antlr
antlr4-runtime
4.8
2. 完整校验代码
import org.antlr.v4.runtime.*;
import org.antlr.v4.runtime.atn.PredictionMode;
import org.antlr.v4.runtime.misc.ParseCancellationException;
import org.antlr.v4.runtime.tree.ParseTreeWalker;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.catalyst.parser.SqlBaseBaseListener;
import org.apache.spark.sql.catalyst.parser.SqlBaseLexer;
import org.apache.spark.sql.catalyst.parser.SqlBaseParser;
import org.apache.spark.sql.functions;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
public class SparkSqlChecker {
public static boolean validateSyntax(String sql) {
System.out.println("sql: " + sql);
// 1.预处理:替换换行符为空格,并去除非必要空格
sql = sql.replaceAll("\\s+", " ").trim();
// 2. 关键字转大写
CharStream input = new UpperCaseCharStream(sql);
SqlBaseLexer lexer = new SqlBaseLexer(input);
// 3.语法分析
CommonTokenStream tokens = new CommonTokenStream(lexer);
SqlBaseParser parser = new SqlBaseParser(tokens);
parser.getInterpreter().setPredictionMode(PredictionMode.LL_EXACT_AMBIG_DETECTION);
parser.setErrorHandler(new BailErrorStrategy());
parser.removeErrorListeners();
parser.addErrorListener(new ParseErrorListener());
try {
SqlBaseParser.SingleStatementContext context = parser.singleStatement();
// 检查是否解析到输入末尾
if (parser.getInputStream().LA(1) != Token.EOF) {
throw new ParseCancellationException("未完整解析SQL语句");
}
//4.函数校验
List functions = new ArrayList<>();
ParseTreeWalker walker = new ParseTreeWalker();
walker.walk(new FunctionNameCollector(functions), context);
System.out.println("sql 中的函数:"+functions.stream().collect(Collectors.joining(",")));
for (String function : functions) {
if(!getBuiltInFunctions().contains(function)) {
System.err.println(String.format("%s 函数不存在", function));
return false;
}
}
return true;
} catch (Exception e) {
System.err.println(e.getMessage());
return false;
}
}
/**
* 自定义监听器
*/
private static class ParseErrorListener extends BaseErrorListener {
@Override
public void syntaxError(Recognizer recognizer,
Object offendingSymbol,
int line,
int charPositionInLine,
String msg,
RecognitionException e) {
throw new ParseCancellationException(
String.format("Line %d, Column %d: %s",
line, charPositionInLine + 1, msg)
);
}
}
/**
* 收集SQL中函数名
*/
private static class FunctionNameCollector extends SqlBaseBaseListener {
private final List functions;
public FunctionNameCollector(List functions) {
this.functions = functions;
}
@Override
public void enterFunctionCall(SqlBaseParser.FunctionCallContext ctx) {
String funcName = ctx.functionName().getText().toUpperCase();
functions.add(funcName);
}
}
/**
* 获取spark sql 允许使用的内置函数
* @return
*/
private static List getBuiltInFunctions() {
List functionList = new ArrayList<>();
try {
// 反射获取 FunctionRegistry 类
// Class functionsClass = Class.forName("org.apache.spark.sql.functions");
Method[] methods = functions.class.getMethods();
// 过滤出返回类型为Column的静态方法(即内置函数)
for (Method method : methods) {
if (method.getReturnType().equals(Column.class) &&
java.lang.reflect.Modifier.isStatic(method.getModifiers())) {
functionList.add(method.getName());
}
}
functionList = functionList.stream().distinct().map(s -> s.toUpperCase()).collect(Collectors.toList());
return functionList;
} catch (Exception e) {
System.err.println("反射调用失败: " + e.getMessage());
}
return Collections.emptyList();
}
}
/**
* 关键字转大写,antlr 有严格的大小写区分
*/
class UpperCaseCharStream extends ANTLRInputStream {
public UpperCaseCharStream(String input) {
super(input);
}
@Override
public int LA(int i) {
int c = super.LA(i);
return c == CharStream.EOF ? c : Character.toUpperCase((char) c);
}
}
3. 测试
public static void main(String[] args) {
System.out.println(validateSyntax("selec "));
System.out.println("-------------------------------------------------");
System.out.println(validateSyntax("select m "));
System.out.println("-------------------------------------------------");
System.out.println(validateSyntax("SELECT m FROM"));
System.out.println("-------------------------------------------------");
System.out.println(validateSyntax("selec m FROM x.xx"));
System.out.println("-------------------------------------------------");
System.out.println(validateSyntax("SELECT * FROM x.xx"));
System.out.println("-------------------------------------------------");
System.out.println(validateSyntax("""
select m from xx.xxx
where a = MAX(bi_dt) and b = DATE_ADD(DATE_FORMAT(dat_parse('20250101' ,'yyyyMMdd') ,'yyyy-MM-dd') ,2)
and c = 'c'
"""));
System.out.println("-------------------------------------------------");
System.out.println(validateSyntax(" select m from xx.xxx where a ="));
System.out.println("-------------------------------------------------");
System.out.println(validateSyntax(" select m from xx.xxx where a = a b="));
System.out.println("-------------------------------------------------");
}
//输出:
sql: selec
Line 1, Column 1: no viable alternative at input 'selec'
false
-------------------------------------------------
sql: select m
true
-------------------------------------------------
sql: SELECT m FROM
true
-------------------------------------------------
sql: selec m FROM x.xx
Line 1, Column 1: no viable alternative at input 'selec'
false
-------------------------------------------------
sql: SELECT * FROM x.xx
true
-------------------------------------------------
sql: select m from xx.xxx
where a = MAX(bi_dt) and b = DATE_ADD(DATE_FORMAT(dat_parse('20250101' ,'yyyyMMdd') ,'yyyy-MM-dd') ,2)
and c = 'c'
sql 中的函数:MAX,DATE_ADD,DATE_FORMAT,DAT_PARSE
DAT_PARSE 函数不存在
false
-------------------------------------------------
sql: select m from xx.xxx where a =
Line 1, Column 31: no viable alternative at input ''
false
-------------------------------------------------
sql: select m from xx.xxx where a = a b=
null
false
-------------------------------------------------
4. 基本的语法校验没有问题,其中还是有一些错误校验,比如,SELECT m FROM 这个sql ,应该校验不通过,from 后面没有表或者子查询,这种是一个很明显的错误,没有校验出来,后面再优化。