关于mybatis的拦截器 请移步 mybatis 拦截器实现
我们今天就根据这个插件 实现根据值hash分表功能
首先 准备一个 注解类
@Target({ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Inherited
@Documented
public @interface Share {
public int size();
/** 分片策略 */
public ShareStrategy strategy() default ShareStrategy.HASH;
static enum ShareStrategy {
/** hash分片 */
HASH;
}
}
@Share(size=2)
@Table(name = "model")
public class Model {
@Id
@Column(name = "id")
private Integer id;
//......
}
/**
* @author :sunla
* @descript:构建分片sql
*/
@Intercepts({
@Signature(type = StatementHandler.class, method = "prepare", args = {
Connection.class }),
@Signature(type = StatementHandler.class, method = "parameterize", args = {
Statement.class }),
@Signature(type = StatementHandler.class, method = "batch", args = {
Statement.class}),
@Signature(type = StatementHandler.class, method = "update", args = {
Statement.class }),
@Signature(type = StatementHandler.class, method = "query", args = {
Statement.class,ResultHandler.class })})
public class ShareStatementPlugin implements Interceptor {
private static final Logger LOG = LoggerFactory.getLogger(ShareStatementPlugin.class);
@Override
public Object intercept(Invocation invocation) throws Throwable {
StatementHandler handler = (StatementHandler)invocation.getTarget();
MappedStatement ms = getMappedStatement(handler);
if(ms == null){
return invocation.proceed();
}
String superGenericClass = getSuperGenericClass(ms);
ShareEntity entity = getShareInfo(superGenericClass);
BoundSql boundSql = handler.getBoundSql();
Object obj = boundSql.getParameterObject();
/** 参数映射 */
List list = boundSql.getParameterMappings();
String sql = boundSql.getSql();
modifySql(boundSql,buildShareSql(sql,entity,obj));
return invocation.proceed();
}
@Override
public Object plugin(Object target) {
if(target instanceof StatementHandler){
return Plugin.wrap(target, this);
}
return target;
}
/**
* 通过反射修改sql
* @param target BoundSql 实例
* @param sql 修改的sql
*/
private void modifySql(Object target,String sql){
try {
Field sqlField = target.getClass().getDeclaredField("sql");
sqlField.setAccessible(true);
sqlField.set(target,sql);
}catch(Exception e){
LOG.error("modifySql error msg is {} sql is {}",e,sql);
}
}
@Override
public void setProperties(Properties properties) {
}
/**
* 从statementHandler中获取 mappedStatement
* @param statementHandler
* @return
*/
private MappedStatement getMappedStatement(StatementHandler statementHandler){
final String FILEDNAME = "mappedStatement";
boolean isBaseStatementHandler = false;
BaseStatementHandler baseStatementHandler = null;
try {
if(statementHandler instanceof RoutingStatementHandler){
Field field = statementHandler.getClass().getDeclaredField("delegate");
field.setAccessible(true);
Object delegate = field.get(statementHandler);
if( delegate instanceof BaseStatementHandler){
baseStatementHandler = (BaseStatementHandler)delegate;
isBaseStatementHandler = true;
}
}
/** 判断是否父类为 BaseStatementHandler */
if(isBaseStatementHandler){
Field statementField = baseStatementHandler.getClass().getSuperclass().getDeclaredField(FILEDNAME);
statementField.setAccessible(true);
Object obj = statementField.get(baseStatementHandler);
if(obj != null && obj instanceof MappedStatement) return (MappedStatement)obj;
}
}catch(Exception e){
LOG.error("getMappedStatement error msg is {}",e);
}
return null;
}
/**
* 获取表的分片信息
* @param genericClass
*/
private ShareEntity getShareInfo(String genericClass){
try {
Class generic = getClass().getClassLoader().loadClass(genericClass);
Annotation annotation = generic.getAnnotation(Share.class);
Annotation tableAnnotation = generic.getAnnotation(Table.class);
/** 查询分片字段 */
String shareField = "";
Field[] fields = generic.getDeclaredFields();
for(Field field : fields){
Annotation fieldId = field.getAnnotation(Id.class);
if(fieldId != null)
shareField = field.getName();
}
if(annotation != null && tableAnnotation != null){
Share share = (Share) annotation;
Table table = (Table) tableAnnotation;
return new ShareEntity(table.name(),share.size(),shareField);
}
}catch (Exception e){
LOG.error("getShareInfo error is {}",e);
}
return null;
}
/**
* 获取mapper中的泛型实体
* ex:BaseBillMapper 中的blbackUser实体
* @param ms
* @return
*/
private String getSuperGenericClass(MappedStatement ms){
String mapperId = ms.getId();
String classPackage = mapperId.substring(0,mapperId.lastIndexOf("."));
try {
Class mapper = this.getClass().getClassLoader().loadClass(classPackage);
Type[] types = mapper.getGenericInterfaces();
if(types != null && types.length > 0){
Type genericInterface = types[0];
ParameterizedType p =(ParameterizedType)genericInterface;
if(p.getActualTypeArguments() != null && p.getActualTypeArguments().length > 0){
Class parameterizedClass = (Class) p.getActualTypeArguments()[0];
if(LOG.isDebugEnabled()){
LOG.debug("share StatementPlugin getSuperGenericClass name is {}",parameterizedClass.getName());
}
return parameterizedClass.getName();
}
}
}catch(Exception e){
LOG.error("ShareBoundSqlPlugin error ms id is {} msg is {}",mapperId,e);
}
return null;
}
/**
* 构建分片sql
* @param sql
* @param entity
* @param param
* @return
*/
private String buildShareSql(String sql,ShareEntity entity,Object param){
if(entity.shareSize>1){
String tableName = entity.getTableName();
int hash = Integer.valueOf(String.valueOf(param)) % entity.shareSize;
String newTableName = tableName+"_"+hash;
sql = sql.replaceAll(tableName,newTableName);
System.err.println(sql);
}
return sql;
}
private static class ShareEntity {
private String tableName;
private int shareSize;
private String shareField;
ShareEntity(String tableName,int shareSize,String shareField){
this.tableName = tableName;
this.shareSize = shareSize;
this.shareField = shareField;
}
public int getShareSize(){return this.shareSize;}
public String getTableName(){return this.tableName;}
public String getShareField(){return this.shareField;}
}
}
<bean id="sqlSessionFactory" class="org.mybatis.spring.SqlSessionFactoryBean">
<property name="dataSource" ref="dataSource"/>
<property name="mapperLocations">
<list>
<value>classpath:mapper/*.xmlvalue>
list>
property>
<property name="typeAliasesPackage" value="com.test.model"/>
<property name="plugins">
<array>
<bean class="com.test.filter.ShareStatementPlugin" />
array>
property>
bean>