mybatis sql拦截修改

自定义注解:



import java.lang.annotation.Documented;
import java.lang.annotation.ElementType;
import java.lang.annotation.Inherited;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

@Target({ ElementType.TYPE, ElementType.METHOD })
@Retention(RetentionPolicy.RUNTIME)
@Documented
@Inherited
public @interface TableSeg {
    //表名
    public String tableName() default "tb_shop_order";
    // 分表方式
    public String shardType() default "";
    //根据什么字段分表 ,多个字段用逗号,隔开
    public String shardBy();

}

自定义拦截器:

package com.fhgl.shop.interceptor;

import java.lang.reflect.Method;
import java.sql.Connection;
import java.util.ArrayList;
import java.util.Date;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.Set;

import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.SqlSource;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Plugin;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.reflection.DefaultReflectorFactory;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.SystemMetaObject;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.fhgl.shop.annotation.TableSeg;
import com.fhgl.shop.utils.DateUtils;
import com.fhgl.shop.utils.OrderTableShardingUtils;

/**
 * 分表Interceptor
 *
 */
@Intercepts({
    @Signature(type = StatementHandler .class, method = "prepare", args = {Connection.class, Integer.class}),
})
public class MybatisSqlInterceptor implements Interceptor{

	private static final Logger LOGGER = LoggerFactory.getLogger(MybatisSqlInterceptor.class);
	
	@SuppressWarnings("unchecked")
	@Override
	public Object intercept(Invocation invocation) throws Throwable {
		 try {
			StatementHandler statementHandler = (StatementHandler) invocation.getTarget();
			    //通过MetaObject优雅访问对象的属性,这里是访问statementHandler的属性;:MetaObject是Mybatis提供的一个用于方便、
			    //优雅访问对象属性的对象,通过它可以简化代码、不需要try/catch各种reflect异常,同时它支持对JavaBean、Collection、Map三种类型对象的操作。
			    MetaObject metaObject = MetaObject
			        .forObject(statementHandler, SystemMetaObject.DEFAULT_OBJECT_FACTORY, SystemMetaObject.DEFAULT_OBJECT_WRAPPER_FACTORY,
			            new DefaultReflectorFactory());
			    //先拦截到RoutingStatementHandler,里面有个StatementHandler类型的delegate变量,其实现类是BaseStatementHandler,然后就到BaseStatementHandler的成员变量mappedStatement
			    MappedStatement mappedStatement = (MappedStatement) metaObject.getValue("delegate.mappedStatement");
			    //id为执行的mapper方法的全路径名,如com.uv.dao.UserMapper.insertUser
			    String sqlId = mappedStatement.getId();
			    //sql语句类型 select、delete、insert、update
			    String sqlCommandType = mappedStatement.getSqlCommandType().toString();
			    
		        String className = sqlId.substring(0, sqlId.lastIndexOf("."));
		        
		        String methodName = sqlId.replace(className+".", "");
		        Class classObj = Class.forName(className);
		        Method[] methods = classObj.getDeclaredMethods();
		        Method method = null;
		        for(Method method1 : methods) {
		        	if(method1.getName().equals(methodName)) {
		        		method = method1;
		        		break;
		        	}
		        }
			    TableSeg tableSeg = method.getAnnotation(TableSeg.class);
		        if(null == tableSeg){
		            //不需要分表,直接传递给下一个拦截器处理
		            return invocation.proceed();
		        }
		        LOGGER.info("MybatisSqlInterceptor--sqlId:{}",sqlId);
		        //获取分表字段
		        String[] shardField = tableSeg.shardBy().split(",");
		        List timeList = new ArrayList<>();
		        String orderId = "";
		        List orderIdList = new ArrayList<>();
		        Set tableNameSet = new HashSet();
		        Map parameters = (Map) metaObject.getValue("delegate.boundSql.parameterObject");
		        LOGGER.info("MybatisSqlInterceptor--parameters:{}",parameters);
		        //获取参数
		        for (int i = 0;i < shardField.length;i++) {
		        	Object paramValue = parameters.get(shardField[i]);
		        	if(paramValue != null) {
		        		System.out.println(paramValue.getClass());
		        		if(paramValue.getClass() == String.class) {
		        			//参数是订单号或者时间
		        			String paramStr = (String) paramValue;
		        			String paramIndex = paramStr.substring(0, 2);
		        			if(paramIndex.equals("DD")) {
		        				//订单号
		        				orderId = paramStr;
		        				String tableName = OrderTableShardingUtils.getTableNameByOrderId(orderId);
		        				if(tableName != null) {
		        					tableNameSet.add(tableName);
		        				}
		        				
		        			} else {
		        				//时间
		        				boolean isDate = DateUtils.checkDate(paramStr, "yyyy-MM-dd HH:mm:ss");
		        				
		        				if(isDate) {
		        					//时间
		        					timeList.add(paramStr);
		        				}
		        			}
		        		} else if(List.class.isAssignableFrom(paramValue.getClass())) {
		        			//参数是orderList
		        			orderIdList = (List) paramValue;
		        			if(orderIdList != null && orderIdList.size()>0) {
		        				for(String ordersId : orderIdList) {
		        					String tableName = OrderTableShardingUtils.getTableNameByOrderId(ordersId);
			        				if(tableName != null) {
			        					tableNameSet.add(tableName);
			        				}
		        				}
		        			}
		        		}
		        	}
		        }
		        //判断时间
		        Date beginTime = null;
		        Date endTime = null;
		        if(timeList.size() == 2) {
		        	beginTime = DateUtils.strToDateLong(timeList.get(0));
		        	endTime = DateUtils.strToDateLong(timeList.get(1));
		        	List tableName = new ArrayList<>();
		        	if (beginTime.getTime() > endTime.getTime()) {
		        		tableName = OrderTableShardingUtils.getTableNameListByTimes(timeList.get(1), timeList.get(0));
		        	} else {
		        		tableName = OrderTableShardingUtils.getTableNameListByTimes(timeList.get(0), timeList.get(1));
		        	}
		        	if (tableName != null && tableName.size()>0) {
		        		for(String tn: tableName) {
        					if(tn != null) {
        						tableNameSet.add(tn);
        					}
        				}
		        	}
		        	
		        }
		        
			    //拦截查询sql
			    if(sqlCommandType.equals("SELECT")) {
			    	BoundSql boundSql = statementHandler.getBoundSql();
				    String sql = boundSql.getSql();
				    if(sql.toLowerCase().indexOf("tb_shop_order") < 0 || sql.toLowerCase().indexOf("tb_shop_order_") >= 0) {
				    	return invocation.proceed();
				    }
				    if(tableNameSet.size() == 0) {
				    	return invocation.proceed();
				    } 
				    Iterator it = tableNameSet.iterator();
				    List tableNameList = new ArrayList<>();
				    while(it.hasNext()) {
				    	tableNameList.add(it.next());
				    }
				    String pattern = "(?i)tb_shop_order";
				    if(tableNameList.size() == 1) {
					    sql = sql.replaceAll(pattern, tableNameList.get(0));
				    } else {
				    	for (int i=0;i

处理逻辑工具类:

package com.fhgl.shop.utils;

import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Date;
import java.util.List;
import java.util.TimeZone;

/**
 * 订单表分表工具类
 * @author admin
 *
 */

public class OrderTableShardingUtils {

	/**
	 * 通过订单号获取表名
	 * @param orderId
	 * @return
	 */
	public static String getTableNameByOrderId(String orderId) {
		if(orderId.length() == 14) {
			return "tb_shop_order_2018";
		} else if(orderId.length() == 19) {
			String timeStr = orderId.substring(2, 15);
			Date date = new Date(Long.parseLong(timeStr));
			int year = ShardingUtil.getYear(date);
			int season = ShardingUtil.getSeason(date);
			String tableName = "tb_shop_order_"+year+"_"+season;
			return tableName;
		} else {
			return null;
		}
	}
	
	/**
	 * 通过时间段获取表名列表
	 * @param beginTime
	 * @param endTime
	 * @return
	 */
	public static List getTableNameListByTimes(String beginTime,String endTime){
		SimpleDateFormat sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");
		List tableNameList = new ArrayList<>();
		try {
			Date beginDate = sdf.parse(beginTime);
			Date endDate = sdf.parse(endTime);
			
			Long  time = System.currentTimeMillis();  //当前时间的时间戳
		    long zero = time/(1000*3600*24)*(1000*3600*24) - TimeZone.getDefault().getRawOffset();//当天的0点
			if(beginDate.getTime() > endDate.getTime() || endDate.getTime()>time) {
				return null;
			}
		    if(beginDate.getTime() > zero-1000*3600*24*100L) {
		    	tableNameList.add("tb_shop_order");
		    	return tableNameList;
		    }
		    if(endDate.getTime() <= sdf.parse("2019-01-01 00:00:00").getTime()) {
				tableNameList.add("tb_shop_order_2018");
		    	return tableNameList;
			}
			if(beginDate.getTime() < sdf.parse("2019-01-01 00:00:00").getTime()) {
				tableNameList.add("tb_shop_order_2018");
				beginDate = sdf.parse("2019-01-01 00:00:00");
			}
			int beginYear = ShardingUtil.getYear(beginDate);
			int beginSeason = ShardingUtil.getSeason(beginDate);
			
			int endYear = ShardingUtil.getYear(endDate);
			int endSeason = ShardingUtil.getSeason(endDate);
			if(beginYear == endYear) {
				for(int season = beginSeason; season <= endSeason; season++) {
					String tableName = "";
					tableName = "tb_shop_order_"+beginYear+"_"+season;
					tableNameList.add(tableName);
				}
			} else {
				for (int year = beginYear; year <= endYear ;year++) {
					String tableName = "";
					if(year == beginYear) {
						for(int season = beginSeason; season <= 4; season++) {
							tableName = "tb_shop_order_"+year+"_"+season;
							tableNameList.add(tableName);
						}
					} else if(year == endYear) {
						for(int season = 1; season <= endSeason; season++) {
							tableName = "tb_shop_order_"+year+"_"+season;
							tableNameList.add(tableName);
						}
					} else {
						for(int season = 1; season <= 4; season++) {
							tableName = "tb_shop_order_"+year+"_"+season;
							tableNameList.add(tableName);
						}
					}
				}
			}
			
			if(endDate.getTime() >= zero-1000*3600*24*10) {
				tableNameList.add("tb_shop_order");
			}
			
		} catch (Exception e) {
			return null;
		}
		return tableNameList;
	}
	
	public static List getBeginTimeAndEndTimeByTableName(String TableName,String beginTime,String endTime) {
		return null;
	}
}

启动类添加初始化拦截器:

@Bean
  ConfigurationCustomizer mybatisConfigurationCustomizer() {
      return new ConfigurationCustomizer() {
          @Override
          public void customize(org.apache.ibatis.session.Configuration configuration) {
              configuration.addInterceptor(new MybatisSqlInterceptor());
          }
      };
  }

 

你可能感兴趣的:(web开发)