github:https://github.com/cjx913/cbatis
具体代码如下:
package com.cjx913.cbatis.core;
import org.apache.ibatis.cache.CacheKey;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.SqlCommandType;
import org.apache.ibatis.plugin.*;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import java.util.HashSet;
import java.util.Properties;
import java.util.Set;
@Intercepts({
@Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}),
@Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class, CacheKey.class, BoundSql.class}),
@Signature(type = Executor.class, method = "update", args = {MappedStatement.class, Object.class}),
})
public class SubTableInterceptor implements Interceptor {
private SubTableHandler subTableHandler;
private Set handleSqlIds;
public Object intercept(Invocation invocation) throws Throwable {
Executor executor = (Executor) invocation.getTarget();
Object[] args = invocation.getArgs();
MappedStatement mappedStatement = (MappedStatement) args[0];
Object parameter = args[1];
String sqlId = mappedStatement.getId();
//是否处理
if (handleSqlIds == null || handleSqlIds.isEmpty() || !handleSqlIds.contains(sqlId)) {
return invocation.proceed();
}
if (mappedStatement.getSqlCommandType().compareTo(SqlCommandType.SELECT) == 0) {
RowBounds rowBounds = (RowBounds) args[2];
ResultHandler resultHandler = (ResultHandler) args[3];
CacheKey cacheKey;
BoundSql boundSql;
if (args.length == 4) {
boundSql = mappedStatement.getBoundSql(parameter);
cacheKey = executor.createCacheKey(mappedStatement, parameter, rowBounds, boundSql);
} else {
boundSql = (BoundSql) args[5];
cacheKey = (CacheKey) args[4];
}
subTableHandler.boundSqlHandler(boundSql);
return executor.query(mappedStatement, parameter, rowBounds, resultHandler, cacheKey, boundSql);
} else {
MappedStatement newMappedStatement = subTableHandler.mappedStatementHandler(mappedStatement, parameter);
return executor.update(newMappedStatement, parameter);
}
}
public Object plugin(Object target) {
return Plugin.wrap(target, this);
}
public void setProperties(Properties properties) {
String tableNameHandlerClassName = properties.getProperty("table-name-handler").trim();
if (tableNameHandlerClassName == null || "".equalsIgnoreCase(tableNameHandlerClassName)) {
throw new CbatisException("无效的名表处理类!请在插件添加\n");
}
subTableHandler = new SubTableHandler(tableNameHandlerClassName);
String[] sqlIds = properties.getProperty("handle-sql-ids").split(",");
if (sqlIds != null && sqlIds.length > 0) {
handleSqlIds = new HashSet <>();
for (String s : sqlIds) {
handleSqlIds.add(s.trim());
}
}
}
public Set getHandleSqlIds() {
return handleSqlIds;
}
public void setHandleSqlIds(Set handleSqlIds) {
this.handleSqlIds = handleSqlIds;
}
}
package com.cjx913.cbatis.core;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.statement.Statement;
import net.sf.jsqlparser.util.TablesNamesFinder;
import org.apache.ibatis.builder.StaticSqlSource;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.SqlSource;
import java.lang.reflect.Field;
import java.util.*;
/**
* 分表处理器
*/
public class SubTableHandler {
private AbstractTableNameHandler tableNameHandler;
public SubTableHandler(String tableNameHandlerClassName) {
try {
tableNameHandler = (AbstractTableNameHandler) Class.forName(tableNameHandlerClassName).newInstance();
} catch (InstantiationException | IllegalAccessException | ClassNotFoundException e) {
throw new CbatisException("分表处理器错误->表名处理类错误!请确认" + tableNameHandlerClassName + "是否存在", e);
}
}
public MappedStatement mappedStatementHandler(MappedStatement mappedStatement, Object parameter) {
SqlSource sqlSource = mappedStatement.getSqlSource();
BoundSql boundSql = sqlSource.getBoundSql(parameter);
boundSqlHandler(boundSql);
StaticSqlSource newSqlSource = new StaticSqlSource(mappedStatement.getConfiguration(), boundSql.getSql(), boundSql.getParameterMappings());
MappedStatement newMappedStatement = new MappedStatement.Builder(mappedStatement.getConfiguration(), mappedStatement.getId(), newSqlSource, mappedStatement.getSqlCommandType())
.resource(mappedStatement.getResource())
.parameterMap(mappedStatement.getParameterMap())
.flushCacheRequired(mappedStatement.isFlushCacheRequired())
.keyGenerator(mappedStatement.getKeyGenerator())
.keyProperty(CbatisUtil.split(mappedStatement.getKeyProperties()))
.keyColumn(CbatisUtil.split(mappedStatement.getKeyColumns()))
.timeout(mappedStatement.getTimeout())
.cache(mappedStatement.getCache())
.useCache(mappedStatement.isUseCache())
.databaseId(mappedStatement.getDatabaseId())
.fetchSize(mappedStatement.getFetchSize())
.lang(mappedStatement.getLang())
.resultMaps(mappedStatement.getResultMaps())
.resultOrdered(mappedStatement.isResultOrdered())
.build();
return newMappedStatement;
}
public String boundSqlHandler(BoundSql boundSql) {
try {
String sql = boundSql.getSql();
Object parameter = boundSql.getParameterObject();
sql = sqlHandler(sql);
Set tableNames = getTableNames(sql);
if (!tableNames.isEmpty()) {
Map newTableNames = getNewTableNames(tableNames, parameter);
if (!newTableNames.isEmpty()) {
sql = replaceTableName(sql, newTableNames);
//通过反射修改sql语句
Field field = boundSql.getClass().getDeclaredField("sql");
field.setAccessible(true);
field.set(boundSql, sql);
}
}
return sql;
} catch (Exception e) {
throw new CbatisException("Sql修改错误!", e);
}
}
/**
* sql处理,去掉换行和多余空格
*
* @param sql
* @return
*/
private String sqlHandler(String sql) {
String[] strings = sql.trim().replace("\n", " ").split("\\s+");
StringBuilder stringBuilder = new StringBuilder();
for (int i = 0; i < strings.length; i++) {
if (strings[i].endsWith(",")) {
stringBuilder.append(strings[i]);
} else {
stringBuilder.append(strings[i] + " ");
}
}
sql = stringBuilder.toString().trim().toUpperCase();
return sql;
}
/**
* @param sql
* @return 获取sql所有的表名
*/
private Set getTableNames(String sql) {
Statement statement = null;
try {
statement = CCJSqlParserUtil.parse(sql);
} catch (JSQLParserException e) {
throw new CbatisException("解析sql语句错误!sql:" + sql, e);
}
TablesNamesFinder tablesNamesFinder = new TablesNamesFinder();
List tableList = tablesNamesFinder.getTableList(statement);
Set tableNames = new HashSet <>();
for (String tableName : tableList) {
//获取去掉“`”的表名
if (tableName.startsWith("`") && tableName.endsWith("`")) {
tableNames.add(tableName.substring(1, tableName.length()-1));
}else {
tableNames.add(tableName);
}
}
return tableNames;
}
/**
* @param tableNames
* @param parameter
* @return 原表名与新表名的Map
*/
private Map getNewTableNames(Set tableNames, Object parameter) {
Map newTableNames = new HashMap <>();
for (String tableName : tableNames) {
//获取新表名(新表名都带“`”)
String newTableName = "`"+tableNameHandler.tableNameHandler(tableName, parameter)+"`";
newTableNames.put(tableName, newTableName);
}
return newTableNames;
}
/**
* 表名替换
*
* @param sql
* @param tableNames
* @return
*/
private String replaceTableName(String sql, Map tableNames) {
//去掉sq的“`”
sql = sql.replace("`", "");
Set > entrySet = tableNames.entrySet();
for (Map.Entry entry : entrySet) {
if (!entry.getKey().equalsIgnoreCase(entry.getValue())) {
sql = sql.replace(" " + entry.getKey() + " ", " " + entry.getValue() + " ")
.replace(" " + entry.getKey() + ",", " " + entry.getValue() + ",")
.replace("," + entry.getKey() + " ", "," + entry.getValue() + " ")
.replace("," + entry.getKey() + ",", "," + entry.getValue() + ",")
.replace(" " + entry.getKey() + "(", " " + entry.getValue() + "(");
}
}
return sql.toUpperCase();
}
}
package com.cjx913.cbatis.core;
import org.apache.ibatis.binding.MapperMethod;
import java.beans.IntrospectionException;
import java.beans.PropertyDescriptor;
import java.lang.annotation.Annotation;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.*;
/**
* 表名处理器
*/
public abstract class AbstractTableNameHandler {
private Map tableMethodMap = new HashMap <>();
public AbstractTableNameHandler() {
initTableMethodMap();
}
public final String tableNameHandler(String tableName, Object parameter) {
TableNameHandlerMethod tableNameHandlerMethod = tableMethodMap.get(tableName.toUpperCase());
if (tableNameHandlerMethod == null) {
return tableName.trim().toUpperCase();
}
Method method = tableNameHandlerMethod.getMethod();
//获取方法参数值,可以为null
Object[] params = paramHandler(method, parameter);
// if (params == null) {
// throw new CbatisException("调用分表处理方法错误!没有指定参数\n表名:" + tableName + ",方法:" + method.getName());
// }
Object newTableName = null;
try {
newTableName = method.invoke(tableNameHandlerMethod.getInstanse(), params);
} catch (IllegalAccessException | InvocationTargetException e) {
throw new CbatisException("调用分表处理方法错误!表名:" + tableName + ",方法:" + method.getName(), e);
}
if (newTableName instanceof String) {
return ((String) newTableName).trim().toUpperCase();
} else {
throw new CbatisException("调用分表处理方法错误!->返回非字符串。表名:" + tableName);
}
}
/**
* 获取参数值
*
* @param method
* @param parameter
* @return
*/
private Object[] paramHandler(Method method, Object parameter) {
MapperMethod.ParamMap
package com.cjx913.cbatis.core;
import java.lang.annotation.*;
@Documented
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.PARAMETER)
public @interface Param {
String value();
}
package com.cjx913.cbatis.core;
public class CbatisException extends RuntimeException {
public CbatisException() {
}
public CbatisException(String message) {
super(message);
}
public CbatisException(String message, Throwable cause) {
super(message, cause);
}
public CbatisException(Throwable cause) {
super(cause);
}
public CbatisException(String message, Throwable cause, boolean enableSuppression, boolean writableStackTrace) {
super(message, cause, enableSuppression, writableStackTrace);
}
}
package com.cjx913.cbatis.core;
public class CbatisUtil {
public static String split(String[] strings){
if (strings==null){
return null;
}
StringBuilder stringBuilder =new StringBuilder();
for(String str:strings){
stringBuilder.append(str+",");
}
return stringBuilder.deleteCharAt(stringBuilder.length() - 1).toString();
}
}
在mybatis的配置文件添加插件
package com.cjx913.cbatis;
import com.cjx913.cbatis.core.AbstractTableNameHandler;
import com.cjx913.cbatis.core.Param;
import java.util.HashSet;
import java.util.Set;
public class SimpleTableNameHandler extends AbstractTableNameHandler {
public String user(@Param("username") String username) {
if (username == null || "".equalsIgnoreCase(username.trim())) {
return "user";
}
if (username.hashCode() % 2 == 0) {
return "user_0";
} else {
return "user_1";
}
}
public String order(@Param("userId") Integer userId, @Param("id") Integer id) {
if (userId == null) {
return "order";
}
if (id == null) id = 0;
int i = (userId.hashCode() + id.hashCode()) % 3;
if (i == 0) {
return "order_1";
} else if (i == 1) {
return "order_2";
} else {
return "order_3";
}
}
@Override
public Set setTableMethodSet() throws NoSuchMethodException {
Set tableMethodMap = new HashSet <>();
tableMethodMap.add(new TableNameHandlerMethod(this.getClass().getMethod("user", String.class)));
tableMethodMap.add(getThisTableNameHandlerMethod(this.getClass().getMethod("order", Integer.class, Integer.class)));
return tableMethodMap;
}
}