我们只考虑以 Mybatis 插件的方式写一个水平分表插件,不包含分库,所以不需要考虑代理数据源等问题。
github 上大多数的实现都是:
1, 解析路由配置,知道哪些表需要路由,哪些表不需要路由,并知道路由参数;
2,SQL 解析,使用 JSqlParser 或 druid 的 SQL 解析器去获取静态 SQL 的路由键和路由键实际参数。因为如果找不到路由键,那就扫全表,如果找不到路由键实际参数,就不知道要路由到哪些子表,也要扫全表。
3,根据路由键和路由参数计算目标路由表,生成新的 SQL。
这些只能完成比较简单的单表增删改查,无法应对项目里复杂的 SQL 需求,还需要做其他的工作。如果我们查看 Sharding-JDBC 的一些原理的话,应该还会有“结果归并”这个步骤,要想做到分表查和在一张表查的结果一模一样,光靠上面的3个步骤是远远不够的。在 Mybatis 无法覆写 DefaultResultSetHandler 去归并结果,因为 Mybatis 代码里是写死的,我们只能拦截 StatementHandler#query 去获取结果,再对结果进行归并。
@Component
@Intercepts({
@Signature(type = StatementHandler.class,
method = "query", args = {Statement.class, ResultHandler.class})})
public class StatementHandlerInterceptor implements Interceptor {
@Override
public Object intercept(Invocation invocation) throws Throwable {
Object target = invocation.getTarget();
Object result = invocation.proceed();
Object[] res = new Object[1];
if (!(result instanceof List)) {
return result;
} else {
res[0] = new ArrayList<>((List
怎么归并?就是对group by进行再次分组,对聚集函数进行重新计算,对order by 进行重排序,对limit 进行再划分,具体过程还是很复杂的。在结果归并之前,还要进行 SQL 重写,补列,有了 druid 语法解析器,补列操作还是挺简单的,就是通过 Visitor 模式对 SQL 进行重写。
下面这样的 SQL,
select t.* from (select avg(id) from tbl where id>10
group by department order by name,age limit 5,3) t
改写成功后是这样:
SELECT t.*
FROM (
SELECT avg(id), sum(id) AS generate_sum
, count(id) AS generate_count, department AS department, name AS name
, age AS age
FROM tbl
WHERE id > 10
GROUP BY department
ORDER BY name, age
LIMIT 0, 8
) t
代码如下:
public static class Solution11 {
public static void main(String[] args) {
String sql = "select t.* from (select avg(id) from tbl where id>10 group by department order by name,age limit 5,3) t";
MySqlStatementParser parser = new MySqlStatementParser(sql);
SQLSelectStatement selectStatement = (SQLSelectStatement) parser.parseStatement();
MySqlSelectQueryBlock query = (MySqlSelectQueryBlock) selectStatement.getSelect().getQuery();
query.accept(new SQLASTVisitorAdapter() {
@Override
public boolean visit(SQLLimit sqlLimit) {
SQLExpr offset = sqlLimit.getOffset();
SQLExpr rowCount = sqlLimit.getRowCount();
if (offset == null) {
return super.visit(sqlLimit);
}
int offsetNum = (Integer) ((SQLIntegerExpr) offset).getNumber();
if (offsetNum == 0) {
return super.visit(sqlLimit);
}
int rowCountNum = (Integer) ((SQLIntegerExpr) rowCount).getNumber();
sqlLimit.setOffset(0);
sqlLimit.setRowCount(offsetNum + rowCountNum);
return super.visit(sqlLimit);
}
@Override
public boolean visit(SQLSelectGroupByClause groupBy) {
List items = groupBy.getItems();
SQLExpr sqlExpr = items.get(0);
MySqlSelectQueryBlock parent = findParent(sqlExpr);
List selectList = parent.getSelectList();
List diff = diff(selectList, items);
int i = 0;
for (SQLExpr item : diff) {
SQLSelectItem sqlSelectItem = new SQLSelectItem();
item.setParent(sqlSelectItem);
sqlSelectItem.setExpr(item);
sqlSelectItem.setAlias(((SQLIdentifierExpr)item).getName());
parent.addSelectItem(sqlSelectItem);
i++;
}
return super.visit(groupBy);
}
@Override
public boolean visit(SQLOrderBy orderBy) {
List items = orderBy.getItems();
Map map = new HashMap<>();
for (SQLSelectOrderByItem item : items) {
SQLExpr expr = item.getExpr();
if (expr instanceof SQLIdentifierExpr) {
map.put(((SQLIdentifierExpr) expr).getName(), expr);
} else {
throw new RuntimeException("未知类型");
}
}
MySqlSelectQueryBlock parent = findParent(orderBy);
List selectList = parent.getSelectList();
for (Map.Entry entry : map.entrySet()) {
if (needAddItem(entry, selectList)) {
SQLSelectItem sqlSelectItem = new SQLSelectItem();
SQLExpr value = entry.getValue();
value.setParent(sqlSelectItem);
sqlSelectItem.setExpr(value);
sqlSelectItem.setAlias(entry.getKey());
selectList.add(sqlSelectItem);
}
}
return super.visit(orderBy);
}
@Override
public boolean visit(SQLAggregateExpr x) {
SQLExpr param = null;
String paramName = "";
MySqlSelectQueryBlock parent = findParent(x);
List selectList = parent.getSelectList();
List collect = selectList.stream().map(SQLSelectItem::getExpr).filter(it -> it instanceof SQLAggregateExpr)
.collect(Collectors.toList());
SQLAggregateExpr count = new SQLAggregateExpr("count");
SQLAggregateExpr sum = new SQLAggregateExpr("sum");
boolean hasCount = false, hasSum = false;
if (CollectionUtil.isNotEmpty(collect)) {
for (SQLExpr sqlExpr : collect) {
SQLAggregateExpr aggregateExpr = (SQLAggregateExpr) sqlExpr;
if (aggregateExpr.getMethodName().equals("avg")) {
param = aggregateExpr.getArguments().get(0);
paramName = ((SQLIdentifierExpr) aggregateExpr.getArguments().get(0)).getName();
}
}
if (StrUtil.isBlank(paramName)) {
return true;
}
for (SQLExpr sqlExpr : collect) {
SQLAggregateExpr aggregateExpr = (SQLAggregateExpr) sqlExpr;
SQLExpr sqlIdentify = aggregateExpr.getArguments().get(0);
String name = "";
if (sqlIdentify instanceof SQLIdentifierExpr) {
name = ((SQLIdentifierExpr) sqlIdentify).getName();
}
if (aggregateExpr.getMethodName().equals("count") && name.equals(paramName)) {
hasCount = true;
}
if (aggregateExpr.getMethodName().equals("sum") && name.equals(paramName)) {
hasSum = true;
}
}
}
if (hasCount && hasSum) {
return true;
}
SQLSelectItem sumExpr = new SQLSelectItem();
sum.setParent(sumExpr);
sum.addArgument(param);
sumExpr.setParent(parent);
sumExpr.setExpr(sum);
sumExpr.setAlias("generate_sum");
parent.addSelectItem(sumExpr);
SQLSelectItem countExpr = new SQLSelectItem();
count.setParent(countExpr);
count.addArgument(param);
countExpr.setExpr(count);
countExpr.setParent(parent);
countExpr.setAlias("generate_count");
parent.addSelectItem(countExpr);
return super.visit(x);
}
});
DbType mysql = JdbcConstants.MYSQL;
sql = SQLUtils.toSQLString(query, mysql);
System.out.println(sql);
}
private static List diff(List selectList, List items) {
List res = new ArrayList<>();
for (SQLExpr item : items) {
String name = "";
if (item instanceof SQLIdentifierExpr) {
name = ((SQLIdentifierExpr) item).getName();
} else {
throw new RuntimeException("有其他的类型");
}
for (SQLSelectItem sqlSelectItem : selectList) {
if (sqlSelectItem.getAlias() != null && sqlSelectItem.getAlias().equalsIgnoreCase(name)) {
break;
}
if (SQLUtils.toSQLString(sqlSelectItem.getExpr()).equalsIgnoreCase(name)) {
break;
}
}
res.add(item);
}
return res;
}
private static MySqlSelectQueryBlock findParent(SQLExpr groupBy) {
SQLObject parent = groupBy.getParent();
while (!(parent instanceof MySqlSelectQueryBlock)) {
parent = parent.getParent();
}
return (MySqlSelectQueryBlock) parent;
}
private static boolean needAddItem(Map.Entry entry, List selectList) {
for (SQLSelectItem selectItem : selectList) {
String alias = selectItem.getAlias();
SQLExpr expr = selectItem.getExpr();
if (StrUtil.isNotBlank(alias) && entry.getKey().equalsIgnoreCase(alias)) {
return false;
}
if (expr.toString().equalsIgnoreCase(entry.getKey())) {
return false;
}
}
return true;
}
private static MySqlSelectQueryBlock findParent(SQLOrderBy x) {
SQLObject parent = x.getParent();
while (!(parent instanceof MySqlSelectQueryBlock)) {
parent = parent.getParent();
}
return (MySqlSelectQueryBlock) parent;
}
}
结果归并部分,也就是我们面对的情况分别进行处理。
比如多字段排序就是如此
public boolean visit(SQLOrderBy orderBy) {
List items = orderBy.getItems();
ArrayList list = new ArrayList<>();
for (SQLSelectOrderByItem item : items) {
SQLExpr itemExpr = item.getExpr();
if (itemExpr instanceof SQLIdentifierExpr) {
list.add(((SQLIdentifierExpr) itemExpr).getName());
} else {
throw new RuntimeException("未知类型");
}
}
List
代码: yzp/sharding-plugin