最近有需求,需要拆解Spark SQL中的表,字段等信息,然后配合Ranger实现一些权限校验。
其实难度不大,就是需要根据语法树做一些递归拆解,然后就能拆解出一段SQL中的相关信息,再创建一些数据结构bean对象用于配合校验。
下面是部分源码(全部本人原创),不涉及业务信息。
对于Presto我也做了拆解,欢迎沟通交流。
package test.server.utils;
import com.fasterxml.jackson.databind.ObjectMapper;
import test.ranger.AccessType;
import test.ranger.ColumnAccess;
import test.ranger.ColumnsWithTable;
import test.ranger.PartitionData;
import test.server.service.impl.SparkSqlJob;
import lombok.extern.slf4j.Slf4j;
import okhttp3.MediaType;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.Response;
import okhttp3.ResponseBody;
import okio.Buffer;
import okio.BufferedSource;
import org.antlr.v4.runtime.*;
import org.antlr.v4.runtime.tree.ParseTree;
import org.antlr.v4.runtime.tree.RuleNode;
import org.antlr.v4.runtime.tree.TerminalNode;
import org.apache.spark.sql.catalyst.parser.SqlBaseParser;
import java.io.IOException;
import java.nio.charset.Charset;
import java.nio.charset.UnsupportedCharsetException;
import java.sql.SQLException;
import java.util.*;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import static java.lang.String.format;
@Slf4j
public class SparkSqlParser {
private final ObjectMapper mapper = new ObjectMapper();
private Pattern pattern = Pattern.compile(".*(?i)set\\s+(.+)=(.+)");
private String currentDB = "default";
private Map items = new HashMap<>();
private Set udfList = new HashSet<>();
private Set tempTable = new HashSet<>();
private OkHttpClient client = new OkHttpClient().newBuilder().build();
private String metadataUrl;
private String metadataToken;
public SparkSqlParser(String metadataUrl, String metadataToken) {
this.metadataUrl = metadataUrl;
this.metadataToken = metadataToken;
}
public Map getItems() {
return items;
}
public Set getUdfList() {
return udfList;
}
public Map getConfigFromSql(String originalSql) throws IOException {
List commands = SparkSqlJob.prepare(originalSql);
HashMap configMap = new HashMap<>();
for (String command : commands) {
if (command.trim().toLowerCase().startsWith("set")) {
try {
Matcher m = pattern.matcher(command);
if (m.find()) {
String key = m.group(1).trim();
String value = m.group(2).trim();
configMap.put(key, value);
}
} catch (Exception e) {
log.warn(e.getMessage(), e);
}
}
}
return configMap;
}
public void reset(){
items.clear();
udfList.clear();
tempTable.clear();
}
public void visit(ParseTree t) throws Exception {
if (t instanceof TerminalNode){
} else if (t instanceof SqlBaseParser.UseContext){
currentDB = ((SqlBaseParser.UseContext) t)
.getText().toLowerCase().replaceFirst("use", "").trim();
} else if (t instanceof SqlBaseParser.QuerySpecificationContext) { //select权限
visitSelect(t, null);
} else if (t instanceof SqlBaseParser.CtesContext) {
SqlBaseParser.CtesContext with = (SqlBaseParser.CtesContext) t;
for (SqlBaseParser.NamedQueryContext item : with.namedQuery()) {
tempTable.add(item.name.getText().replace("`", "").toLowerCase());
visit(item);
}
} else if (t instanceof SqlBaseParser.CacheTableContext){
SqlBaseParser.CacheTableContext cache = (SqlBaseParser.CacheTableContext) t;
tempTable.add(cache.multipartIdentifier().getText().replace("`", "").toLowerCase());
} else if (t instanceof SqlBaseParser.InsertIntoTableContext
|| t instanceof SqlBaseParser.InsertOverwriteTableContext) { // write权限
visitAndAddTable(t, AccessType.write.toString());
} else if (t instanceof SqlBaseParser.CreateTableHeaderContext){ //create权限
SqlBaseParser.CreateTableHeaderContext header = (SqlBaseParser.CreateTableHeaderContext) t;
String table = header.multipartIdentifier().getText();
if (header.TEMPORARY() != null){
tempTable.add(table.replace("`", "").toLowerCase());
} else {
addItem(table, AccessType.create.toString(), null);
}
} else if (t instanceof SqlBaseParser.CreateViewContext){
tempTable.add(((SqlBaseParser.CreateViewContext) t)
.multipartIdentifier().getText().replace("`", "").toLowerCase());
} else if (t instanceof SqlBaseParser.CreateTableLikeContext
|| t instanceof SqlBaseParser.CreateTempViewUsingContext) {
visitAndAddTable(t, AccessType.create.toString());
} else if (t instanceof SqlBaseParser.DropTableContext // drop权限
|| t instanceof SqlBaseParser.DropTablePartitionsContext) {
visitAndAddTable(t, AccessType.drop.toString());
} else if (t instanceof SqlBaseParser.RenameTableContext // update权限
|| t instanceof SqlBaseParser.RenameTablePartitionContext
|| t instanceof SqlBaseParser.AlterViewQueryContext
|| t instanceof SqlBaseParser.AddTablePartitionContext
|| t instanceof SqlBaseParser.HiveChangeColumnContext) {
visitAndAddTable(t, AccessType.update.toString());
} else {
RuleNode r = (RuleNode)t;
int n = r.getChildCount();
for(int i = 0; i < n; ++i) {
ParseTree child = r.getChild(i);
visit(child);
}
}
}
private void visitConfig(ParseTree t, Map configMap) throws RecognitionException {
if (t instanceof SqlBaseParser.SetConfigurationContext) {
String[] config = t.getText().toLowerCase()
.replace("set", "")
.replace(" ", "")
.split("=");
if (config.length == 2){
configMap.put(config[0], config[1]);
}
} else if (!(t instanceof TerminalNode)) {
RuleNode r = (RuleNode)t;
int n = r.getChildCount();
for(int i = 0; i < n; ++i) {
ParseTree child = r.getChild(i);
visitConfig(child, configMap);
}
}
}
private String addItem(SqlBaseParser.TableIdentifierContext tableIdentifier, String accessType, Set columns){
String tableName = tableIdentifier.table.getText().replace("`", "").toLowerCase();
if (tempTable.contains(tableName)) {
return ""; //table create in this sql, we don't need
}
String dbName = tableIdentifier.db == null ? currentDB : tableIdentifier.db.getText().toLowerCase();
String table = dbName + "." + tableName;
if (items.containsKey(table)) {
items.get(table).addAccess(accessType, columns);
} else {
items.put(table, new ColumnAccess(accessType, columns));
}
return table;
}
private String addItem(String table, String accessType, Set columns){
String[] part = table.replace("`", "").toLowerCase().split("\\.");
String tableName = "";
String fullName = "";
if (part.length == 1){
tableName = part[0];
fullName = currentDB + "." + tableName;
} else if (part.length == 2){
tableName = part[1];
fullName = part[0] + "." + part[1];
}
if (tempTable.contains(tableName)) {
return ""; //table create in this sql, we don't need
}
if (items.containsKey(fullName)) {
items.get(fullName).addAccess(accessType, columns);
} else {
items.put(fullName, new ColumnAccess(accessType, columns));
}
return fullName;
}
private void visitAndAddTable(ParseTree t, String access) throws Exception {
RuleNode r = (RuleNode)t;
int n = r.getChildCount();
for(int i = 0; i < n; ++i) {
ParseTree child = r.getChild(i);
if (child instanceof SqlBaseParser.TableIdentifierContext) {
addItem((SqlBaseParser.TableIdentifierContext) child, access, null);
} else {
visit(child);
}
}
}
private String visitSelect(ParseTree t, ColumnsWithTable columns) throws Exception {
if (t instanceof TerminalNode){
return "";
} else if (t instanceof SqlBaseParser.RegularQuerySpecificationContext) {
String table = "";
Set compareColumns = new HashSet<>();
for (ParseTree child : ((ParserRuleContext) t).children) {
if (child instanceof SqlBaseParser.SelectClauseContext){ //select
if (columns == null){
columns = new ColumnsWithTable();
}
visitColumn(child, columns);
} else if (child instanceof SqlBaseParser.FromClauseContext){ //from
table = visitSelect(child, columns);
} else if (child instanceof SqlBaseParser.WhereClauseContext){
visitWhere(child, compareColumns);
}
}
if (table != null && !table.isEmpty()){
checkPartition(table, compareColumns);
}
return table;
} else if (t instanceof SqlBaseParser.TableNameContext) {
SqlBaseParser.TableNameContext table = (SqlBaseParser.TableNameContext) t;
SqlBaseParser.MultipartIdentifierContext multipart = table.multipartIdentifier();
String tableName = table.multipartIdentifier().getText();
if (!tableName.isEmpty()){
String alias = table.tableAlias().getText().toLowerCase();
if (alias.isEmpty()){
return addItem(tableName, AccessType.select.toString(), columns.getMatchColumns());
} else {
return addItem(tableName, AccessType.select.toString(), columns.getMatchColumns(alias));
}
}
return "";
} else if (t instanceof SqlBaseParser.TableIdentifierContext) {
return addItem((SqlBaseParser.TableIdentifierContext) t, AccessType.select.toString(), columns.getMatchColumns());
} else {
RuleNode r = (RuleNode)t;
String table = "";
int n = r.getChildCount();
for(int i = 0; i < n; ++i) {
ParseTree child = r.getChild(i);
table = visitSelect(child, columns);
}
return table;
}
}
private void visitColumn(ParseTree t, ColumnsWithTable columns) throws Exception {
if (t instanceof TerminalNode) {
} else if (t instanceof SqlBaseParser.ColumnReferenceContext) {
columns.addColumn(t.getText().toLowerCase());
} else if (t instanceof SqlBaseParser.DereferenceContext) {
columns.addColumn(t.getText().toLowerCase());
} else if (t instanceof SqlBaseParser.FunctionCallContext) {
SqlBaseParser.FunctionCallContext function = (SqlBaseParser.FunctionCallContext) t;
udfList.add(function.start.getText().toLowerCase());
for (ParseTree child : function.children) {
visitColumn(child, columns);
}
} else {
RuleNode r = (RuleNode)t;
int n = r.getChildCount();
for(int i = 0; i < n; ++i) {
ParseTree child = r.getChild(i);
visitColumn(child, columns);
}
}
}
private void visitWhere(ParseTree t, Set compareColumns) throws Exception {
if (t instanceof TerminalNode) {
} else if (t instanceof SqlBaseParser.ColumnReferenceContext) {
SqlBaseParser.ColumnReferenceContext column = (SqlBaseParser.ColumnReferenceContext) t;
compareColumns.add(column.getText().replace("`", "").toLowerCase());
} else {
RuleNode r = (RuleNode)t;
int n = r.getChildCount();
for(int i = 0; i < n; ++i) {
ParseTree child = r.getChild(i);
visitWhere(child, compareColumns);
}
}
}
private void checkPartition(String tableName, Set compareColumns) throws Exception {
String[] tablePart = tableName.split("\\.");
if (tablePart.length == 2) {
String url = format("%s/partitionColumns?dbName=%s&tableName=%s", metadataUrl, tablePart[0], tablePart[1]);
Request request = new Request.Builder()
.url(url).method("GET", null)
.addHeader("token", metadataToken)
.build();
List partitions = getPartitions(client.newCall(request).execute());
if (partitions != null && partitions.size() > 0 && !compareColumns.contains(partitions.get(0))) {
throw new SQLException(
format("查询分区表%s时,请带上分区(%s)作为where限制条件.", tableName, partitions.get(0)));
}
}
}
public List getPartitions(Response response) {
try {
String res = "";
ResponseBody responseBody = response.body();
String contentEncoding = response.headers().get("Content-Encoding");
long contentLength = responseBody.contentLength();
if (contentEncoding == null || contentEncoding.equalsIgnoreCase("identity")) {
BufferedSource source = responseBody.source();
try {
source.request(Long.MAX_VALUE); // Buffer the entire body.
} catch (IOException e) {
e.printStackTrace();
}
Buffer buffer = source.buffer();
Charset charset = Charset.forName("UTF-8");
MediaType contentType = responseBody.contentType();
if (contentType != null) {
try {
charset = contentType.charset(Charset.forName("UTF-8"));
} catch (UnsupportedCharsetException e) {
e.printStackTrace();
}
}
if (contentLength != 0) {
res = buffer.clone().readString(charset);
}
}
PartitionData partitionData = mapper.readValue(res, PartitionData.class);
if (partitionData.isSuccess()){
return partitionData.getData();
} else {
log.warn(partitionData.getErrMessage());
return null;
}
} catch (Exception e) {
log.warn(e.getMessage());
return null;
}
}
}