写一个水平分表插件

        我们只考虑以 Mybatis 插件的方式写一个水平分表插件,不包含分库,所以不需要考虑代理数据源等问题。

        github 上大多数的实现都是:

1, 解析路由配置,知道哪些表需要路由,哪些表不需要路由,并知道路由参数;

2,SQL 解析,使用 JSqlParser 或 druid 的 SQL 解析器去获取静态 SQL 的路由键和路由键实际参数。因为如果找不到路由键,那就扫全表,如果找不到路由键实际参数,就不知道要路由到哪些子表,也要扫全表。

3,根据路由键和路由参数计算目标路由表,生成新的 SQL。

这些只能完成比较简单的单表增删改查,无法应对项目里复杂的 SQL 需求,还需要做其他的工作。如果我们查看 Sharding-JDBC 的一些原理的话,应该还会有“结果归并”这个步骤,要想做到分表查和在一张表查的结果一模一样,光靠上面的3个步骤是远远不够的。在 Mybatis 无法覆写 DefaultResultSetHandler 去归并结果,因为 Mybatis 代码里是写死的,我们只能拦截 StatementHandler#query 去获取结果,再对结果进行归并。

@Component
@Intercepts({
        @Signature(type = StatementHandler.class, 
method = "query", args = {Statement.class, ResultHandler.class})})
public class StatementHandlerInterceptor implements Interceptor {
     @Override
     public Object intercept(Invocation invocation) throws Throwable {
        Object target = invocation.getTarget();
        Object result = invocation.proceed();
        Object[] res = new Object[1];
        if (!(result instanceof List)) {
            return result;
        } else {
            res[0] = new ArrayList<>((List) result);
        }
        ...
}

怎么归并?就是对group by进行再次分组,对聚集函数进行重新计算,对order by 进行重排序,对limit 进行再划分,具体过程还是很复杂的。在结果归并之前,还要进行 SQL 重写,补列,有了 druid 语法解析器,补列操作还是挺简单的,就是通过 Visitor 模式对 SQL 进行重写。

下面这样的 SQL,

select t.* from (select avg(id) from tbl where id>10 
group by department order by name,age limit 5,3) t

改写成功后是这样: 

SELECT t.*
FROM (
	SELECT avg(id), sum(id) AS generate_sum
		, count(id) AS generate_count, department AS department, name AS name
		, age AS age
	FROM tbl
	WHERE id > 10
	GROUP BY department
	ORDER BY name, age
	LIMIT 0, 8
) t

代码如下:

public static class Solution11 {
        public static void main(String[] args) {
            String sql = "select t.* from (select avg(id) from tbl where id>10 group by department order by name,age limit 5,3) t";

            MySqlStatementParser parser = new MySqlStatementParser(sql);
            SQLSelectStatement selectStatement = (SQLSelectStatement) parser.parseStatement();
            MySqlSelectQueryBlock query = (MySqlSelectQueryBlock) selectStatement.getSelect().getQuery();


            query.accept(new SQLASTVisitorAdapter() {

                @Override
                public boolean visit(SQLLimit sqlLimit) {
                    SQLExpr offset = sqlLimit.getOffset();
                    SQLExpr rowCount = sqlLimit.getRowCount();
                    if (offset == null) {
                        return super.visit(sqlLimit);
                    }
                    int offsetNum = (Integer) ((SQLIntegerExpr) offset).getNumber();
                    if (offsetNum == 0) {
                        return super.visit(sqlLimit);
                    }
                    int rowCountNum = (Integer) ((SQLIntegerExpr) rowCount).getNumber();
                    sqlLimit.setOffset(0);
                    sqlLimit.setRowCount(offsetNum + rowCountNum);
                    return super.visit(sqlLimit);
                }

                @Override
                public boolean visit(SQLSelectGroupByClause groupBy) {
                    List items = groupBy.getItems();
                    SQLExpr sqlExpr = items.get(0);
                    MySqlSelectQueryBlock parent = findParent(sqlExpr);
                    List selectList = parent.getSelectList();
                    List diff = diff(selectList, items);
                    int i = 0;
                    for (SQLExpr item : diff) {
                        SQLSelectItem sqlSelectItem = new SQLSelectItem();
                        item.setParent(sqlSelectItem);
                        sqlSelectItem.setExpr(item);
                        sqlSelectItem.setAlias(((SQLIdentifierExpr)item).getName());
                        parent.addSelectItem(sqlSelectItem);
                        i++;
                    }
                    return super.visit(groupBy);
                }

                @Override
                public boolean visit(SQLOrderBy orderBy) {
                    List items = orderBy.getItems();
                    Map map = new HashMap<>();
                    for (SQLSelectOrderByItem item : items) {
                        SQLExpr expr = item.getExpr();
                        if (expr instanceof SQLIdentifierExpr) {
                            map.put(((SQLIdentifierExpr) expr).getName(), expr);
                        } else {
                            throw new RuntimeException("未知类型");
                        }
                    }
                    MySqlSelectQueryBlock parent = findParent(orderBy);
                    List selectList = parent.getSelectList();
                    for (Map.Entry entry : map.entrySet()) {
                        if (needAddItem(entry, selectList)) {
                            SQLSelectItem sqlSelectItem = new SQLSelectItem();
                            SQLExpr value = entry.getValue();
                            value.setParent(sqlSelectItem);
                            sqlSelectItem.setExpr(value);
                            sqlSelectItem.setAlias(entry.getKey());
                            selectList.add(sqlSelectItem);
                        }
                    }
                    return super.visit(orderBy);
                }

                @Override
                public boolean visit(SQLAggregateExpr x) {
                    SQLExpr param = null;
                    String paramName = "";
                    MySqlSelectQueryBlock parent = findParent(x);

                    List selectList = parent.getSelectList();
                    List collect = selectList.stream().map(SQLSelectItem::getExpr).filter(it -> it instanceof SQLAggregateExpr)
                            .collect(Collectors.toList());
                    SQLAggregateExpr count = new SQLAggregateExpr("count");
                    SQLAggregateExpr sum = new SQLAggregateExpr("sum");
                    boolean hasCount = false, hasSum = false;
                    if (CollectionUtil.isNotEmpty(collect)) {
                        for (SQLExpr sqlExpr : collect) {
                            SQLAggregateExpr aggregateExpr = (SQLAggregateExpr) sqlExpr;
                            if (aggregateExpr.getMethodName().equals("avg")) {
                                param = aggregateExpr.getArguments().get(0);
                                paramName = ((SQLIdentifierExpr) aggregateExpr.getArguments().get(0)).getName();
                            }
                        }
                        if (StrUtil.isBlank(paramName)) {
                            return true;
                        }
                        for (SQLExpr sqlExpr : collect) {
                            SQLAggregateExpr aggregateExpr = (SQLAggregateExpr) sqlExpr;
                            SQLExpr sqlIdentify = aggregateExpr.getArguments().get(0);
                            String name = "";
                            if (sqlIdentify instanceof SQLIdentifierExpr) {
                                name = ((SQLIdentifierExpr) sqlIdentify).getName();
                            }
                            if (aggregateExpr.getMethodName().equals("count") && name.equals(paramName)) {
                                hasCount = true;
                            }
                            if (aggregateExpr.getMethodName().equals("sum") && name.equals(paramName)) {
                                hasSum = true;
                            }
                        }
                    }
                    if (hasCount && hasSum) {
                        return true;
                    }
                    SQLSelectItem sumExpr = new SQLSelectItem();

                    sum.setParent(sumExpr);
                    sum.addArgument(param);
                    sumExpr.setParent(parent);
                    sumExpr.setExpr(sum);
                    sumExpr.setAlias("generate_sum");
                    parent.addSelectItem(sumExpr);

                    SQLSelectItem countExpr = new SQLSelectItem();
                    count.setParent(countExpr);
                    count.addArgument(param);
                    countExpr.setExpr(count);
                    countExpr.setParent(parent);
                    countExpr.setAlias("generate_count");
                    parent.addSelectItem(countExpr);
                    return super.visit(x);
                }
            });
            DbType mysql = JdbcConstants.MYSQL;
            sql = SQLUtils.toSQLString(query, mysql);
            System.out.println(sql);
        }

        private static List diff(List selectList, List items) {
            List res = new ArrayList<>();
            for (SQLExpr item : items) {
                String name = "";
                if (item instanceof SQLIdentifierExpr) {
                    name = ((SQLIdentifierExpr) item).getName();
                } else {
                    throw new RuntimeException("有其他的类型");
                }
                for (SQLSelectItem sqlSelectItem : selectList) {
                    if (sqlSelectItem.getAlias() != null && sqlSelectItem.getAlias().equalsIgnoreCase(name)) {
                        break;
                    }
                    if (SQLUtils.toSQLString(sqlSelectItem.getExpr()).equalsIgnoreCase(name)) {
                        break;
                    }
                }
                res.add(item);
            }
            return res;
        }

        private static MySqlSelectQueryBlock findParent(SQLExpr groupBy) {
            SQLObject parent = groupBy.getParent();
            while (!(parent instanceof MySqlSelectQueryBlock)) {
                parent = parent.getParent();
            }
            return (MySqlSelectQueryBlock) parent;
        }

        private static boolean needAddItem(Map.Entry entry, List selectList) {
            for (SQLSelectItem selectItem : selectList) {
                String alias = selectItem.getAlias();
                SQLExpr expr = selectItem.getExpr();
                if (StrUtil.isNotBlank(alias) && entry.getKey().equalsIgnoreCase(alias)) {
                    return false;
                }
                if (expr.toString().equalsIgnoreCase(entry.getKey())) {
                    return false;
                }
            }
            return true;
        }

        private static MySqlSelectQueryBlock findParent(SQLOrderBy x) {
            SQLObject parent = x.getParent();
            while (!(parent instanceof MySqlSelectQueryBlock)) {
                parent = parent.getParent();
            }
            return (MySqlSelectQueryBlock) parent;
        }
 
    }

结果归并部分,也就是我们面对的情况分别进行处理。

比如多字段排序就是如此

public boolean visit(SQLOrderBy orderBy) {
                    List items = orderBy.getItems();
                    ArrayList list = new ArrayList<>();
                    for (SQLSelectOrderByItem item : items) {
                        SQLExpr itemExpr = item.getExpr();
                        if (itemExpr instanceof SQLIdentifierExpr) {
                            list.add(((SQLIdentifierExpr) itemExpr).getName());
                        } else {
                            throw new RuntimeException("未知类型");
                        }
                    }
                    List tmp = (List) res[0];
                    List collect = tmp.stream().sorted(new Comparator() {
                        @Override
                        public int compare(Object o1, Object o2) {
                            MetaObject metaObjectO1 = MetaObject.forObject(o1, DEFAULT_OBJECT_FACTORY, DEFAULT_OBJECT_WRAPPER_FACTORY, DEFAULT_REFLECTOR_FACTORY);
                            MetaObject metaObjectO2 = MetaObject.forObject(o2, DEFAULT_OBJECT_FACTORY, DEFAULT_OBJECT_WRAPPER_FACTORY, DEFAULT_REFLECTOR_FACTORY);
                            for (String order : list) {
                                Object value1 = null;
                                String javaOrder = SqlUtil.toJavaProperty(order);
                                if (metaObjectO1.hasGetter(order)) {
                                    value1 = metaObjectO1.getValue(order);
                                } else if (metaObjectO1.hasGetter(javaOrder)) {
                                    value1 = metaObjectO1.getValue(javaOrder);
                                } else {
                                    throw new RuntimeException("没有值");
                                }

                                Object value2 = null;
                                if (metaObjectO2.hasGetter(order)) {
                                    value2 = metaObjectO2.getValue(order);
                                } else if (metaObjectO2.hasGetter(javaOrder)) {
                                    value2 = metaObjectO2.getValue(javaOrder);
                                } else {
                                    throw new RuntimeException("没有值");
                                }
                                // long,int 比较都是用 equals
                                if (value1.equals(value2)) {
                                    continue;
                                }
                                if (value1 instanceof String) {
                                    if (((String) value1).equalsIgnoreCase((String) value2)) {
                                        continue;
                                    }
                                    return ((String) value1).compareTo((String) value2);
                                } else if (value1 instanceof Long) {
                                    return ((Long) value1).compareTo((Long) value2);
                                } else if (value1 instanceof Integer) {
                                    return ((Integer) value1).compareTo((Integer) value2);
                                } else if (value1 instanceof Double) {
                                    return ((Double) value1).compareTo((Double) value2);
                                } else if (value1 instanceof Float) {
                                    return ((Float) value1).compareTo((Float) value2);
                                }
                            }
                            return 0;
                        }
                    }).collect(Collectors.toList());
                    res[0] = collect;
                    return super.visit(orderBy);
                }

代码: yzp/sharding-plugin

你可能感兴趣的:(mybatis,mysql,sql,java)