mybatis拦截器 多租户隔离 及 数据权限隔离 动态可扩展

多租户隔离, 数据权限隔离 (动态扩展)

定义多租户注解 , 添加多注解的contrler 或者接口 开启多租户模式

使用案例:

@GetMapping("/test")
//开启多租户 (方法 或 controller类上使用)
@DataSpace	
//开启数据权限处理策略  aaaa数据权限策略 和 bbbb数据权限策略
@DataPermission({DataPermissionEnum.aaaa, DataPermissionEnum.bbbb})
public Object test() {
    return null;
}
//aaaa数据权限策略
public class Aaa_DataPermissionHandler implements DataPermissionConditionHandler {
    @Override
    public List<SqlConditionDTO> handle() {
        List<Integer> groupIds = ContextHolderUtil.getDataPermissionList();
        return List.of(
        	    SqlConditionDTO.eq("table_a", "id", 1),
                SqlConditionDTO.in("table_b", "id", groupIds),
                SqlConditionDTO.jsonContains("table_c", "group_ids", groupIds)
        );
    }
}

开始

/**
 * 开启 地区 多租户
 * @author xiaopeng
 * @date 2021-06-09
 */
@Target({ElementType.METHOD, ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
public @interface DataSpace {

}

上下文参数传递工具类

public class ContextHolderUtil {


    /**
     * 获取request
     * @return
     */
    public static HttpServletRequest getRequest() {
        var obj = ((ServletRequestAttributes) RequestContextHolder.getRequestAttributes());
        return obj == null ? null : obj.getRequest();
    }

    /**
     * 获取httpSession
     * @return
     */
    public static HttpSession getSession() {
        var request = getRequest();
        return  request == null ? null : request.getSession();
    }

    /**
     * 获取request中的值
     * @param key request的key
     * @return
     */
    public static Object getRequestAttribute(String key) {
        var request = getRequest();
        return request == null ? null : request.getAttribute(key);
    }

    /**
     * 获取 Token 中的用户 id
     * @return
     */
    public static Integer getAuthedUserId() {
        var requestArr =  getRequestAttribute("id");
        return Integer.parseInt(requestArr.toString());
    }

    //获取角色id
    public static Integer getAuthedUserRoleId() {
        return Integer.parseInt(String.valueOf(getRequestAttribute("roleId")));
    }

    //获取角色权重
    public static Integer getAuthedUserRoleWeight() {
        return Integer.parseInt(String.valueOf(getRequestAttribute("roleWeight")));
    }

    //获取用户信息
    public static UserInfo getAuthedUserInfo() {
        return (UserInfo) getRequestAttribute("userInfo");
    }

    private static String openDataSpace = "open";
    //开启 地区 多租户模式
    public static void openDataSpace(){
        HttpServletRequest request = getRequest();
        request.setAttribute("openDataSpace", "open");
    }
    //判断是否开启 地区 多租户模式
    public static boolean isDataSpace(){
        HttpServletRequest request = getRequest();
        return openDataSpace.equals(request.getAttribute("openDataSpace"));
    }


    //获取前端传过来的 地区id
    public static String getDataSpaceId(){
        HttpServletRequest request = getRequest();
        String dataSpaceId = request.getHeader("SpaceId");
        if (StringUtil.isEmpty(dataSpaceId)){
            //请求头拿不到的话, 去url参数那里试一下
            dataSpaceId = request.getParameter("SpaceId");
        }
        if (isDataSpace()){
            //未选择地区id 强制抛异常退出
            CommonUtil.expressionThrowResponseCodeAndSetAttribute(StringUtil.isEmpty(dataSpaceId), ResponseCode.NOT_REGION);
        }
        return dataSpaceId;
    }

}

多租户注解切面, 将注解的信息保存进 上下文参数

@Component
@Aspect
@Slf4j
public class DataSpaceAspect {

    @Pointcut("@annotation(com.xxxx.xxxx.common.annotation.DataSpace)")
    public void pointcut(){
    }

    @Pointcut("execution(* com.xxxx.xxxx..*Controller.*(..))")
    public void pointcutClass(){
    }

    @Before("pointcut()")
    public void openDataSpace(){
        ContextHolderUtil.openDataSpace();
    }


    /**
     * 拦截所有control接口
     */
    @Around("pointcutClass()")
    public Object openDataSpace(ProceedingJoinPoint joinPoint) throws Throwable {
        Class<?> clazz = joinPoint.getTarget().getClass();
        //如果类上面带 DataSpace 注解,开启多租户模式
        DataSpace annotation = clazz.getAnnotation(DataSpace.class);
        if (annotation != null) {
            ContextHolderUtil.openDataSpace();
        }
        return joinPoint.proceed();
    }

}

multi-tenant.properties 配置文件 配置多租户字段,以及表(逗号分割)

tenant.id-field=region_id
tenant.table=tfmes_sales_order,tfmes_work_procedure_data

mybatis SqlSessionFactory 添加自己写的mybatis多租户拦截器

@Configuration
@PropertySource("classpath:multi-tenant.properties")
public class TfmesMybatisConfig {


    /**
     * 多租户字段名称
     */
    @Value("${tenant.id-field}")
    private String tenantIdField;

    /**
     * 需要识别多租户字段的表名称列表
     */
    @Value("${tenant.table}")
    private String tableSet;

    @Bean
    public String interceptorTfmes(SqlSessionFactory sqlSessionFactory) {
    //如果多数据源 配置了多个sqlSessionFactory public String interceptorTfmes(@Qualifier("xxxxSqlSessionFactory") SqlSessionFactory sqlSessionFactory) {
        MultiTenantPlugin multiTenantPlugin = new MultiTenantPlugin();
        Properties properties = new Properties(4);
        properties.put("tenantIdField", tenantIdField);
        properties.put("tableNames", tableSet);

        multiTenantPlugin.setProperties(properties);
        // 可添加多个mybatis拦截器
        sqlSessionFactory.getConfiguration().addInterceptor(multiTenantPlugin);
        return "interceptorTfmes";
    }

}

MultiTenantPlugin mybatis拦截器

@Slf4j
@Intercepts({@Signature(
        type = Executor.class,
        method = "update",
        args = {MappedStatement.class, Object.class}
), @Signature(
        type = Executor.class,
        method = "query",
        args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}
), @Signature(
        type = Executor.class,
        method = "query",
        args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class, CacheKey.class, BoundSql.class}
)})
//@Component    不要这个 MybatisConfig.java 已经注入了
public class TfmesMultiTenantPlugin implements Interceptor {

    /**
     * 多租户字段名称
     */
    private String tenantIdField;

    /**
     * 需要识别多租户字段的表名称列表
     */
    private Set<String> tableSet;

    //新增,编辑 时候字段赋值
    private SqlFieldHelper sqlFieldHelper;
    //sql条件生成
    private SqlConditionHelper conditionHelper;

    @Override
    public Object intercept(Invocation invocation) throws Throwable {


        //获取前端传过来的 地区id
        String tenantFieldValue = ContextHolderUtil.getDataSpaceId();

        //@Signature 上面拦截的方法 的参数
        final Object[] args = invocation.getArgs();
        MappedStatement ms = (MappedStatement) args[0];
        Object parameter = args[1];
        //这个巨重要 不然会 新sql成功生成, 但mybatis运行的是旧sql
        BoundSql boundSql;
        if(args.length == 6){
            // 6 个参数时
            boundSql = (BoundSql) args[5];
        } else {
            boundSql = ms.getBoundSql(parameter);
        }



        //如果查询没有 地区 多租户 不生成新sql
        if (SqlCommandType.SELECT == ms.getSqlCommandType() && !ContextHolderUtil.isDataSpace()){
            return invocation.proceed();
        }



        String processSql = boundSql.getSql();
        log.debug("替换前SQL:{}", processSql);

        //TODO 核心 语法分析生成新sql
        String newSql = getNewSql(processSql, tenantFieldValue);

        log.debug("替换后SQL:{}", newSql);



        //通过反射修改sql字段
        MetaObject boundSqlMeta = MetaObject.forObject(boundSql, new DefaultObjectFactory(), new DefaultObjectWrapperFactory(), new DefaultReflectorFactory());
        // 把新sql设置到boundSql
        boundSqlMeta.setValue("sql", newSql);


        // 重新new一个查询语句对象
        BoundSql newBoundSql = new BoundSql(ms.getConfiguration(), newSql,
                boundSql.getParameterMappings(), boundSql.getParameterObject());
        // 把新的查询放到statement里
        MappedStatement newMs = newMappedStatement(ms, new BoundSqlSqlSource(newBoundSql));
        for (ParameterMapping mapping : boundSql.getParameterMappings()) {
            String prop = mapping.getProperty();
            if (boundSql.hasAdditionalParameter(prop)) {
                newBoundSql.setAdditionalParameter(prop, boundSql.getAdditionalParameter(prop));
            }
        }
        args[0] = newMs;

        return invocation.proceed();
    }




    @Override
    public void setProperties(Properties properties) {
        tenantIdField = properties.getProperty("tenantIdField");
        if (StringUtils.isBlank(tenantIdField)) {
            throw new IllegalArgumentException("MultiTenantPlugin need tenantIdField property value");
        }


        String tableNames = properties.getProperty("tableNames");
        if (!StringUtils.isBlank(tableNames)) {
            tableSet = new HashSet<>(Arrays.asList(StringUtils.split(tableNames, ",")));
        }


        // 多租户条件字段决策器
        TableFieldConditionDecision conditionDecision = new TableFieldConditionDecision() {
            @Override
            public boolean isAllowNullValue() {
                return false;
            }
            @Override
            public boolean adjudge(String tableName, String fieldName) {
                // 去除反引号
                tableName = tableName.replace("`", "");
                return tableSet != null && tableSet.contains(tableName);
            }
        };
        sqlFieldHelper = new SqlFieldHelper(conditionDecision);
        conditionHelper = new SqlConditionHelper();
    }






    /**
     * 定义一个内部辅助类,作用是包装 SQL
     */
    static class BoundSqlSqlSource implements SqlSource {

        private final BoundSql boundSql;

        public BoundSqlSqlSource(BoundSql boundSql) {
            this.boundSql = boundSql;
        }

        @Override
        public BoundSql getBoundSql(Object parameterObject) {
            return boundSql;
        }

    }


    /**
     * 根据原MappedStatement更新SqlSource生成新MappedStatement
     *
     * @param ms MappedStatement
     * @param newSqlSource 新SqlSource
     * @return
     */
    private MappedStatement newMappedStatement(MappedStatement ms, SqlSource newSqlSource) {
        MappedStatement.Builder builder = new
                MappedStatement.Builder(ms.getConfiguration(), ms.getId(), newSqlSource, ms.getSqlCommandType());
        builder.resource(ms.getResource());
        builder.fetchSize(ms.getFetchSize());
        builder.statementType(ms.getStatementType());
        builder.keyGenerator(ms.getKeyGenerator());
        if (ms.getKeyProperties() != null && ms.getKeyProperties().length > 0) {
            builder.keyProperty(ms.getKeyProperties()[0]);
        }
        builder.timeout(ms.getTimeout());
        builder.parameterMap(ms.getParameterMap());
        builder.resultMaps(ms.getResultMaps());
        builder.resultSetType(ms.getResultSetType());
        builder.cache(ms.getCache());
        builder.flushCacheRequired(ms.isFlushCacheRequired());
        builder.useCache(ms.isUseCache());
        return builder.build();
    }




    /**
     * 给sql语句where添加租户id过滤条件
     *
     * @param sql      要添加过滤条件的sql语句
     * @param tenantFieldValue 当前的租户id
     * @return 添加条件后的sql语句
     */
    private String getNewSql(String sql, String tenantFieldValue) {

        List<SQLStatement> statementList = SQLUtils.parseStatements(sql, JdbcConstants.MYSQL);
        if (statementList.size() == 0) {
            return sql;
        }

        SQLStatement sqlStatement = statementList.get(0);

        //新增,修改 字段赋值 (创建人 时间,修改人 时间,多租户字段)
        sqlFieldHelper.addStatementField(sqlStatement, tenantIdField, tenantFieldValue);

        //多地区  多租户
        //查询、修改、删除  where条件添加多租户
        if (ContextHolderUtil.isDataSpace()){
//            conditionHelper.addStatementCondition(sqlStatement, tenantIdField, tenantFieldValue);
            conditionHelper.addStatementCondition(sqlStatement, (String tableName, String tableAlias) ->{
                // 去除反引号
                tableName = tableName.replace("`", "");
                if (tableSet != null && tableSet.contains(tableName)){
                    String fieldName = StringUtils.isBlank(tableAlias) ? tableName + "." + tenantIdField : tableAlias + "." + tenantIdField;
                    return SqlConditionJointHelper.jointEqSql(fieldName, tenantFieldValue);
                }
                return null;
            });
        }


        String newSql = SQLUtils.toSQLString(statementList, JdbcConstants.MYSQL);
        //去掉自动加上去的 \
        return newSql.replaceAll("\\\\", "");
    }


}

sql 新增,编辑 字段赋值 (druid 解析sql)

public class SqlFieldHelper {

    private final TableFieldConditionDecision conditionDecision;

    public SqlFieldHelper(TableFieldConditionDecision conditionDecision) {
        this.conditionDecision = conditionDecision;
    }

    /**
     * 为sql语句添加指定字段
     *
     * @param sqlStatement
     * @param fieldName
     * @param fieldValue
     */
    public void addStatementField(SQLStatement sqlStatement, String fieldName, String fieldValue) {
        if (sqlStatement instanceof SQLInsertStatement) {
            SQLInsertStatement insertStatement = (SQLInsertStatement) sqlStatement;
            addInsertStatementField(insertStatement, fieldName, fieldValue);
        } else if (sqlStatement instanceof SQLUpdateStatement) {
            SQLUpdateStatement updateStatement = (SQLUpdateStatement) sqlStatement;
            addUpdateStatementField(updateStatement, fieldName, fieldValue);
        }
    }

    /**
     * 为insert语句添加 字段
     *
     * @param insertStatement
     * @param fieldName
     * @param fieldValue
     */
    private void addInsertStatementField(SQLInsertStatement insertStatement, String fieldName, String fieldValue) {
        List<SQLExpr> columns = insertStatement.getColumns();
        //list是存在批量插入的情况
        List<SQLInsertStatement.ValuesClause> valuesList = insertStatement.getValuesList();
        //多租户字段
        if (ContextHolderUtil.isDataSpace() && conditionDecision.adjudge(insertStatement.getTableName().getSimpleName(), fieldName)){
            addInsertItemField(columns, valuesList, fieldName, fieldValue);
        }
        addInsertItemField(columns, valuesList, "created_by", ContextHolderUtil.getAuthedUserId());
        addInsertItemField(columns, valuesList, "created_at", getCurrentTime());
    }


    private void addInsertItemField(List<SQLExpr> columns, List<SQLInsertStatement.ValuesClause> valuesList, String fieldName, Object fieldValue) {
        if (fieldName == null || fieldValue == null){
            return;
        }
        for (int i = 0 ; i < columns.size() ; i++) {
            SQLIdentifierExpr column = (SQLIdentifierExpr) columns.get(i);
            //如果字段名字匹配
            if (column.getName().equals(fieldName)){
                for (SQLInsertStatement.ValuesClause valuesClause : valuesList) {
                    List<SQLExpr> values = valuesClause.getValues();
                    SQLExpr valExpr = values.get(i);
                    //如果值是空的就设置
                    if (valExpr instanceof SQLNullExpr){
                        values.set(i, getValuableExpr(fieldValue));
                    }
                }
                return;
            }
        }

        //如果没有匹配, 加入这个字段
        columns.add(new SQLIdentifierExpr(fieldName));
        for (SQLInsertStatement.ValuesClause valuesClause : valuesList) {
            List<SQLExpr> values = valuesClause.getValues();
            values.add(getValuableExpr(fieldValue));
        }
    }





    /**
     * 为update语句添加 字段
     *
     * @param updateStatement
     * @param fieldName
     * @param fieldValue
     */
    private void addUpdateStatementField(SQLUpdateStatement updateStatement, String fieldName, Object fieldValue) {

        List<SQLUpdateSetItem> items = updateStatement.getItems();
        addUpdateItemField(items, "updated_by", ContextHolderUtil.getAuthedUserId());
        addUpdateItemField(items, "updated_at", getCurrentTime());
    }

    private void addUpdateItemField(List<SQLUpdateSetItem> items, String fieldName, Object fieldValue) {
        if (fieldName == null || fieldValue == null){
            return;
        }
        for (SQLUpdateSetItem item : items) {
//            if (((SQLIdentifierExpr)item.getColumn()).getName().equals(fieldName) && item.getValue() instanceof SQLNullExpr) {
//                item.setValue(getValuableExpr(fieldValue));
//                return;
//            }
            if (((SQLIdentifierExpr)item.getColumn()).getName().equals(fieldName)) {
                return;
            }
        }

        //如果没有匹配, 加入这个字段
        SQLUpdateSetItem sqlUpdateSetItem = new SQLUpdateSetItem();
        sqlUpdateSetItem.setColumn(new SQLIdentifierExpr(fieldName));
        sqlUpdateSetItem.setValue(getValuableExpr(fieldValue));
        items.add(sqlUpdateSetItem);
    }


    /**
     * 封装SQLValuableExpr
     * @param val
     * @return
     */
    private SQLValuableExpr getValuableExpr(Object val){
        if (val == null){
            return new SQLNullExpr();
        } else if (val instanceof Number){
            return new SQLNumberExpr((Number)val);
        } else if (val instanceof String){
            return new SQLCharExpr((String)val);
        }else {
            return new SQLCharExpr(val.toString());
        }
    }

    /**
     * 获取当前时间
     */
    private final SimpleDateFormat sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");
    private String getCurrentTime(){
        return sdf.format(new Date());
    }


    public static void main(String[] args) {
//        String sql = "select * from user s  ";
//        String sql = "select * from user s where s.name='333'";
//        String sql = "select * from (select * from tab t where id = 2 and name = 'wenshao') s where s.name='333'";
//        String sql="select u.*,g.name from user u join user_group g on u.groupId=g.groupId where u.name='123'";

        String sql = "update user set name=? where id =(select id from user s)";
//        String sql = "delete from user where id = ( select id from user s )";

//        String sql = "insert into user (id,name) select g.id,g.name from user_group g where id=1";
//        String sql = "insert into user (id,name) values (1, null), (2, '小白')";

//        String sql = "SELECT id, pw_id, warehouse_top, warehouse_level, warehouse_name  , is_default, can_to, disabled, data_space_id, created_at  , created_by, updated_at, updated_by FROM erp_warehouses_warehouse WHERE disabled = 0 AND is_default = 1 AND warehouse_top = (   SELECT id   FROM erp_warehouses_warehouse   WHERE is_default = 1    AND disabled = 0  )";
        List<SQLStatement> statementList = SQLUtils.parseStatements(sql, JdbcConstants.MYSQL);
        SQLStatement sqlStatement = statementList.get(0);

        SqlFieldHelper sqlFieldHelper = new SqlFieldHelper(new TableFieldConditionDecision() {
            @Override
            public boolean adjudge(String tableName, String fieldName) {
                return true;
            }

            @Override
            public boolean isAllowNullValue() {
                return false;
            }
        });
        sqlFieldHelper.addStatementField(sqlStatement, "regionId", "2");

        //添加多租户条件,domain是字段ignc,yay是筛选值
        System.out.println("源sql:" + sql);
        System.out.println("修改后sql:" + SQLUtils.toSQLString(statementList, JdbcConstants.MYSQL));
    }

}

sql where 条件 处理工具类 (druid 解析sql)

public class SqlConditionHelper {



    /**
     * 为sql语句添加指定where条件
     *
     * @param sqlStatement
     * @param sqlConditionGenerate
     */
    public void addStatementCondition(SQLStatement sqlStatement, SqlConditionGenerate sqlConditionGenerate) {
        if (sqlStatement instanceof SQLSelectStatement) {
            SQLSelectQueryBlock queryObject = (SQLSelectQueryBlock) ((SQLSelectStatement) sqlStatement).getSelect().getQuery();
            addSelectStatementCondition(queryObject, queryObject.getFrom(), sqlConditionGenerate);
        } else if (sqlStatement instanceof SQLUpdateStatement) {
            SQLUpdateStatement updateStatement = (SQLUpdateStatement) sqlStatement;
            addUpdateStatementCondition(updateStatement, sqlConditionGenerate);
        } else if (sqlStatement instanceof SQLDeleteStatement) {
            SQLDeleteStatement deleteStatement = (SQLDeleteStatement) sqlStatement;
            addDeleteStatementCondition(deleteStatement, sqlConditionGenerate);
        } else if (sqlStatement instanceof SQLInsertStatement) {
            SQLInsertStatement insertStatement = (SQLInsertStatement) sqlStatement;
            addInsertStatementCondition(insertStatement, sqlConditionGenerate);
        }
    }

    /**
     * 为insert语句添加where条件
     *
     * @param insertStatement
     * @param sqlConditionGenerate
     */
    private void addInsertStatementCondition(SQLInsertStatement insertStatement, SqlConditionGenerate sqlConditionGenerate) {
        if (insertStatement != null) {
            SQLSelect sqlSelect = insertStatement.getQuery();
            if (sqlSelect != null) {
                SQLSelectQueryBlock selectQueryBlock = (SQLSelectQueryBlock) sqlSelect.getQuery();
                addSelectStatementCondition(selectQueryBlock, selectQueryBlock.getFrom(), sqlConditionGenerate);
            }
        }
    }


    /**
     * 为delete语句添加where条件
     *
     * @param deleteStatement
     * @param sqlConditionGenerate
     */
    private void addDeleteStatementCondition(SQLDeleteStatement deleteStatement, SqlConditionGenerate sqlConditionGenerate) {
        SQLExpr where = deleteStatement.getWhere();
        //添加子查询中的where条件
        addSQLExprCondition(where, sqlConditionGenerate);

        SQLExpr newCondition = newEqualityCondition(deleteStatement.getTableName().getSimpleName(),
                deleteStatement.getTableSource().getAlias(), sqlConditionGenerate, where);
        deleteStatement.setWhere(newCondition);
    }

    /**
     * where中添加指定筛选条件
     *
     * @param where      源where条件
     * @param sqlConditionGenerate
     */
    private void addSQLExprCondition(SQLExpr where, SqlConditionGenerate sqlConditionGenerate) {
        if (where instanceof SQLInSubQueryExpr) {
            SQLInSubQueryExpr inWhere = (SQLInSubQueryExpr) where;
            SQLSelect subSelectObject = inWhere.getSubQuery();
            SQLSelectQueryBlock subQueryObject = (SQLSelectQueryBlock) subSelectObject.getQuery();
            addSelectStatementCondition(subQueryObject, subQueryObject.getFrom(), sqlConditionGenerate);
        } else if (where instanceof SQLBinaryOpExpr) {
            SQLBinaryOpExpr opExpr = (SQLBinaryOpExpr) where;
            SQLExpr left = opExpr.getLeft();
            SQLExpr right = opExpr.getRight();
            addSQLExprCondition(left, sqlConditionGenerate);
            addSQLExprCondition(right, sqlConditionGenerate);
        } else if (where instanceof SQLQueryExpr) {
            SQLSelectQueryBlock selectQueryBlock = (SQLSelectQueryBlock) (((SQLQueryExpr) where).getSubQuery()).getQuery();
            addSelectStatementCondition(selectQueryBlock, selectQueryBlock.getFrom(), sqlConditionGenerate);
        }
    }

    /**
     * 为update语句添加where条件
     *
     * @param updateStatement
     * @param sqlConditionGenerate
     */
    private void addUpdateStatementCondition(SQLUpdateStatement updateStatement, SqlConditionGenerate sqlConditionGenerate) {
        SQLExpr where = updateStatement.getWhere();
        //添加子查询中的where条件
        addSQLExprCondition(where, sqlConditionGenerate);
        SQLExpr newCondition = newEqualityCondition(updateStatement.getTableName().getSimpleName(),
                updateStatement.getTableSource().getAlias(), sqlConditionGenerate, where);
        updateStatement.setWhere(newCondition);
    }

    /**
     * 给一个查询对象添加一个where条件
     *
     * @param queryObject
     * @param sqlConditionGenerate
     */
    private void addSelectStatementCondition(SQLSelectQueryBlock queryObject, SQLTableSource from, SqlConditionGenerate sqlConditionGenerate) {
        if (from == null || queryObject == null) {
            return;
        }

        SQLExpr originCondition = queryObject.getWhere();
        // 添加子查询中的where条件
        addSQLExprCondition(originCondition, sqlConditionGenerate);
        if (from instanceof SQLExprTableSource) {
            // TODO 对于JOIN_TABLE支持有问题,待优化
            String tableName = ((SQLIdentifierExpr) ((SQLExprTableSource) from).getExpr()).getName();
            String alias = from.getAlias();
            SQLExpr newCondition = newEqualityCondition(tableName, alias, sqlConditionGenerate, originCondition);
            queryObject.setWhere(newCondition);
        } else if (from instanceof SQLJoinTableSource) {
            SQLJoinTableSource joinObject = (SQLJoinTableSource) from;
            SQLTableSource left = joinObject.getLeft();
            SQLTableSource right = joinObject.getRight();

            addSelectStatementCondition(queryObject, left, sqlConditionGenerate);
            addSelectStatementCondition(queryObject, right, sqlConditionGenerate);

        } else if (from instanceof SQLSubqueryTableSource) {
            SQLSelect subSelectObject = ((SQLSubqueryTableSource) from).getSelect();
            SQLSelectQueryBlock subQueryObject = (SQLSelectQueryBlock) subSelectObject.getQuery();
            addSelectStatementCondition(subQueryObject, subQueryObject.getFrom(), sqlConditionGenerate);
        } else {
            throw new NotImplementedException("未处理的异常");
        }
    }

    /**
     * 根据原来的condition创建一个新的condition
     *
     * @param tableName       表名称
     * @param tableAlias      表别名 字段名
     * @param sqlConditionGenerate 字段值
     * @param originCondition 原始条件
     * @return
     */
    private SQLExpr newEqualityCondition(String tableName, String tableAlias, SqlConditionGenerate sqlConditionGenerate, SQLExpr originCondition) {


        String newConditionSql  = sqlConditionGenerate.generate(tableName, tableAlias);
        if (StringUtil.isEmpty(newConditionSql)){
            return originCondition;
        } else {
            return SQLUtils.buildCondition(SQLBinaryOperator.BooleanAnd, new SQLVariantRefExpr(newConditionSql), false, originCondition);
        }
    }


    public static void main(String[] args) {
//        String sql = "select * from user s  ";
//        String sql = "select * from user s where s.name='333'";
//        String sql = "select * from (select * from tab t where id = 2 and name = 'wenshao') s where s.name='333'";
//        String sql="select u.*,g.name from user u join user_group g on u.groupId=g.groupId where u.name='123'";

//        String sql = "update user set name=? where id =(select id from user s)";
//        String sql = "delete from user where id = ( select id from user s )";

        String sql = "insert into user (id,name) select g.id,g.name from user_group g where id=1";

//        String sql = "SELECT id, pw_id, warehouse_top, warehouse_level, warehouse_name  , is_default, can_to, disabled, data_space_id, created_at  , created_by, updated_at, updated_by FROM erp_warehouses_warehouse WHERE disabled = 0 AND is_default = 1 AND warehouse_top = (   SELECT id   FROM erp_warehouses_warehouse   WHERE is_default = 1    AND disabled = 0  )";
        List<SQLStatement> statementList = SQLUtils.parseStatements(sql, JdbcConstants.MYSQL);
        SQLStatement sqlStatement = statementList.get(0);

        SQLInsertStatement sqlSelectStatement = (SQLInsertStatement) sqlStatement;
        //决策器定义
        SqlConditionHelper helper = new SqlConditionHelper();
        //添加多租户条件,domain是字段ignc,yay是筛选值
        helper.addStatementCondition(sqlStatement, (String tableName, String tableAlias) ->{
            String fieldName = "data_space_id";
            fieldName = StringUtils.isBlank(tableAlias) ? tableName + "." + fieldName : tableAlias + "." + fieldName;
            return fieldName + " = 1";
        });
        System.out.println("源sql:" + sql);
        System.out.println("修改后sql:" + SQLUtils.toSQLString(statementList, JdbcConstants.MYSQL));
    }

}

sql 条件组装工具类

public class SqlConditionJointHelper {

    public static final String operation_eq = "operation_eq";
    public static final String operation_in = "operation_in";
    public static final String operation_jsonContains = "operation_jsonContains";

    public static String joint(String fieldName, String operation, Object fieldValue, List<?> fieldValueList){
        if (operation_eq.equals(operation)){
            return jointEqSql(fieldName, fieldValue);
        } else if (operation_in.equals(operation)) {
            return jointInSql(fieldName, fieldValueList);
        } else if (operation_jsonContains.equals(operation)) {
            return jointJsonContainsSql(fieldName, fieldValueList);
        }

        return null;
    }

    /**
     * 拼接eq条件sql
     * @param fieldName
     * @param fieldValue
     * @return
     */
    public static String jointEqSql(String fieldName, Object fieldValue){
        if (StringUtil.isEmpty(fieldName) || fieldValue == null){
            return null;
        }
        return fieldName + " = " + fieldValue;
    }


    /**
     * 拼接in条件sql
     * @param fieldName
     * @param list
     * @return
     */
    public static String jointInSql(String fieldName, List<?> list){
        if (list == null || list.size() == 0){
            return "FALSE";
        }
        String str = StringUtils.join(list, ",");
        return fieldName + " in (" + str + ")";
    }


    /**
     * 拼接json数组条件sql
     * @param fieldName
     * @param list
     * @return
     */
    public static String jointJsonContainsSql(String fieldName, List<?> list){
        if (list == null || list.size() == 0){
            return "FALSE";
        }
        final String separator = " OR ";
        StringBuilder sb = new StringBuilder();
        for (Object item : list) {
            sb.append("JSON_CONTAINS(").append(fieldName).append(", '").append(item).append("')").append(separator);
        }
        String str = sb.toString();
        return "(" + str.substring(0, str.length() - separator.length()) + ")";
    }
}

升级部分

数据权限 动态可扩展

存储sql添加生成规则的dto

@Data
public class SqlConditionDTO {

    //表名
    String tableName;
    //字段
    String fieldName;
    //操作
    String operation;
    //数据
    Object fieldValue;
    //数据
    List<?> fieldValueList;

    public static SqlConditionDTO eq(String tableName, String fieldName, Object fieldValue){
        SqlConditionDTO sqlConditionDTO = new SqlConditionDTO();
        sqlConditionDTO.setOperation(SqlConditionJointHelper.operation_eq);
        sqlConditionDTO.setTableName(tableName);
        sqlConditionDTO.setFieldName(fieldName);
        sqlConditionDTO.setFieldValue(fieldValue);
        return sqlConditionDTO;
    }


    public static SqlConditionDTO in(String tableName, String fieldName,  List<?> fieldValueList){
        SqlConditionDTO sqlConditionDTO = new SqlConditionDTO();
        sqlConditionDTO.setOperation(SqlConditionJointHelper.operation_in);
        sqlConditionDTO.setTableName(tableName);
        sqlConditionDTO.setFieldName(fieldName);
        sqlConditionDTO.setFieldValueList(fieldValueList);
        return sqlConditionDTO;
    }

    public static SqlConditionDTO jsonContains(String tableName, String fieldName,  List<?> fieldValueList){
        SqlConditionDTO sqlConditionDTO = new SqlConditionDTO();
        sqlConditionDTO.setOperation(SqlConditionJointHelper.operation_jsonContains);
        sqlConditionDTO.setTableName(tableName);
        sqlConditionDTO.setFieldName(fieldName);
        sqlConditionDTO.setFieldValueList(fieldValueList);
        return sqlConditionDTO;
    }
}

数据权限mybatis拦截器中 sql 生成规则 处理者,生成sql条件生产规则dto

/**
 * 数据权限mybatis拦截器 sql 生成规则 处理者
 */
public interface DataPermissionConditionHandler {
    List<SqlConditionDTO> handle();
}

枚举常量 用来配合注解指定接口, 用哪些sql条件生成处理者

public enum DataPermissionEnum {

    /**
     * 用于aaa的数据隔离
     */
    aaa(new aaa_DataPermissionHandler(), "用于aaa的数据隔离")
    /**
     * 用于bbb的数据隔离
     */
    bbb(new bbb_DataPermissionHandler(), "用于bbb的数据隔离")

    ;

    /**
     * 数据权限mybatis拦截器 sql 生成规则 处理这
     */
    private DataPermissionConditionHandler dataPermissionConditionHandler;
    /**
     * 提示
     */
    private String msg;

    DataPermissionEnum(DataPermissionConditionHandler dataPermissionConditionHandler, String msg) {
        this.dataPermissionConditionHandler = dataPermissionConditionHandler;
        this.msg = msg;
    }

    public DataPermissionConditionHandler getDataPermissionConditionHandler() {
        return dataPermissionConditionHandler;
    }
    public String getMsg(){
        return msg;
    }
}

数据权限注解 (指定枚举常量,指定用哪些sql条件处理者)

@Target({ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
public @interface DataPermission {
    DataPermissionEnum[] value();
}

数据权限注解切面, 将注解的信息保存进 上下文参数

@Component
@Aspect
@Slf4j
public class DataPermissionAspect {

    @Pointcut("@annotation(com.kintexgroup.tfmes.common.annotation.DataPermission)")
    public void pointcut(){
    }

    @Before("pointcut()")
    public void openDataSpaceMethod(JoinPoint joinPoint){
        DataPermission annotation = ((MethodSignature)joinPoint.getSignature())
                .getMethod().getAnnotation(DataPermission.class);

        DataPermissionEnum[] dataPermissionEnumList = annotation.value();

        log.debug("接口 {} ", ContextHolderUtil.getRequest().getRequestURL());
        List<SqlConditionDTO> list = new ArrayList<>();
        for (DataPermissionEnum dataPermissionEnum : dataPermissionEnumList) {
            log.debug("\t -数据权限处理:{}", dataPermissionEnum.getMsg());
            List<SqlConditionDTO> sqlConditionDTOList = dataPermissionEnum.getDataPermissionConditionHandler().handle();
            if (sqlConditionDTOList != null && sqlConditionDTOList.size() > 0){
                list.addAll(sqlConditionDTOList);
            }
        }

        ContextHolderUtil.setDataPermissionSqlConditionList(list);
    }

}

之前的上下文参数传递工具类 (添加代码:保存数据权限处理者)

public class ContextHolderUtil {
    /**
     * 是否开启数据权限条件处理
     * @return
     */
    public static boolean isDataPermission() {
        List<SqlConditionDTO> sqlConditionDTOList = getDataPermissionSqlConditionList();
        return sqlConditionDTOList != null && sqlConditionDTOList.size() > 0;
    }
    private static final String SqlConditionDTOList = "SqlConditionDTOList";
    //获取数据权限条件生成dto
    public static void setDataPermissionSqlConditionList(List<SqlConditionDTO> list) {
        getRequest().setAttribute(SqlConditionDTOList, list);
    }
    public static List<SqlConditionDTO> getDataPermissionSqlConditionList() {
        Object target = getRequestAttribute(SqlConditionDTOList);
        if (target != null) {
            return (List<SqlConditionDTO>) target;
        }
        return null;
    }
}

之前的 MultiTenantPlugin mybatis拦截器 升级 (添加代码)

@Slf4j
@Intercepts({@Signature(
        type = Executor.class,
        method = "update",
        args = {MappedStatement.class, Object.class}
), @Signature(
        type = Executor.class,
        method = "query",
        args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}
), @Signature(
        type = Executor.class,
        method = "query",
        args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class, CacheKey.class, BoundSql.class}
)})
//@Component    不要这个 MybatisConfig.java 已经注入了
public class TfmesMultiTenantPlugin implements Interceptor {

    /**
     * 多租户字段名称
     */
    private String tenantIdField;

    /**
     * 需要识别多租户字段的表名称列表
     */
    private Set<String> tableSet;

    //新增,编辑 时候字段赋值
    private SqlFieldHelper sqlFieldHelper;
    //sql条件生成
    private SqlConditionHelper conditionHelper;

    @Override
    public Object intercept(Invocation invocation) throws Throwable {


        //获取前端传过来的 地区id
        String tenantFieldValue = ContextHolderUtil.getDataSpaceId();

        //@Signature 上面拦截的方法 的参数
        final Object[] args = invocation.getArgs();
        MappedStatement ms = (MappedStatement) args[0];
        Object parameter = args[1];
        //这个巨重要 不然会 新sql成功生成, 但mybatis运行的是旧sql
        BoundSql boundSql;
        if(args.length == 6){
            // 6 个参数时
            boundSql = (BoundSql) args[5];
        } else {
            boundSql = ms.getBoundSql(parameter);
        }



        //如果查询没有 地区 多租户 不生成新sql
        if (SqlCommandType.SELECT == ms.getSqlCommandType() && !ContextHolderUtil.isDataSpace() && !ContextHolderUtil.isDataPermission()){
            return invocation.proceed();
        }



        String processSql = boundSql.getSql();
        log.debug("替换前SQL:{}", processSql);

        //TODO 核心 语法分析生成新sql
        String newSql = getNewSql(processSql, tenantFieldValue);

        log.debug("替换后SQL:{}", newSql);



        //通过反射修改sql字段
        MetaObject boundSqlMeta = MetaObject.forObject(boundSql, new DefaultObjectFactory(), new DefaultObjectWrapperFactory(), new DefaultReflectorFactory());
        // 把新sql设置到boundSql
        boundSqlMeta.setValue("sql", newSql);


        // 重新new一个查询语句对象
        BoundSql newBoundSql = new BoundSql(ms.getConfiguration(), newSql,
                boundSql.getParameterMappings(), boundSql.getParameterObject());
        // 把新的查询放到statement里
        MappedStatement newMs = newMappedStatement(ms, new BoundSqlSqlSource(newBoundSql));
        for (ParameterMapping mapping : boundSql.getParameterMappings()) {
            String prop = mapping.getProperty();
            if (boundSql.hasAdditionalParameter(prop)) {
                newBoundSql.setAdditionalParameter(prop, boundSql.getAdditionalParameter(prop));
            }
        }
        args[0] = newMs;

        return invocation.proceed();
    }




    @Override
    public void setProperties(Properties properties) {
        tenantIdField = properties.getProperty("tenantIdField");
        if (StringUtils.isBlank(tenantIdField)) {
            throw new IllegalArgumentException("MultiTenantPlugin need tenantIdField property value");
        }


        String tableNames = properties.getProperty("tableNames");
        if (!StringUtils.isBlank(tableNames)) {
            tableSet = new HashSet<>(Arrays.asList(StringUtils.split(tableNames, ",")));
        }


        // 多租户条件字段决策器
        TableFieldConditionDecision conditionDecision = new TableFieldConditionDecision() {
            @Override
            public boolean isAllowNullValue() {
                return false;
            }
            @Override
            public boolean adjudge(String tableName, String fieldName) {
                // 去除反引号
                tableName = tableName.replace("`", "");
                return tableSet != null && tableSet.contains(tableName);
            }
        };
        sqlFieldHelper = new SqlFieldHelper(conditionDecision);
        conditionHelper = new SqlConditionHelper();
    }






    /**
     * 定义一个内部辅助类,作用是包装 SQL
     */
    static class BoundSqlSqlSource implements SqlSource {

        private final BoundSql boundSql;

        public BoundSqlSqlSource(BoundSql boundSql) {
            this.boundSql = boundSql;
        }

        @Override
        public BoundSql getBoundSql(Object parameterObject) {
            return boundSql;
        }

    }


    /**
     * 根据原MappedStatement更新SqlSource生成新MappedStatement
     *
     * @param ms MappedStatement
     * @param newSqlSource 新SqlSource
     * @return
     */
    private MappedStatement newMappedStatement(MappedStatement ms, SqlSource newSqlSource) {
        MappedStatement.Builder builder = new
                MappedStatement.Builder(ms.getConfiguration(), ms.getId(), newSqlSource, ms.getSqlCommandType());
        builder.resource(ms.getResource());
        builder.fetchSize(ms.getFetchSize());
        builder.statementType(ms.getStatementType());
        builder.keyGenerator(ms.getKeyGenerator());
        if (ms.getKeyProperties() != null && ms.getKeyProperties().length > 0) {
            builder.keyProperty(ms.getKeyProperties()[0]);
        }
        builder.timeout(ms.getTimeout());
        builder.parameterMap(ms.getParameterMap());
        builder.resultMaps(ms.getResultMaps());
        builder.resultSetType(ms.getResultSetType());
        builder.cache(ms.getCache());
        builder.flushCacheRequired(ms.isFlushCacheRequired());
        builder.useCache(ms.isUseCache());
        return builder.build();
    }




    /**
     * 给sql语句where添加租户id过滤条件
     *
     * @param sql      要添加过滤条件的sql语句
     * @param tenantFieldValue 当前的租户id
     * @return 添加条件后的sql语句
     */
    private String getNewSql(String sql, String tenantFieldValue) {

        List<SQLStatement> statementList = SQLUtils.parseStatements(sql, JdbcConstants.MYSQL);
        if (statementList.size() == 0) {
            return sql;
        }

        SQLStatement sqlStatement = statementList.get(0);

        //新增,修改 字段赋值 (创建人 时间,修改人 时间,多租户字段)
        sqlFieldHelper.addStatementField(sqlStatement, tenantIdField, tenantFieldValue);

        //多地区  多租户
        //查询、修改、删除  where条件添加多租户
        if (ContextHolderUtil.isDataSpace()){
//            conditionHelper.addStatementCondition(sqlStatement, tenantIdField, tenantFieldValue);
            conditionHelper.addStatementCondition(sqlStatement, (String tableName, String tableAlias) ->{
                // 去除反引号
                tableName = tableName.replace("`", "");
                if (tableSet != null && tableSet.contains(tableName)){
                    String fieldName = StringUtils.isBlank(tableAlias) ? tableName + "." + tenantIdField : tableAlias + "." + tenantIdField;
                    return SqlConditionJointHelper.jointEqSql(fieldName, tenantFieldValue);
                }
                return null;
            });
        }


        //数据权限条件
        if (ContextHolderUtil.isDataPermission()){
            List<SqlConditionDTO> dataPermissionSqlConditionList = ContextHolderUtil.getDataPermissionSqlConditionList();
            Map<String, List<SqlConditionDTO>> groupMap = dataPermissionSqlConditionList.stream().collect(Collectors.groupingBy(SqlConditionDTO::getTableName));


            conditionHelper.addStatementCondition(sqlStatement, (String tableName, String tableAlias) ->{
                // 去除反引号
                tableName = tableName.replace("`", "");
                String sqlCondition = null;
                List<SqlConditionDTO> sqlConditionDTOList = groupMap.get(tableName);
                if (sqlConditionDTOList != null && sqlConditionDTOList.size() > 0){
                    for (SqlConditionDTO item : sqlConditionDTOList) {
                        String fieldName = StringUtils.isBlank(tableAlias) ? tableName + "." + item.getFieldName() : tableAlias + "." + item.getFieldName();
                        if (sqlCondition == null){
                            sqlCondition = SqlConditionJointHelper.joint(fieldName, item.getOperation(), item.getFieldValue(), item.getFieldValueList());
                        } else {
                            sqlCondition = sqlCondition + " and " + SqlConditionJointHelper.joint(fieldName, item.getOperation(), item.getFieldValue(), item.getFieldValueList());
                        }
                    }
                }
                return sqlCondition;
            });
        }


        String newSql = SQLUtils.toSQLString(statementList, JdbcConstants.MYSQL);
        //去掉自动加上去的 \
        return newSql.replaceAll("\\\\", "");
    }


}

你可能感兴趣的:(java,mybatis,java,spring)