利用mybatis 拦截器插件功能实现分片SQL

关于mybatis的拦截器 请移步 mybatis 拦截器实现
我们今天就根据这个插件 实现根据值hash分表功能
首先 准备一个 注解类

Share

@Target({ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Inherited
@Documented
public @interface Share {

    public int size();

    /** 分片策略 */
    public ShareStrategy strategy() default ShareStrategy.HASH;

    static enum ShareStrategy {
        /** hash分片 */
        HASH;
    }
}

在基础实体上添加注解

Model

@Share(size=2)
@Table(name = "model")
public class Model {

    @Id
    @Column(name = "id")
    private Integer id;
    //......

}

最重要的拦截器类

ShareStatementPlugin

/**
 * @author :sunla
 * @descript:构建分片sql
 */
@Intercepts({
        @Signature(type = StatementHandler.class, method = "prepare", args = {
                Connection.class }),
        @Signature(type = StatementHandler.class, method = "parameterize", args = {
                Statement.class }),
        @Signature(type = StatementHandler.class, method = "batch", args = {
                Statement.class}),
        @Signature(type = StatementHandler.class, method = "update", args = {
                Statement.class }),
        @Signature(type = StatementHandler.class, method = "query", args = {
                Statement.class,ResultHandler.class })})
public class ShareStatementPlugin implements Interceptor {

    private static final Logger LOG = LoggerFactory.getLogger(ShareStatementPlugin.class);

    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        StatementHandler handler = (StatementHandler)invocation.getTarget();
        MappedStatement ms = getMappedStatement(handler);
        if(ms == null){
            return invocation.proceed();
        }
        String superGenericClass = getSuperGenericClass(ms);
        ShareEntity entity = getShareInfo(superGenericClass);
        BoundSql boundSql = handler.getBoundSql();
        Object obj = boundSql.getParameterObject();
        /** 参数映射 */
        List list = boundSql.getParameterMappings();
        String sql = boundSql.getSql();
        modifySql(boundSql,buildShareSql(sql,entity,obj));
        return invocation.proceed();
    }

    @Override
    public Object plugin(Object target) {
        if(target instanceof StatementHandler){
            return Plugin.wrap(target, this);
        }
        return target;
    }

    /**
     * 通过反射修改sql
     * @param target BoundSql 实例
     * @param sql 修改的sql
     */
    private void modifySql(Object target,String sql){
        try {
            Field sqlField = target.getClass().getDeclaredField("sql");
            sqlField.setAccessible(true);
            sqlField.set(target,sql);
        }catch(Exception e){
            LOG.error("modifySql error msg is {} sql is {}",e,sql);
        }
    }

    @Override
    public void setProperties(Properties properties) {

    }

    /**
     * 从statementHandler中获取 mappedStatement
     * @param statementHandler
     * @return
     */
    private MappedStatement getMappedStatement(StatementHandler statementHandler){
        final String FILEDNAME = "mappedStatement";
        boolean isBaseStatementHandler = false;
        BaseStatementHandler baseStatementHandler = null;
        try {
            if(statementHandler instanceof RoutingStatementHandler){
                Field field = statementHandler.getClass().getDeclaredField("delegate");
                field.setAccessible(true);
                Object delegate = field.get(statementHandler);
                if( delegate instanceof BaseStatementHandler){
                    baseStatementHandler = (BaseStatementHandler)delegate;
                    isBaseStatementHandler = true;
                }
            }
            /** 判断是否父类为 BaseStatementHandler */
            if(isBaseStatementHandler){
                Field statementField = baseStatementHandler.getClass().getSuperclass().getDeclaredField(FILEDNAME);
                statementField.setAccessible(true);
                Object obj = statementField.get(baseStatementHandler);
                if(obj != null && obj instanceof MappedStatement) return (MappedStatement)obj;
            }
        }catch(Exception e){
            LOG.error("getMappedStatement error msg is {}",e);
        }
        return null;
    }

    /**
     * 获取表的分片信息
     * @param genericClass
     */
    private ShareEntity getShareInfo(String genericClass){
        try {
            Class generic = getClass().getClassLoader().loadClass(genericClass);
            Annotation annotation = generic.getAnnotation(Share.class);
            Annotation tableAnnotation = generic.getAnnotation(Table.class);
            /** 查询分片字段 */
            String shareField = "";
            Field[] fields = generic.getDeclaredFields();
            for(Field field : fields){
                Annotation fieldId = field.getAnnotation(Id.class);
                if(fieldId != null)
                    shareField = field.getName();
            }
            if(annotation != null && tableAnnotation != null){
                Share share = (Share) annotation;
                Table table = (Table) tableAnnotation;
                return new ShareEntity(table.name(),share.size(),shareField);
            }
        }catch (Exception e){
            LOG.error("getShareInfo error is {}",e);
        }
        return null;
    }

    /**
     * 获取mapper中的泛型实体
     * ex:BaseBillMapper 中的blbackUser实体
     * @param ms
     * @return
     */
    private String getSuperGenericClass(MappedStatement ms){
        String mapperId = ms.getId();
        String classPackage = mapperId.substring(0,mapperId.lastIndexOf("."));
        try {
            Class mapper = this.getClass().getClassLoader().loadClass(classPackage);
            Type[] types = mapper.getGenericInterfaces();
            if(types != null && types.length > 0){
                Type genericInterface = types[0];
                ParameterizedType p =(ParameterizedType)genericInterface;
                if(p.getActualTypeArguments() != null && p.getActualTypeArguments().length > 0){
                    Class parameterizedClass = (Class) p.getActualTypeArguments()[0];
                    if(LOG.isDebugEnabled()){
                        LOG.debug("share StatementPlugin getSuperGenericClass name is {}",parameterizedClass.getName());
                    }
                    return parameterizedClass.getName();
                }
            }
        }catch(Exception e){
            LOG.error("ShareBoundSqlPlugin error ms id is {} msg is {}",mapperId,e);
        }
        return null;
    }

    /**
     * 构建分片sql
     * @param sql
     * @param entity
     * @param param
     * @return
     */
    private String buildShareSql(String sql,ShareEntity entity,Object param){
        if(entity.shareSize>1){
            String tableName = entity.getTableName();
            int hash = Integer.valueOf(String.valueOf(param)) % entity.shareSize;
            String newTableName = tableName+"_"+hash;
            sql = sql.replaceAll(tableName,newTableName);
            System.err.println(sql);
        }
        return sql;
    }

    private static class ShareEntity {
        private String tableName;
        private int shareSize;
        private String shareField;

        ShareEntity(String tableName,int shareSize,String shareField){
            this.tableName = tableName;
            this.shareSize = shareSize;
            this.shareField = shareField;
        }

        public int getShareSize(){return this.shareSize;}
        public String getTableName(){return this.tableName;}
        public String getShareField(){return this.shareField;}

    }

}

最后注册到拦截器里面

spring-mybatis.xml

<bean id="sqlSessionFactory" class="org.mybatis.spring.SqlSessionFactoryBean">
        <property name="dataSource" ref="dataSource"/>
        <property name="mapperLocations">
            <list>
                <value>classpath:mapper/*.xmlvalue>
            list>
        property>
        <property name="typeAliasesPackage" value="com.test.model"/>
        <property name="plugins">
            <array>
                <bean class="com.test.filter.ShareStatementPlugin" />
            array>
        property>
    bean>

你可能感兴趣的:(java)