【转载】springboot 之 mybatis 拦截器实现(数据分表查询及保存)

已实现功能
  • 自动创建业务分表索引表
  • 保存数据时,自动保存到最新表中
  • 查询时根据时间段查询对应时间的分表里面的记录
具体业务

当数据越来越大时,需要分时间段保存到不同的表中,查询的时候也可以根据时间段查询不同表中的记录

框架依赖
  • springboot
  • mybatis plus
  • (最新版本有实现了一些插件,可惜项目中使用的版本比较老,换上新版本后,发现mybatis plus在v3.3.1 就移除了对entity 的泛型提取,然而项目中使用的entity 太多,不想去改了,所以放弃了升级)
使用方法

1、如果是保存,只需要在Mapper 的接口类上加上注解,对这些方法进行拦截 insert,selectList,selectPage 就可以了

@ShardTable(table = DoorInoutRecord.class,interceptMethod = {"insert","selectList","selectPage"})
public interface DoorInoutRecordMapper extends BaseMapper<DoorInoutRecord> {

}

2、如果查询需要查询一个时间段内的所有表里面的数据,需要在mybatis- plus 的queryWrapper.last() 方法传入参数就可以了

QueryWrapper<DoorInoutRecord> queryWrapper = new QueryWrapper<>();
// 分库查询的时候用
queryWrapper.last(ShardTableConstant.SHARD_TABLE_PARAM + "2020-01-01,2021-01-01");
List<DoorInoutRecord> dataList = doorInoutRecordService.list(queryWrapper);

具体实现完整源码

首先创建一个自定义注解
/**
 * @author czx
 * @title: ShardTable
 * @projectName zhjg
 * @description: TODO
 * @date 2021/3/1014:59
 */
@Retention(RetentionPolicy.RUNTIME)
@Target({ ElementType.TYPE, ElementType.METHOD })
public @interface ShardTable {

    /**
     * 表名
     **/
    Class table();

    /**
     * 拦截 mapper 中的方法名
     **/
    String[] interceptMethod() default {""};

}

然后再创建一个 实体类,用来解析 查询时间段的参数
/**
 * @author czx
 * @title: ShardTableParam
 * @projectName zhjg
 * @description: TODO
 * @date 2021/3/1214:51
 */
@Data
public class ShardTableParam {

    public String startData;

    public String endData;

}

然后创建一个常量类 ShardTableConstant
/**
 * @author czx
 * @title: ShardTableConstant
 * @projectName zhjg
 * @description: TODO
 * @date 2021/3/1114:31
 */
public interface ShardTableConstant {

    /**
     * 分表记录索引表后缀
     * xxxx_shard_table
     **/
    String SHARD_TABLE_NAME = "_shard_table";

    /**
     * 查询当前数据库是否存在XXX记录表
     * xxxx_shard_table
     **/
    String WHETHER_HAS_TABLE = "SELECT TABLE_NAME AS tableName FROM information_schema.TABLES WHERE TABLE_SCHEMA = (SELECT DATABASE ()) and TABLE_NAME = ? ";

    /**
     * 分表记录 查询的时间段
     * shardTableTimeParam=2020-01-01,2021-01-01
     **/
    String SHARD_TABLE_PARAM = "shardTableTimeParam=";

}

再创建一个 能获取spring上下文的类,用来找bean
@Slf4j
@Component
public class ApplicationContextRegister implements ApplicationContextAware {

    private static ApplicationContext APPLICATION_CONTEXT;

    @Override
    public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
        APPLICATION_CONTEXT = applicationContext;
    }

    public static ApplicationContext getApplicationContext() {
        return APPLICATION_CONTEXT;
    }
}

最后自定义一个mybatis 的拦截器

import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.date.DateUtil;
import cn.hutool.core.map.MapUtil;
import cn.hutool.core.util.StrUtil;
import com.baomidou.mybatisplus.annotation.TableName;
import com.google.common.collect.Lists;
import com.suke.zhjg.common.seata.annotation.ShardTable;
import com.suke.zhjg.common.seata.config.ApplicationContextRegister;
import lombok.extern.slf4j.Slf4j;
import org.apache.ibatis.cache.CacheKey;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.ParameterMapping;
import org.apache.ibatis.plugin.*;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.SystemMetaObject;
import org.apache.ibatis.session.Configuration;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import org.apache.ibatis.type.TypeHandlerRegistry;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.stereotype.Component;
import org.springframework.transaction.support.TransactionTemplate;

import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
import java.sql.Connection;
import java.util.*;
import java.util.stream.Collectors;

/**
 * @author czx
 * @title: ShardTableInterceptor
 * @projectName zhjg
 * @description: TODO 分表拦截器
 * @date 2021/3/1014:46
 */
@Slf4j
@Component
@Intercepts({
    @Signature(type = StatementHandler.class, method = "prepare", args = {Connection.class, Integer.class}),
    @Signature(type = StatementHandler.class, method = "getBoundSql", args = {}),
    @Signature(type = Executor.class, method = "update", args = {MappedStatement.class, Object.class}),
    @Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}),
    @Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class, CacheKey.class, BoundSql.class}),
})
public class ShardTableInterceptor implements Interceptor {

    /**
     *  默认策略
     **/
    public final String strategyType = "semester";

    public final String boundSql = "delegate.boundSql";
    public final String boundSqlStr = "delegate.boundSql.sql";
    public final String mappedStatementValue = "delegate.mappedStatement";

    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        Object target = invocation.getTarget();
        Object[] args = invocation.getArgs();

        if (target instanceof Executor) {

        }else {
            // StatementHandler
            StatementHandler statementHandler = (StatementHandler) target;
            // 目前只有StatementHandler.getBoundSql方法args才为null
            if (null == args) {

            }else {
                statementHandler = realTarget(statementHandler);
                MetaObject metaStatementHandler = SystemMetaObject.forObject(statementHandler);
                doSplitTable(metaStatementHandler);
            }
        }
        // 传递给下一个拦截器处理
        return invocation.proceed();
    }

    @Override
    public Object plugin(Object target) {
        // 当目标类是StatementHandler类型时,才包装目标类,否者直接返回目标本身,减少目标被代理的次数
        if (target instanceof Executor || target instanceof StatementHandler) {
            return Plugin.wrap(target, this);
        }
        return target;
    }

    @Override
    public void setProperties(Properties properties) {

    }

    private void doSplitTable(MetaObject metaStatementHandler) throws Exception {
        String originalSql = (String) metaStatementHandler.getValue(boundSqlStr);
        if (originalSql != null && !originalSql.equals("")) {
            // 去掉换行符
            originalSql = originalSql.replaceAll("[\\s]+", " ");
            // 获取分表查询的时间段参数
            ShardTableParam shardTableParam = this.getShardTableParam(originalSql);
            int index = originalSql.indexOf(ShardTableConstant.SHARD_TABLE_PARAM);
            if(index > 0){
                originalSql = originalSql.substring(0,index); // 去掉分表查询的时间段标实
            }
            MappedStatement mappedStatement = (MappedStatement) metaStatementHandler.getValue(mappedStatementValue);
            String id = mappedStatement.getId();
            String className = id.substring(0, id.lastIndexOf("."));
            String methodName = id.substring(id.lastIndexOf(".") + 1);
            Class<?> clazz = Class.forName(className);
            Method method = findMethod(clazz.getDeclaredMethods(), methodName);
            // 根据配置自动生成分表SQL
            ShardTable shardTable = null;
            if (method != null) {
                shardTable = method.getAnnotation(ShardTable.class);
            }
            if (shardTable == null) {
                shardTable = clazz.getAnnotation(ShardTable.class);
            }
            if (shardTable != null) {
                Class tableClass = shardTable.table();
                List<String> shardTableMethod = Arrays.stream(shardTable.interceptMethod()).collect(Collectors.toList());
                // 如果表不是空,拦截方法不是空
                if(tableClass != null && shardTableMethod != null && shardTableMethod.contains(methodName)){
                    log.info("分表前的SQL:{}", originalSql);
                    TableName annotation = (TableName) tableClass.getAnnotation(TableName.class);
                    String originalTableName = annotation.value();
                    JdbcTemplate jdbcTemplate = ApplicationContextRegister.getApplicationContext().getBean(JdbcTemplate.class);
                    // 查询当前数据库是否存在表
                    String shardTableRecordName = originalTableName + ShardTableConstant.SHARD_TABLE_NAME;
                    // 如果不存在就创建记录表
                    if(!this.queryShardTable(jdbcTemplate,shardTableRecordName)){
                        this.createShardTable(jdbcTemplate,shardTableRecordName,originalTableName);
                        // 把原表数据插入,开始时间为2020-01-01 结束时间为 今天
                        this.insertShardTableRecord(jdbcTemplate,shardTableRecordName,originalTableName,originalTableName,"2020-01-01",DateUtil.formatDate(new Date()));
                    }
                    // 获取sql 头,select 、 update 、 delete
                    String sqlHead = originalSql.substring(0, originalSql.indexOf(" ")).toLowerCase();
                    String shardTableName = null;
                    switch (sqlHead){
                        case "select":
                            String startDate = null;
                            String endDate = null;
                            List<String> tableName = Lists.newArrayList();
                            // 如果有传时间段 就使用传的去查询
                            if(shardTableParam != null){
                                startDate = shardTableParam.getStartData();
                                endDate = shardTableParam.getEndData();
                            }else {
                                // 如果没有传 就取当前分表策略的时间段
                                HashMap<String, String> shardTableNameStrategy = this.getShardTableStrategy(strategyType);
                                startDate = MapUtil.getStr(shardTableNameStrategy, "startDate");
                                endDate = MapUtil.getStr(shardTableNameStrategy, "endDate");
                            }
                            List<String> shardTables = this.getShardTableName(jdbcTemplate, shardTableRecordName, originalTableName, startDate, endDate);
                            if(CollUtil.isNotEmpty(shardTables)){
                                tableName.addAll(shardTables);
                            }
                            if(CollUtil.isNotEmpty(tableName)){
                                List<String> existTable = Lists.newArrayList();
                                List<String> convertedSqlList = tableName
                                        .stream()
                                        .filter(name -> this.queryShardTable(jdbcTemplate, name))
                                        .map(name -> {
                                            // 封装参数
                                            existTable.add(name);
                                            String newSql = this.getNewSql(metaStatementHandler, mappedStatement);
                                            return this.sqlPackage(newSql.replaceAll(originalTableName, name));
                                        })
                                        .collect(Collectors.toList());
                                if(CollUtil.isNotEmpty(convertedSqlList)){
                                    List<String> sqlList = Lists.newArrayList();
                                    for(int i=0;i<convertedSqlList.size();i++){
                                        if(i == 0){
                                            // 上面封装参数是,把问号参数都替换了,如果当前这条查询是有参数的,那么会报错
                                            // 所有这里直接替换第一条为原始sql 只替换表名
                                            String sql = originalSql.replaceAll(originalTableName, existTable.get(0));
                                            sqlList.add(this.sqlPackage(sql));
                                        }else {
                                            sqlList.add(convertedSqlList.get(i));
                                        }
                                    }
                                    String convertedSql = sqlList.stream().collect(Collectors.joining(" UNION "));
                                    metaStatementHandler.setValue(boundSqlStr, convertedSql);
                                    log.info("分表后的SQL:{}" , convertedSql);
                                }
                            }else {
                                shardTableName = originalTableName;
                            }
                            break;
                        case "insert":
                            List<String> shardTableNameList = this.getShardTableName(jdbcTemplate, shardTableRecordName, originalTableName);
                            if(CollUtil.isNotEmpty(shardTableNameList)){
                                shardTableName = shardTableNameList.get(0);
                            }
                            break;
                        case "update":
                            break;
                        case "delete":
                            break;
                    }
                    if(StrUtil.isNotEmpty(shardTableName)){
                        String convertedSql = originalSql.replaceAll(originalTableName, shardTableName);
                        metaStatementHandler.setValue(boundSqlStr, convertedSql);
                        log.info("分表后的SQL:{}",convertedSql);
                    }
                }
            }
        }
    }

    /**
     * 获取分表查询的参数
     */
    public ShardTableParam getShardTableParam(String originalSql){
        int index = originalSql.indexOf(ShardTableConstant.SHARD_TABLE_PARAM);
        if(index > 0){
            String param = originalSql.substring(index);
            String[] split = param.split(",");
            if(split.length == 2){
                ShardTableParam shardTableParam = new ShardTableParam();
                shardTableParam.setStartData(split[0]);
                shardTableParam.setEndData(split[1]);
                return shardTableParam;
            }
        }
        return null;
    }

    /**
     * 获取组装后的SQL
     * 替换问号“?”把参数拼进去了
     * 参考 https://www.cnblogs.com/selinamee/p/7110072.html
     */
    public String getNewSql(MetaObject metaStatementHandler,MappedStatement mappedStatement){
        BoundSql boundSqlObj = (BoundSql) metaStatementHandler.getValue(boundSql);
        List<ParameterMapping> parameterMappings = boundSqlObj.getParameterMappings();
        Configuration configuration = mappedStatement.getConfiguration();
        Object parameterObject = boundSqlObj.getParameterObject();
        String sql = boundSqlObj.getSql().replaceAll("[\\s]+", " ");
        sql = sql.substring(0, sql.indexOf(ShardTableConstant.SHARD_TABLE_PARAM)); // 去掉分表查询的时间段标实
        if (parameterMappings.size() > 0 && parameterObject != null) {
            TypeHandlerRegistry typeHandlerRegistry = configuration.getTypeHandlerRegistry();
            if (typeHandlerRegistry.hasTypeHandler(parameterObject.getClass())) {
                sql = sql.replaceFirst("\\?", getParameterValue(parameterObject));
            } else {
                MetaObject metaObject = configuration.newMetaObject(parameterObject);
                for (ParameterMapping parameterMapping : parameterMappings) {
                    String propertyName = parameterMapping.getProperty();
                    if (metaObject.hasGetter(propertyName)) {
                        Object obj = metaObject.getValue(propertyName);
                        sql = sql.replaceFirst("\\?", getParameterValue(obj));
                    } else if (boundSqlObj.hasAdditionalParameter(propertyName)) {
                        Object obj = boundSqlObj.getAdditionalParameter(propertyName);
                        sql = sql.replaceFirst("\\?", getParameterValue(obj));
                    }
                }
            }
        }
        return sql;
    }

    /**
     * 获取参数问号“?”对应的值
     */
    private String getParameterValue(Object obj) {
        String value = null;
        if (obj instanceof String) {
            value = "'" + obj.toString() + "'";
        } else if (obj instanceof Date) {
            Date date = (Date) obj;
            value = "'" + DateUtil.formatDateTime(date) + "'";
        } else {
            if (obj != null) {
                value = obj.toString();
            } else {
                value = "";
            }
        }
        return value;
    }

    /**
     * 包装SQL
     */
    public String sqlPackage(String sql){
        return "select * from ("+ sql +") t";
    }

    /**
     * 查询分表记录表是否存在
     */
    public boolean queryShardTable(JdbcTemplate jdbcTemplate,String shardTableName){
        List<?> result = jdbcTemplate.queryForList(ShardTableConstant.WHETHER_HAS_TABLE, String.class,shardTableName);
        if(result == null || result.size() == 0){
            return false;
        }
        return true;
    }

    /**
     * 创建分表记录表
     */
    public boolean createShardTable(JdbcTemplate jdbcTemplate,String shardTableName,String originalTableName){
        StringBuffer sb = new StringBuffer("");
        sb.append("CREATE TABLE `" + shardTableName + "` (");
        sb.append(" `originalTableName` varchar(100) NOT NULL COMMENT '原表名', ");
        sb.append(" `shardTableName` varchar(50) NOT NULL COMMENT '分表名', ");
        sb.append(" `startDate` varchar(20) NOT NULL COMMENT '开始时间', ");
        sb.append(" `endDate` varchar(20) DEFAULT NULL COMMENT '结束时间',");
        sb.append(" PRIMARY KEY (`shardTableName`)");
        sb.append(") ENGINE=InnoDB DEFAULT CHARSET=utf8 COMMENT='"+originalTableName+"分表信息';");
        try {
            jdbcTemplate.update(sb.toString());
            return true;
        } catch (Exception e) {
            e.printStackTrace();
        }
        return false;
    }

    /**
     * 获取分表策略
     */
    public HashMap<String,String> getShardTableStrategy(String strategy){
        HashMap<String,String> strategyMap = new HashMap();
        int year = DateUtil.thisYear();
        int month = DateUtil.thisMonth();
        switch (strategy){
            case "year":
                strategyMap.put("startDate",  year + "-01-01");
                strategyMap.put("endDate",  year + "-12-31");
                break;
            case "semester":
                if(month < 6){
                    strategyMap.put("startDate",  year + "-01-01");
                    strategyMap.put("endDate",  year + "-06-31");
                }else {
                    strategyMap.put("startDate",  year + "-07-01");
                    strategyMap.put("endDate",  year + "-12-31");
                }
                break;
            case "quarter":
                if(month <= 2){
                    strategyMap.put("startDate",  year + "-01-01");
                    strategyMap.put("endDate",  year + "-03-31");
                }else if(month > 2 && month <= 5) {
                    strategyMap.put("startDate",  year + "-04-01");
                    strategyMap.put("endDate",  year + "-06-31");
                }else if(month > 5 && month <= 8) {
                    strategyMap.put("startDate",  year + "-07-01");
                    strategyMap.put("endDate",  year + "-09-31");
                }else {
                    strategyMap.put("startDate",  year + "-10-01");
                    strategyMap.put("endDate",  year + "-12-31");
                }
                break;
        }
        return strategyMap;
    }

    /**
     * 获取当前分表表名
     */
    public List<String> getShardTableName(JdbcTemplate jdbcTemplate,String shardTableRecordName,String originalTableName){
        HashMap<String, String> shardTableNameInsert = this.getShardTableStrategy(strategyType);
        String startDate = MapUtil.getStr(shardTableNameInsert, "startDate");
        String endDate = MapUtil.getStr(shardTableNameInsert, "endDate");
        List<String> tableName = this.getShardTableName(jdbcTemplate,shardTableRecordName,originalTableName,startDate,endDate);
        if(CollUtil.isNotEmpty(tableName)){
            return tableName;
        }else {
            // 如果没有就获取最新的index 然后去创建
            TransactionTemplate transactionTemplate = ApplicationContextRegister.getApplicationContext().getBean(TransactionTemplate.class);
            try {
                String sql = "select max(CONVERT(substring_index(shardTableName, '_', -1), UNSIGNED INTEGER)) shardTableNameIndex from " + shardTableRecordName + " ;";
                Integer shardTableNameIndex = jdbcTemplate.queryForObject(sql, Integer.class);
                String newTableName = originalTableName + "_" + (shardTableNameIndex + 1);
               boolean result = transactionTemplate.execute((status) ->{
                   String newTableSql = "CREATE TABLE " + newTableName + " LIKE "+ originalTableName +" ;";
                   jdbcTemplate.execute(newTableSql);
                   this.insertShardTableRecord(jdbcTemplate,shardTableRecordName,originalTableName,newTableName,startDate,endDate);
                    return true;
                });
               if(result){
                   return Collections.singletonList(newTableName);
               }
            }catch (Exception e){
                e.printStackTrace();
            }
        }
        return null;
    }

    /**
     * 分表记录表 插入数据
     */
    public void insertShardTableRecord(JdbcTemplate jdbcTemplate,String shardTableRecordName,String originalTableName,String newTableName,String startDate,String endDate){
        StringBuffer insetSql = new StringBuffer("");
        insetSql.append("INSERT INTO "+ shardTableRecordName);
        insetSql.append(" (`originalTableName`, `shardTableName`, `startDate`, `endDate`) ");
        insetSql.append(" VALUES ");
        insetSql.append(" ('"+ originalTableName +"', '"+ newTableName +"', '"+ startDate +"', '"+ endDate +"'); ");
        jdbcTemplate.execute(insetSql.toString());
    }

    /**
     * 获取当前分表表名
     */
    public List<String>  getShardTableName(JdbcTemplate jdbcTemplate,String shardTableRecordName,String originalTableName,String startDate,String endDate){
        StringBuffer sb = new StringBuffer("");
        sb.append("select shardTableName from " + shardTableRecordName + " ");
        sb.append(" where ");
        sb.append(" originalTableName = '" + originalTableName + "' ");
        sb.append(" AND ( ");
        sb.append(" ( startDate >= '" + startDate + "' AND startDate <= '"+endDate+"' ) ");
        sb.append(" OR ");
        sb.append(" ( endDate <= '"+ startDate +"' AND endDate >= '"+endDate+"' ) ");
        sb.append(" OR ");
        sb.append(" ( startDate <= '"+startDate+"' AND endDate >= '"+endDate+"' ) ");
        sb.append(" ) ORDER BY shardTableName desc ; ");
        List<String> tableName = jdbcTemplate.queryForList(sb.toString(), String.class);
        return tableName;
    }

    private Method findMethod(Method[] methods, String methodName) {
        for (Method method : methods) {
            if (method.getName().equals(methodName)) {
                return method;
            }
        }
        return null;
    }

    /**
     * 获得真正的处理对象,可能多层代理.
     */
    @SuppressWarnings("unchecked")
    public static <T> T realTarget(Object target) {
        if (Proxy.isProxyClass(target.getClass())) {
            MetaObject metaObject = SystemMetaObject.forObject(target);
            return realTarget(metaObject.getValue("h.target"));
        }
        return (T) target;
    }

}

最后。。。把查询的时间段参数放到last 里面有些不妥,哪位大佬们有更好的传参方法

你可能感兴趣的:(数据库与mybatis,mybatis实现数据库分表)