springboot mybatis(不是mybatis-plus)多租户


public class TenantContext {
    private static final ThreadLocal CURRENT_TENANT = new ThreadLocal<>();

    public static void setTenantId(String tenantId) {
        CURRENT_TENANT.set(tenantId);
    }

    public static String getTenantId() {
        return CURRENT_TENANT.get();
    }

    public static void clear() {
        CURRENT_TENANT.remove();
    }
}

import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.StringValue;
import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.statement.Statement;
import net.sf.jsqlparser.statement.delete.Delete;
import net.sf.jsqlparser.statement.select.PlainSelect;
import net.sf.jsqlparser.statement.select.Select;
import net.sf.jsqlparser.statement.update.Update;
import org.apache.ibatis.cache.CacheKey;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.ParameterMapping;
import org.apache.ibatis.mapping.SqlSource;
import org.apache.ibatis.plugin.*;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.SystemMetaObject;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;


import java.lang.reflect.Field;
import java.util.Map;
import java.util.Properties;

@Intercepts({
    @Signature(type = Executor.class, method = "update", args = {MappedStatement.class, Object.class}),
    @Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}),
    @Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class, CacheKey.class, BoundSql.class}),
})
public class TenantInterceptor implements Interceptor {
    private static final String TENANT_ID_COLUMN = "tenant_id";  // 租户字段名

    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        // 获取当前租户 ID
        String tenantId = TenantContext.getTenantId();
        if (tenantId == null) {
            return invocation.proceed();  // 未设置租户,直接放行
        }

        // 获取 SQL 执行参数
        Object[] args = invocation.getArgs();
        MappedStatement mappedStatement = (MappedStatement) args[0];
        Object parameter = args;

        // 解析原始 SQL
        BoundSql boundSql = mappedStatement.getBoundSql(parameter);
        String originalSql = boundSql.getSql();

        // 解析 SQL 并动态添加租户条件
        String modifiedSql = modifySql(originalSql, tenantId);


        BoundSql newBoundSql = new BoundSql(
                mappedStatement.getConfiguration(),
                modifiedSql,
                boundSql.getParameterMappings(),
                boundSql.getParameterObject());
        MappedStatement newMs = buildMappedStatement(mappedStatement, new BoundSqlSqlSource(newBoundSql));

        for (ParameterMapping mapping : boundSql.getParameterMappings()) {
            String prop = mapping.getProperty();
            if (boundSql.hasAdditionalParameter(prop)) {
                newBoundSql.setAdditionalParameter(prop, boundSql.getAdditionalParameter(prop));
            }
        }

        // 替换 MappedStatement
        invocation.getArgs()[0] = newMs;



       /* // 反射修改 SQL 内容
        Field sqlField = boundSql.getClass().getDeclaredField("sql");
        sqlField.setAccessible(true);
        sqlField.set(boundSql, modifiedSql);*/
/*
        // 注入 tenantId 到参数对象(用于 PreparedStatement 参数替换)
        if (parameter != null) {
            MetaObject metaObject = SystemMetaObject.forObject(parameter);
            metaObject.setValue(TENANT_ID_COLUMN, tenantId);
        }*/

        return invocation.proceed();
    }

    // 修改 SQL 的核心逻辑(使用 JSqlParser 解析 SQL 语法树)
    private String modifySql(String originalSql, String tenantId) {
        try {
            Statement statement = CCJSqlParserUtil.parse(originalSql);
            if (statement instanceof Select) {
                Select select = (Select) statement;
                PlainSelect plainSelect = (PlainSelect) select.getSelectBody();
                addTenantCondition(plainSelect.getWhere(), tenantId, select);
            } else if (statement instanceof Update) {
                Update update = (Update) statement;
                addTenantCondition(update.getWhere(), tenantId, update);
            } else if (statement instanceof Delete) {
                Delete delete = (Delete) statement;
                addTenantCondition(delete.getWhere(), tenantId, delete);
            }
            return statement.toString();
        } catch (JSQLParserException e) {
            throw new RuntimeException("SQL 解析失败", e);
        }
    }

    // 向 WHERE 条件中添加租户过滤
    private void addTenantCondition(Expression where, String tenantId, Statement statement) {
        EqualsTo tenantCondition = new EqualsTo();
        tenantCondition.setLeftExpression(new Column(TENANT_ID_COLUMN));
        tenantCondition.setRightExpression(new StringValue(tenantId));

        if (where == null) {
            if (statement instanceof Update) {
                ((Update) statement).setWhere(tenantCondition);
            } else if (statement instanceof Delete) {
                ((Delete) statement).setWhere(tenantCondition);
            } else if (statement instanceof Select) {
                ((PlainSelect) ((Select) statement).getSelectBody()).setWhere(tenantCondition);
            }
        } else {
            AndExpression and = new AndExpression(where, tenantCondition);
            if (statement instanceof Update) {
                ((Update) statement).setWhere(and);
            } else if (statement instanceof Delete) {
                ((Delete) statement).setWhere(and);
            } else if (statement instanceof Select) {
                ((PlainSelect) ((Select) statement).getSelectBody()).setWhere(and);
            }
        }
    }

    @Override
    public Object plugin(Object target) {
        return Plugin.wrap(target, this);
    }

    @Override
    public void setProperties(Properties properties) {}



    /**
     * 根据已有MappedStatement构造新的MappedStatement
     */
    private MappedStatement buildMappedStatement(MappedStatement ms, SqlSource sqlSource) {
        MappedStatement.Builder builder = new MappedStatement.Builder(ms.getConfiguration(), ms.getId(), sqlSource, ms.getSqlCommandType());
        builder.resource(ms.getResource());
        builder.fetchSize(ms.getFetchSize());
        builder.statementType(ms.getStatementType());
        builder.keyGenerator(ms.getKeyGenerator());
        if (ms.getKeyProperties() != null && ms.getKeyProperties().length != 0) {
            StringBuilder keyProperties = new StringBuilder();
            for (String keyProperty : ms.getKeyProperties()) {
                keyProperties.append(keyProperty).append(",");
            }
            keyProperties.delete(keyProperties.length() - 1, keyProperties.length());
            builder.keyProperty(keyProperties.toString());
        }
        builder.timeout(ms.getTimeout());
        builder.parameterMap(ms.getParameterMap());
        builder.resultMaps(ms.getResultMaps());
        builder.resultSetType(ms.getResultSetType());
        builder.cache(ms.getCache());
        builder.flushCacheRequired(ms.isFlushCacheRequired());
        builder.useCache(ms.isUseCache());
        return builder.build();
    }

    /**
     * 用于构造新MappedStatement
     */
    public static class BoundSqlSqlSource implements SqlSource {
        BoundSql boundSql;

        public BoundSqlSqlSource(BoundSql boundSql) {
            this.boundSql = boundSql;
        }

        @Override
        public BoundSql getBoundSql(Object parameterObject) {
            return boundSql;
        }
    }
}

由于原系统中引用了pagehelper-spring-boot-starter,并且PageInterceptor初始化比较晚,影响自定义多租户拦截器。需要把自定义拦截器的初始化放到最后,spring容器就绪后,初始化自定义mybatis拦截器


import org.apache.ibatis.session.SqlSessionFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.context.event.ApplicationReadyEvent;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.event.EventListener;

@Configuration
public class TenantInterceptorConfig {

    @Autowired
    private SqlSessionFactory sqlSessionFactory;

    @EventListener(ApplicationReadyEvent.class) // 应用启动完成后执行
    public void addInterceptorAfterPageHelper() {
        TenantInterceptor tenantInterceptor = new TenantInterceptor();
        sqlSessionFactory.getConfiguration().addInterceptor(tenantInterceptor);
    }

}

你可能感兴趣的:(spring,boot,mybatis,java)