java使用jsqlparser实现入参,并生成可执行sql

话不多说,直接上 验证通过的代码

第一个例子:

package jdbc;

import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.Alias;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.schema.Database;
import net.sf.jsqlparser.schema.Table;
import net.sf.jsqlparser.statement.select.PlainSelect;
import net.sf.jsqlparser.statement.select.Select;
import net.sf.jsqlparser.statement.select.SelectExpressionItem;
import org.apache.commons.lang3.StringUtils;

public class GenSelectSqlBySqlParser {

    public static void main(String[] args) throws JSQLParserException {
        // 入参
        String database = "test"; // 数据库
        String table = "user"; // 表
        String[] fields = {"id", "name", "age", "count(id)", "avg(age)", "sum(a)", "calculate_sum(concat_ws(',', collect_list(id)))"}; // 字段
        //String[] functions = {"count(id)", "avg(age)", "sum(a)","calculate_sum(concat_ws(',', collect_list(id)))"}; // 函数(列)数组
        String[] fieldsAliases = {"", "", "", "c", "a", "s", "t"}; // 列的别名数组
        String where = "age > 20"; // where条件
        String groupBy = "name"; // group by
        String tableAlias = "temp";
        Select select = getSelectStr(database, table, fields, fieldsAliases, where, groupBy, tableAlias);

        // 输出sql语句
        System.out.println(select);
    }

    private static Select getSelectStr(String database, String table, String[] fields, String[] fieldsAliases, String where, String groupBy, String alias) throws JSQLParserException {
        // 创建一个select对象
        Select select = new Select();

        // 创建一个plainSelect对象,用于设置各种子句
        PlainSelect plainSelect = new PlainSelect();

        // 设置数据库
        Database db = new Database(database);
        db.setDatabaseName(database);

        // 设置表
        Table t = new Table();
        t.setDatabase(db);
        t.setName(table);
        t.setAlias(new Alias(alias)); // 给表设置别名
        // 设置字段
        for (int i = 0; i < fields.length; i++) {
            Column column = new Column();
            //column.setTable(t);
            column.setColumnName(fields[i]);
            // 解析函数表达式
            Expression expr = CCJSqlParserUtil.parseExpression(fields[i]);
            SelectExpressionItem item = new SelectExpressionItem();
            item.setExpression(expr);
            // 设置别名
            if(StringUtils.isNotEmpty(fieldsAliases[i])){
                item.setAlias(new Alias(fieldsAliases[i]));
            }
            plainSelect.addSelectItems(item);
        }
        plainSelect.setFromItem(t);
        // 设置where条件
        Expression expr = CCJSqlParserUtil.parseCondExpression(where);
        plainSelect.setWhere(expr);

        // 设置group by
        Expression groupByExpr = CCJSqlParserUtil.parseExpression(groupBy);
        plainSelect.addGroupByColumnReference(groupByExpr);

        // 将plainSelect设置为select的子句
        select.setSelectBody(plainSelect);
        return select;
    }
}

执行结果如下:

SELECT id, name, age, count(id) AS c, avg(age) AS a, sum(a) AS s, 
calculate_sum(concat_ws(',', collect_list(id))) AS t 
FROM test.user AS temp WHERE age > 20 GROUP BY name

第二个例子:

通过多个单条sql,生成关联sql,也就是 join on

package jdbc;

import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.Alias;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.statement.select.*;
import org.apache.commons.lang3.StringUtils;
import org.springframework.util.CollectionUtils;

import java.util.*;
import java.util.stream.Collectors;

public class c1 {


    public static void main(String[] args) throws JSQLParserException {
        // 输入
//        String express = "total_amount * average_amount / user_count + XXX1 - TTT2";
//        List sqlList = new ArrayList<>();
//        // SQL 语句 A
//        String sqlA = "SELECT username, SUM(total_amount) as total_amount FROM hadoop_ind.user_amount GROUP BY username";
//        // SQL 语句 B
//        String sqlB = "SELECT receiver_name as username, avg(total_amount) as average_amount FROM hadoop_ind.`order` GROUP BY receiver_name";
//        // SQL 语句 C
//        String sqlC = "SELECT username, count(*) as user_count FROM hadoop_ind.`user` GROUP BY username";
//        // SQL 语句 C
//        String sqlD = "SELECT username, AVG(*) as XXX1 FROM hadoop_ind.`XXX` GROUP BY username";
//        // SQL 语句 C
//        String sqlE = "SELECT username, MAX(*) as TTT2 FROM hadoop_ind.`TTT` GROUP BY username";
//        sqlList.add(sqlA);
//        sqlList.add(sqlB);
//        sqlList.add(sqlC);
//        sqlList.add(sqlD);
//        sqlList.add(sqlE);
        List sqlList = new ArrayList<>();
        //String express = " ( ( min(a)+max(b)) + (min(c)+max(d)) )*10";
        String express = "  (t1 + t2) *10";

        String sqlA = "select min(a)+max(b) as t1 from db1.table1";
        String sqlB = "select min(c)+max(d) as t2 from db2.table2";
        sqlList.add(sqlA);
        sqlList.add(sqlB);
        // 输出
        String result = combineSql(express, sqlList);
        System.out.println(result);
    }

    /**
     * 整体的拼接 复合指标的sql 逻辑
     *
     * @param express 表达式
     * @param sqlList 衍生指标sql集合
     * @return 处理后的sql
     */
    public static String combineSql(String express, List sqlList) throws JSQLParserException {
        // 存储每个sql语句对应的别名
        HashMap aliasMap = new HashMap<>();
        for (int i = 0; i < sqlList.size(); i++) {
            char c = (char) ('A' + i);
            aliasMap.put(String.valueOf(c), sqlList.get(i));
        }

        // 解析每个sql语句,获取其中的select items, from items, where items, group by items等
        Map plainSelectMap = getPlainSelectMap(aliasMap);

        // 构造子查询
        List subSelectList = getSubSelects(aliasMap, plainSelectMap);

        // 构造父查询
        PlainSelect parentSelect = new PlainSelect();
        // 设置select items
        List selectItems = new ArrayList<>();
        //这里处理非函数列的列名
        setColumnName(plainSelectMap, selectItems);

        //这里处理函数列的列名(目前只支持 一个函数计算的表达式)
        setFunctionName(express, aliasMap, selectItems, parentSelect);

        // 设置from item,只拿第一个sql 作为from ,后面的sql 作为 join
        parentSelect.setFromItem(subSelectList.get(0));
        // 设置join items
        parentSelect.setJoins(getJoins(plainSelectMap, subSelectList));

        // 生成拼接后的字符串
        Select finalSelect = new Select();
        finalSelect.setSelectBody(parentSelect);
        return finalSelect.toString();
    }

    /**
     * join items
     */
    private static List getJoins(Map plainSelectMap, List subSelectList) {
        List joins = new ArrayList<>();
        for (int i = 1; i < subSelectList.size(); i++) {
            Join join = new Join();
            join.setRightItem(subSelectList.get(i));
            final Expression onExpression = getOnExpression(plainSelectMap, i);
            if (Objects.nonNull(onExpression)) {
                join.addOnExpression(onExpression);
            }
            joins.add(join);
        }
        return joins;
    }

    /**
     * 获取子查询
     */
    private static List getSubSelects(Map aliasMap, Map plainSelectMap) {
        List subSelectList = new ArrayList<>();
        for (String key : aliasMap.keySet()) {
            PlainSelect plainSelect = plainSelectMap.get(key);
            SubSelect subSelect = new SubSelect();
            subSelect.setAlias(new Alias(key));
            subSelect.setSelectBody(plainSelect);
            subSelectList.add(subSelect);
        }
        return subSelectList;
    }

    /**
     * 解析每个sql语句,获取其中的select items, from items, where items, group by items等
     */
    private static Map getPlainSelectMap(Map aliasMap) throws JSQLParserException {
        Map plainSelectMap = new HashMap<>();
        for (Map.Entry entry : aliasMap.entrySet()) {
            String sql = entry.getValue(); // 获取键
            Select select = (Select) CCJSqlParserUtil.parse(sql);
            PlainSelect plainSelect = (PlainSelect) select.getSelectBody();
            plainSelectMap.put(entry.getKey(), plainSelect);
        }
        return plainSelectMap;
    }

    /**
     * 获取非函数的字段列
     * 这里去遍历每个sql里面的 非函数字段,
     * 取出来后,看是否有别名,没有别名就取原本的名称,
     * 然后进行分组,再比较是否一致
     *
     * @param plainSelectMap 查询sql 和别名的映射map
     * @return 字段列
     */
    public static List getOnColumnName(
            Map plainSelectMap
    ) {
        List columnList = new ArrayList<>();
        String columnName;
        for (Map.Entry entry : plainSelectMap.entrySet()) {
            String key = entry.getKey(); // 获取键
            PlainSelect plainSelect = entry.getValue(); // 获取值
            List subSelectItems = plainSelect.getSelectItems();
            for (SelectItem selectItem : subSelectItems) {
                if (selectItem instanceof SelectExpressionItem) {
                    Expression expression = ((SelectExpressionItem) selectItem).getExpression();
                    if (expression instanceof Column) {
                        final Alias alias = ((SelectExpressionItem) selectItem).getAlias();
                        String aliasName = alias == null ? selectItem.toString() : alias.getName();
                        columnName = key + "." + aliasName;
                        columnList.add(columnName);
                    }
                }
            }
        }
        return columnList;

    }

    /**
     * 这里处理非函数列的列名
     *
     * @param plainSelectMap 查询sql 和别名的映射map
     */
    private static void setColumnName(
            Map plainSelectMap,
            List selectItems
    ) {
        List columnList = getOnColumnName(plainSelectMap);
        if (CollectionUtils.isEmpty(columnList)) {
            return;
        }
        int i = 1;
        // 这里取出 去匹配各个表的 column name 是否一致 ,一致就取出第一个
        final List stringList = compareGroups(columnList);
        for (String column : stringList) {
            //然后去设置 父查询的 select 的非函数字段
            SelectExpressionItem selectExpressionItem0 = new SelectExpressionItem();
            selectExpressionItem0.setExpression(new Column(column));
            selectExpressionItem0.setAlias(new Alias(String.format("column%d", i++)));
            selectItems.add(selectExpressionItem0);
        }
    }

    /**
     * 这里拼接 函数的名称
     *
     * @param express  String express = "total_amount-average_amount+user_count";
     * @param aliasMap A ,B,C
     *                 parts数组:A.total_amount,B.average_amount,C.user_count
     * @return 期望结果: A.total_amount-B.average_amount+C.user_count
     */
    private static String getFunctionName(String express, HashMap aliasMap) {
        // 分割表达式
        String[] parts = express.split("(?=[\\+\\-\\*\\/\\(\\)])|(?<=[\\+\\-\\*\\/\\(\\)])", -1);
        // 创建字符串缓冲区
        StringBuilder sb = new StringBuilder();
        // 遍历字符串数组
        for (String part : parts) {
            // 判断是否是一个变量
            part = part.trim();
            if (part.matches("[a-zA-Z][a-zA-Z0-9]*")) {
                // 查找对应的value
                for (String key : aliasMap.keySet()) {
                    if (aliasMap.get(key).contains(part)) {
                        String value = aliasMap.get(key);
                        if (value != null) {
                            // 追加替换后的内容
                            sb.append(key).append(".").append(part);
                        } else {
                            // 追加原变量名
                            sb.append(part);
                        }
                        break;
                    }
                }
            } else {
                // 追加运算符或括号
                sb.append(part);
            }
        }
        // 返回结果
        return sb.toString();
    }


    /**
     * 设置 函数,拼接最后的名称
     */
    private static void setFunctionName(
            String express,
            HashMap aliasMap,
            List selectItems,
            PlainSelect parentSelect
    ) throws JSQLParserException {
        String finalResult = getFunctionName(express, aliasMap);
        SelectExpressionItem selectExpressionItem = new SelectExpressionItem();
        assert finalResult != null;
        selectExpressionItem.setExpression(CCJSqlParserUtil.parseExpression(finalResult));
        selectExpressionItem.setAlias(new Alias("result"));
        selectItems.add(selectExpressionItem);
        parentSelect.setSelectItems(selectItems);
    }

    /**
     * 用于拼接字段名称,根据分组进行比对
     * 比如三个sql语句,A有 username,有id,B和C都有,那么就是输出 [A.id, A.username]
     * 比如三个sql语句,A有 username,B和C都有,那么就是输出 [A.username]
     *
     * @param dataList 数据集合 [A.username, A.id, B.username, B.id, C.username, C.id]
     * @return 返回的期望结果
     */
    private static List compareGroups(List dataList) {

        Map> categorizedData = dataList.stream()
                .map(item -> item.split("\\."))
                .filter(parts -> parts.length == 2)
                .map(parts -> new AbstractMap.SimpleEntry<>(parts[0], parts[1]))
                .collect(Collectors.groupingBy(Map.Entry::getKey, Collectors.mapping(Map.Entry::getValue, Collectors.toList())));
        //假设你的Map叫做categorizedData
        //创建一个Set来存储相同的值
        Set commonValues = new HashSet<>();

        for (String value : categorizedData.values().iterator().next()) {
            //假设这个值是在所有分组中都存在的
            boolean isCommon = true;
            //遍历Map的其他分组
            for (List list : categorizedData.values()) {
                //如果这个分组不包含这个值,说明它不是公共的,跳出循环
                if (!list.contains(value)) {
                    isCommon = false;
                    break;
                }
            }
            //如果这个值是公共的,把它加入到Set中
            if (isCommon) {
                //遍历Map的第一个分组的值
                Iterator iterator = categorizedData.keySet().iterator();
                commonValues.add(iterator.next() + "." + value);
            }
        }
        //输出Set中的所有元素
        return new ArrayList<>(commonValues);
    }

    /**
     * 获取on条件
     * 这里假设每个sql语句都有一个相同的列名作为连接条件,你可以根据你的需求修改这个方法
     *
     * @param plainSelectMap 查询sql 和别名的映射map
     * @param index          索引
     * @return on条件的内容
     */
    public static Expression getOnExpression(Map plainSelectMap, int index) {
        // 抽取出两个常量,避免重复计算
        final String key1 = String.valueOf((char) ('A' + index - 1));
        final String key2 = String.valueOf((char) ('A' + index));
        // 获取两个 PlainSelect 对象
        PlainSelect plainSelect1 = plainSelectMap.get(key1);
        PlainSelect plainSelect2 = plainSelectMap.get(key2);
        // 获取两个 SelectItem 列表
        List selectItems1 = plainSelect1.getSelectItems();
        List selectItems2 = plainSelect2.getSelectItems();
        // 获取两个列名
        String columnName1 = getOnColumnName(selectItems1);
        String columnName2 = getOnColumnName(selectItems2);
        // 如果列名为空,抛出异常
        if (!Objects.equals(columnName1, columnName2)) {
            throw new RuntimeException("No common column name found");
        }
        if (StringUtils.isEmpty(columnName1)) {
            //如果为空,则获取 group by的 字段
            columnName1 = getOnGroupColumnName(plainSelect1);
        }
        if (StringUtils.isEmpty(columnName2)) {
            //如果为空,则获取 group by的 字段
            columnName2 = getOnGroupColumnName(plainSelect2);
        }
        // 如果 表中没有 group by 字段 那么就不给 on条件即可
        if (StringUtils.isEmpty(columnName1) || StringUtils.isEmpty(columnName2)) {
            return null;
        }
        // 创建两个 Column 对象
        Column column1 = new Column(key1 + "." + columnName1);
        Column column2 = new Column(key2 + "." + columnName2);
        // 返回一个 EqualsTo 对象
        return new EqualsTo(column1, column2);
    }

    // 抽取出一个辅助方法,用来从 SelectItem 列表中获取列名
    private static String getOnColumnName(List selectItems) {
        // 遍历 SelectItem 列表
        for (SelectItem selectItem : selectItems) {
            // 如果是 SelectExpressionItem 类型的对象
            if (selectItem instanceof SelectExpressionItem) {
                // 获取表达式和别名
                Expression expression = ((SelectExpressionItem) selectItem).getExpression();
                Alias alias = ((SelectExpressionItem) selectItem).getAlias();
                // 如果表达式是 Column 类型的对象
                if (expression instanceof Column) {
                    // 获取列名和别名
                    String name = ((Column) expression).getColumnName();
                    return alias == null ? name : alias.getName();
                }
            }
        }
        // 如果没有找到匹配的列名或别名,返回 null
        return null;
    }

    private static String getOnGroupColumnName(PlainSelect plainSelect) {
        GroupByElement groupByElement = plainSelect.getGroupBy();
        if (Objects.isNull(groupByElement)) {
            return "";
        }
        final ExpressionList groupByExpressionList = groupByElement.getGroupByExpressionList();
        final List expressions = groupByExpressionList.getExpressions();
        // 遍历 SelectItem 列表
        StringBuilder name = new StringBuilder();
        for (Expression s : expressions) {
            name.append(((Column) s).getColumnName()).append(",");
        }
        name.deleteCharAt(name.length() - 1);
        return name.toString();
    }

}

执行结果:

SELECT
	(A.t1 + B.t2) * 10 AS result
FROM
	(
	SELECT
		min(a) + max(b) AS t1
	FROM
		db1.table1
	GROUP BY
		name) AS A
JOIN (
	SELECT
		min(c) + max(d) AS t2
	FROM
		db2.table2
	GROUP BY
		name) AS B 
ON  A.name = B.name


你可能感兴趣的:(java,sql,数据库)