by shihang.mai
1. mybatis层次结构
2. 插件实现
代码思路,就是利用mybatis插件,在上面流程中进行拦截,做自己的业务逻辑
3. 代码实现(核心步骤)
注解:是否开启多租户和是否开启敏感sql
@Retention(RetentionPolicy.RUNTIME)
@Target(value = {ElementType.METHOD, ElementType.TYPE})
public @interface MultiTenant {
/**
* 是否开启使用框架层的多租户,默认开启
* @return
*/
boolean flag() default true;
}
@Retention(RetentionPolicy.RUNTIME)
@Target(value = {ElementType.METHOD, ElementType.TYPE})
public @interface SqlLimit {
/**
* 项目默认禁用drop,create,alter,truncate sql
* @return
*/
boolean flag() default true;
}
核心配置类PluginConfiguration,需要boostrap中加入
@Configuration
public class PluginConfiguration {
@Resource
private BeanFactory beanFactory;
//为了提速,不是重点
@Bean
public MultiTenantMapperCacheManager multiTenantMapperCacheManager() {
return new MultiTenantMapperCacheManager();
}
//为了提速,不是重点
@Bean
public SqlLimitMapperCacheManager sqlLimitMapperCacheManager() {
return new SqlLimitMapperCacheManager();
}
@Bean
public Interceptor tenantInterceptor(){
Interceptor interceptor = new TenantInterceptor(multiTenantMapperCacheManager());
Properties properties = new Properties();
properties.setProperty(TenantConstant.DIALECT, "postgresql");
properties.setProperty(TenantConstant.TENANTID_FIELD, TenantConstant.TENANT_ID);
interceptor.setProperties(properties);
return interceptor;
}
@Bean
public Interceptor sqlCheckInterceptor(){
Interceptor interceptor = new SqlCheckInterceptor(sqlLimitMapperCacheManager());
Properties properties = new Properties();
properties.setProperty(TenantConstant.DIALECT, "postgresql");
interceptor.setProperties(properties);
return interceptor;
}
/**
* 多租户线程池,为了解决异步线程租户id的传递
* @return MultiTenantLazyTraceThreadPoolTaskExecutor
*/
@Bean
public MultiTenantLazyTraceThreadPoolTaskExecutor multiTenantLazyTraceThreadPoolTaskExecutor() {
ThreadPoolTaskExecutor threadPoolTaskExecutor = new ThreadPoolTaskExecutor();
threadPoolTaskExecutor.setCorePoolSize(3);
threadPoolTaskExecutor.setKeepAliveSeconds(60);
threadPoolTaskExecutor.setMaxPoolSize(5);
threadPoolTaskExecutor.setQueueCapacity(1000);
threadPoolTaskExecutor.setAllowCoreThreadTimeOut(true);
threadPoolTaskExecutor.setThreadNamePrefix("base-multitsenant-pool-");
threadPoolTaskExecutor.setRejectedExecutionHandler(new ThreadPoolExecutor.AbortPolicy());
threadPoolTaskExecutor.setWaitForTasksToCompleteOnShutdown(true);
threadPoolTaskExecutor.initialize();
return new MultiTenantLazyTraceThreadPoolTaskExecutor(this.beanFactory, threadPoolTaskExecutor);
}
}
租户拦截类TenantInterceptor
@Intercepts({
@Signature(type = Executor.class, method = "update",
args = {MappedStatement.class, Object.class}),
@Signature(type = Executor.class, method = "query",
args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}),
@Signature(type = Executor.class, method = "query",
args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class, CacheKey.class, BoundSql.class}),
})
@Order(-20)
public class TenantInterceptor implements Interceptor {
private static final Logger logger = LoggerFactory.getLogger(TenantInterceptor.class);
/**
* 当前数据库的方言
*/
private String dialect;
/**
* 多租户字段名称
*/
private String tenantIdField;
private SqlConditionHelper conditionHelper;
private SqlLimitHelper sqlLimitHelper;
private AnnotationHelper annotationHelper;
private final MultiTenantMapperCacheManager multiTenantMapperCacheManager;
public TenantInterceptor(MultiTenantMapperCacheManager multiTenantMapperCacheManager) {
this.multiTenantMapperCacheManager = multiTenantMapperCacheManager;
}
@Override
public Object intercept(Invocation invocation) throws Throwable {
String tenantId = MultiTenantHolders.getTenantId();
//租户id为空时不做处理
if (StringUtils.isBlank(tenantId)) {
//todo 测试tenant_id暂时固定
tenantId = "0210000001";
//return invocation.proceed();
}
Object[] args = invocation.getArgs();
MappedStatement ms = (MappedStatement) args[0];
Object parameter = args[1];
BoundSql boundSql = ms.getBoundSql(parameter);
//判断调用栈的mapper类跟方法是否有注解
ClazzMethodInfo clazzMethodInfo = annotationHelper.getClassAndMethod(ms.getId());
boolean flag = checkAnnotation(clazzMethodInfo);
logger.info("old sql:{}", boundSql.getSql());
if (flag) {
String newSql = addTenantCondition(boundSql.getSql(), tenantId);
logger.info("new sql:{}", newSql);
//重新构造MappedStatement
buildMappedStatement(ms, args, newSql);
}
return invocation.proceed();
}
/**
* 重新构造mappedStatement
* @param mappedStatement
* @param args
* @param sql
*/
private void buildMappedStatement(MappedStatement mappedStatement, final Object[] args, String sql) {
// 获取拦截方法的参数
BoundSql currentBoundSql = mappedStatement.getBoundSql(args[1]);
BoundSql newBoundSql = new BoundSql(mappedStatement.getConfiguration(), sql,
currentBoundSql.getParameterMappings(), currentBoundSql.getParameterObject());
// 把新的查询放到statement里
MappedStatement newMs = copyFromMappedStatement(mappedStatement, new BoundSqlSqlSource(newBoundSql));
for (ParameterMapping mapping : currentBoundSql.getParameterMappings()) {
String prop = mapping.getProperty();
if (currentBoundSql.hasAdditionalParameter(prop)) {
newBoundSql.setAdditionalParameter(prop, currentBoundSql.getAdditionalParameter(prop));
}
}
args[0] = newMs;
}
private MappedStatement copyFromMappedStatement(MappedStatement ms, SqlSource newSqlSource) {
MappedStatement.Builder builder = new MappedStatement.Builder(ms.getConfiguration(), ms.getId(), newSqlSource, ms.getSqlCommandType());
builder.resource(ms.getResource());
builder.fetchSize(ms.getFetchSize());
builder.statementType(ms.getStatementType());
builder.keyGenerator(ms.getKeyGenerator());
if (ms.getKeyProperties() != null && ms.getKeyProperties().length > 0) {
builder.keyProperty(ms.getKeyProperties()[0]);
}
builder.timeout(ms.getTimeout());
builder.parameterMap(ms.getParameterMap());
builder.resultMaps(ms.getResultMaps());
builder.resultSetType(ms.getResultSetType());
builder.cache(ms.getCache());
builder.flushCacheRequired(ms.isFlushCacheRequired());
builder.useCache(ms.isUseCache());
return builder.build();
}
@Override
public Object plugin(Object target) {
return Plugin.wrap(target, this);
}
/**
* 设置属性
* @Param properties 注入属性参数
*/
@Override
public void setProperties(Properties properties) {
dialect = properties.getProperty(TenantConstant.DIALECT);
if (StringUtils.isBlank(dialect)) {
throw new IllegalArgumentException("MultiTenantPlugin need dialect property value");
}
tenantIdField = properties.getProperty(TenantConstant.TENANTID_FIELD);
if (StringUtils.isBlank(tenantIdField)) {
throw new IllegalArgumentException("MultiTenantPlugin need tenantIdField property value");
}
//多租户条件字段决策器
conditionHelper = new SqlConditionHelper(() -> false);
sqlLimitHelper = new SqlLimitHelper();
annotationHelper = new AnnotationHelper();
}
/**
* 校验注解 注解可作用在类与方法上,方法注解优先于类注解
* @param data 类方法信息
* @return true/false
*/
private boolean checkAnnotation(ClazzMethodInfo data) {
if (Objects.isNull(data) || Objects.isNull(data.getMethodName()) || Objects.isNull(data.getClassName())) {
return false;
}
String namespace = data.getNamespace();
String cacheKey = getCacheKey(namespace);
if (StringUtils.isNotEmpty(cacheKey)) {
return Boolean.valueOf(cacheKey);
}
MultiTenant annotation = annotationHelper.checkAnnotation(data, MultiTenant.class);
if (Objects.isNull(annotation)) {
setCacheKey(namespace, false);
return false;
}
setCacheKey(namespace, annotation.flag());
return annotation.flag();
}
// private boolean checkAnnotation(ClazzMethodInfo data) {
/*if (Objects.isNull(data) || Objects.isNull(data.getMethodName()) || Objects.isNull(data.getClassName())) {
return false;
}
String namespace = data.getNamespace();
//缓存校验,定位到缓存值直接返回
String cacheKey = getCacheKey(namespace);
if (StringUtils.isNotEmpty(cacheKey)) {
return Boolean.valueOf(cacheKey);
}
try {
//拦截方法级别的注解
Method[] methods = Class.forName(data.getClassName()).getMethods();
for (int i = 0; i < methods.length; i++) {
if (data.getMethodName().equals(methods[i].getName())) {
MultiTenant annotation = methods[i].getAnnotation(MultiTenant.class);
if (Objects.nonNull(annotation)) {
setCacheKey(namespace, annotation.flag());
return annotation.flag();
}
}
}
//拦截类级别的注解,在方法级别没有定位到注解,才去定位类注解
MultiTenant annotation = Class.forName(data.getClassName()).getAnnotation(MultiTenant.class);
if (Objects.isNull(annotation)) {
setCacheKey(namespace, false);
return false;
}
return annotation.flag();
} catch (ClassNotFoundException e) {
ExceptionLogger.log(e);
}
setCacheKey(namespace, false);*/
// return false;
// }
private void setCacheKey(String namespace, boolean flag) {
logger.info("设置annotation缓存:{}, flag:{}", namespace, flag);
multiTenantMapperCacheManager.update(namespace, flag);
}
private String getCacheKey(String namespace) {
logger.info("get annotation缓存:{}", namespace);
return multiTenantMapperCacheManager.get(namespace);
}
/**
* 给sql语句where添加租户id过滤条件
* @param sql 要添加过滤条件的sql语句
* @param tenantId 当前的租户id
* @return 添加条件后的sql语句
*/
private String addTenantCondition(String sql, String tenantId) {
//todo throw
if (StringUtils.isBlank(sql) || StringUtils.isBlank(tenantIdField)) return sql;
//处理limit offset size
SqlLimiter sqlLimiter = sqlLimitHelper.splitLimitOffsetSize(sql);
if (StringUtils.isNotEmpty(sqlLimiter.getSql())){
sql = sqlLimiter.getSql();
}
List statementList = SQLUtils.parseStatements(sql, dialect);
if (CollectionUtils.isEmpty(statementList)) return sql;
SQLStatement sqlStatement = statementList.get(0);
conditionHelper.addStatementCondition(sqlStatement, tenantIdField, tenantId);
String sqllimit = StringUtils.isNotEmpty(sqlLimiter.getLimit())?" "+sqlLimiter.getLimit():"";
return SQLUtils.toSQLString(statementList, DbType.postgresql) + sqllimit;
}
}
敏感sql拦截类SqlCheckInterceptor
@Intercepts({
/*@Signature(type = StatementHandler.class, method = "update",
args = {Statement.class}),
@Signature(type = StatementHandler.class, method = "query",
args = {Statement.class, ResultHandler.class}),*/
@Signature(type = StatementHandler.class, method = "prepare",
args = {Connection.class, Integer.class}),
})
@Order(-10)
public class SqlCheckInterceptor implements Interceptor {
private static final Logger logger = LoggerFactory.getLogger(SqlCheckInterceptor.class);
/**
* 当前数据库的方言
*/
private String dialect;
private SqlConditionHelper conditionHelper;
private SqlLimitHelper sqlLimitHelper;
private AnnotationHelper annotationHelper;
private final SqlLimitMapperCacheManager sqlLimitMapperCacheManager;
public SqlCheckInterceptor(SqlLimitMapperCacheManager sqlLimitMapperCacheManager) {
this.sqlLimitMapperCacheManager = sqlLimitMapperCacheManager;
}
@Override
public Object intercept(Invocation invocation) throws Throwable {
// 获取代理对象
StatementHandler target = (StatementHandler) invocation.getTarget();
// 获取sql语句的id(id取法直接通过getter方法获取,因此这里通过反射进行获取)
MetaObject metaObject = MetaObject.forObject(target, SystemMetaObject.DEFAULT_OBJECT_FACTORY,
SystemMetaObject.DEFAULT_OBJECT_WRAPPER_FACTORY, new DefaultReflectorFactory());
String id = (String) metaObject.getValue("delegate.mappedStatement.id");
StatementHandler statementHandler = (StatementHandler) invocation.getTarget();
BoundSql boundSql = statementHandler.getBoundSql();
logger.info("old sql:{}", boundSql.getSql());
//判断调用栈的mapper类跟方法是否有注解
ClazzMethodInfo clazzMethodInfo = annotationHelper.getClassAndMethod(id);
boolean flag = checkAnnotation(clazzMethodInfo);
//拦截敏感sql
if (flag) {
checkCondition(boundSql.getSql());
}
return invocation.proceed();
}
/**
* 校验注解 注解可作用在类与方法上,方法注解优先于类注解
* @param data 类方法信息
* @return true/false
*/
private boolean checkAnnotation(ClazzMethodInfo data) {
if (Objects.isNull(data) || Objects.isNull(data.getMethodName()) || Objects.isNull(data.getClassName())) {
return false;
}
String namespace = data.getNamespace();
String cacheKey = getCacheKey(namespace);
if (StringUtils.isNotEmpty(cacheKey)) {
return Boolean.valueOf(cacheKey);
}
SqlLimit annotation = annotationHelper.checkAnnotation(data, SqlLimit.class);
if (Objects.isNull(annotation)) {
setCacheKey(namespace, false);
return false;
}
setCacheKey(namespace, annotation.flag());
return annotation.flag();
}
private void setCacheKey(String namespace, boolean flag) {
logger.info("设置annotation缓存:{}, flag:{}", namespace, flag);
sqlLimitMapperCacheManager.update(namespace, flag);
}
private String getCacheKey(String namespace) {
logger.info("get annotation缓存:{}", namespace);
return sqlLimitMapperCacheManager.get(namespace);
}
//todo throw
private void checkCondition(String sql) {
//update/delete no condition
//drop/create/alter limit
//update 条件没有定位到索引
if (StringUtils.isBlank(sql)) return;
//处理limit offset size
SqlLimiter sqlLimiter = sqlLimitHelper.splitLimitOffsetSize(sql);
if (Objects.nonNull(sqlLimiter) && StringUtils.isNotEmpty(sqlLimiter.getSql())){
sql = sqlLimiter.getSql();
}
List statementList = SQLUtils.parseStatements(sql, dialect);
if (CollectionUtils.isEmpty(statementList)) return;
SQLStatement sqlStatement = statementList.get(0);
conditionHelper.checkNonCondition(sqlStatement);
}
@Override
public Object plugin(Object target) {
return Plugin.wrap(target, this);
}
/**
* 设置属性
* @Param properties 注入属性参数
*/
@Override
public void setProperties(Properties properties) {
dialect = properties.getProperty("dialect");
if (StringUtils.isBlank(dialect)) {
throw new IllegalArgumentException("MultiTenantPlugin need dialect property value");
}
//决策器
conditionHelper = new SqlConditionHelper(() -> false);
sqlLimitHelper = new SqlLimitHelper();
annotationHelper = new AnnotationHelper();
}
}