使用案例:
@GetMapping("/test")
//开启多租户 (方法 或 controller类上使用)
@DataSpace
//开启数据权限处理策略 aaaa数据权限策略 和 bbbb数据权限策略
@DataPermission({DataPermissionEnum.aaaa, DataPermissionEnum.bbbb})
public Object test() {
return null;
}
//aaaa数据权限策略
public class Aaa_DataPermissionHandler implements DataPermissionConditionHandler {
@Override
public List<SqlConditionDTO> handle() {
List<Integer> groupIds = ContextHolderUtil.getDataPermissionList();
return List.of(
SqlConditionDTO.eq("table_a", "id", 1),
SqlConditionDTO.in("table_b", "id", groupIds),
SqlConditionDTO.jsonContains("table_c", "group_ids", groupIds)
);
}
}
/**
* 开启 地区 多租户
* @author xiaopeng
* @date 2021-06-09
*/
@Target({ElementType.METHOD, ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
public @interface DataSpace {
}
public class ContextHolderUtil {
/**
* 获取request
* @return
*/
public static HttpServletRequest getRequest() {
var obj = ((ServletRequestAttributes) RequestContextHolder.getRequestAttributes());
return obj == null ? null : obj.getRequest();
}
/**
* 获取httpSession
* @return
*/
public static HttpSession getSession() {
var request = getRequest();
return request == null ? null : request.getSession();
}
/**
* 获取request中的值
* @param key request的key
* @return
*/
public static Object getRequestAttribute(String key) {
var request = getRequest();
return request == null ? null : request.getAttribute(key);
}
/**
* 获取 Token 中的用户 id
* @return
*/
public static Integer getAuthedUserId() {
var requestArr = getRequestAttribute("id");
return Integer.parseInt(requestArr.toString());
}
//获取角色id
public static Integer getAuthedUserRoleId() {
return Integer.parseInt(String.valueOf(getRequestAttribute("roleId")));
}
//获取角色权重
public static Integer getAuthedUserRoleWeight() {
return Integer.parseInt(String.valueOf(getRequestAttribute("roleWeight")));
}
//获取用户信息
public static UserInfo getAuthedUserInfo() {
return (UserInfo) getRequestAttribute("userInfo");
}
private static String openDataSpace = "open";
//开启 地区 多租户模式
public static void openDataSpace(){
HttpServletRequest request = getRequest();
request.setAttribute("openDataSpace", "open");
}
//判断是否开启 地区 多租户模式
public static boolean isDataSpace(){
HttpServletRequest request = getRequest();
return openDataSpace.equals(request.getAttribute("openDataSpace"));
}
//获取前端传过来的 地区id
public static String getDataSpaceId(){
HttpServletRequest request = getRequest();
String dataSpaceId = request.getHeader("SpaceId");
if (StringUtil.isEmpty(dataSpaceId)){
//请求头拿不到的话, 去url参数那里试一下
dataSpaceId = request.getParameter("SpaceId");
}
if (isDataSpace()){
//未选择地区id 强制抛异常退出
CommonUtil.expressionThrowResponseCodeAndSetAttribute(StringUtil.isEmpty(dataSpaceId), ResponseCode.NOT_REGION);
}
return dataSpaceId;
}
}
@Component
@Aspect
@Slf4j
public class DataSpaceAspect {
@Pointcut("@annotation(com.xxxx.xxxx.common.annotation.DataSpace)")
public void pointcut(){
}
@Pointcut("execution(* com.xxxx.xxxx..*Controller.*(..))")
public void pointcutClass(){
}
@Before("pointcut()")
public void openDataSpace(){
ContextHolderUtil.openDataSpace();
}
/**
* 拦截所有control接口
*/
@Around("pointcutClass()")
public Object openDataSpace(ProceedingJoinPoint joinPoint) throws Throwable {
Class<?> clazz = joinPoint.getTarget().getClass();
//如果类上面带 DataSpace 注解,开启多租户模式
DataSpace annotation = clazz.getAnnotation(DataSpace.class);
if (annotation != null) {
ContextHolderUtil.openDataSpace();
}
return joinPoint.proceed();
}
}
tenant.id-field=region_id
tenant.table=tfmes_sales_order,tfmes_work_procedure_data
@Configuration
@PropertySource("classpath:multi-tenant.properties")
public class TfmesMybatisConfig {
/**
* 多租户字段名称
*/
@Value("${tenant.id-field}")
private String tenantIdField;
/**
* 需要识别多租户字段的表名称列表
*/
@Value("${tenant.table}")
private String tableSet;
@Bean
public String interceptorTfmes(SqlSessionFactory sqlSessionFactory) {
//如果多数据源 配置了多个sqlSessionFactory public String interceptorTfmes(@Qualifier("xxxxSqlSessionFactory") SqlSessionFactory sqlSessionFactory) {
MultiTenantPlugin multiTenantPlugin = new MultiTenantPlugin();
Properties properties = new Properties(4);
properties.put("tenantIdField", tenantIdField);
properties.put("tableNames", tableSet);
multiTenantPlugin.setProperties(properties);
// 可添加多个mybatis拦截器
sqlSessionFactory.getConfiguration().addInterceptor(multiTenantPlugin);
return "interceptorTfmes";
}
}
@Slf4j
@Intercepts({@Signature(
type = Executor.class,
method = "update",
args = {MappedStatement.class, Object.class}
), @Signature(
type = Executor.class,
method = "query",
args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}
), @Signature(
type = Executor.class,
method = "query",
args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class, CacheKey.class, BoundSql.class}
)})
//@Component 不要这个 MybatisConfig.java 已经注入了
public class TfmesMultiTenantPlugin implements Interceptor {
/**
* 多租户字段名称
*/
private String tenantIdField;
/**
* 需要识别多租户字段的表名称列表
*/
private Set<String> tableSet;
//新增,编辑 时候字段赋值
private SqlFieldHelper sqlFieldHelper;
//sql条件生成
private SqlConditionHelper conditionHelper;
@Override
public Object intercept(Invocation invocation) throws Throwable {
//获取前端传过来的 地区id
String tenantFieldValue = ContextHolderUtil.getDataSpaceId();
//@Signature 上面拦截的方法 的参数
final Object[] args = invocation.getArgs();
MappedStatement ms = (MappedStatement) args[0];
Object parameter = args[1];
//这个巨重要 不然会 新sql成功生成, 但mybatis运行的是旧sql
BoundSql boundSql;
if(args.length == 6){
// 6 个参数时
boundSql = (BoundSql) args[5];
} else {
boundSql = ms.getBoundSql(parameter);
}
//如果查询没有 地区 多租户 不生成新sql
if (SqlCommandType.SELECT == ms.getSqlCommandType() && !ContextHolderUtil.isDataSpace()){
return invocation.proceed();
}
String processSql = boundSql.getSql();
log.debug("替换前SQL:{}", processSql);
//TODO 核心 语法分析生成新sql
String newSql = getNewSql(processSql, tenantFieldValue);
log.debug("替换后SQL:{}", newSql);
//通过反射修改sql字段
MetaObject boundSqlMeta = MetaObject.forObject(boundSql, new DefaultObjectFactory(), new DefaultObjectWrapperFactory(), new DefaultReflectorFactory());
// 把新sql设置到boundSql
boundSqlMeta.setValue("sql", newSql);
// 重新new一个查询语句对象
BoundSql newBoundSql = new BoundSql(ms.getConfiguration(), newSql,
boundSql.getParameterMappings(), boundSql.getParameterObject());
// 把新的查询放到statement里
MappedStatement newMs = newMappedStatement(ms, new BoundSqlSqlSource(newBoundSql));
for (ParameterMapping mapping : boundSql.getParameterMappings()) {
String prop = mapping.getProperty();
if (boundSql.hasAdditionalParameter(prop)) {
newBoundSql.setAdditionalParameter(prop, boundSql.getAdditionalParameter(prop));
}
}
args[0] = newMs;
return invocation.proceed();
}
@Override
public void setProperties(Properties properties) {
tenantIdField = properties.getProperty("tenantIdField");
if (StringUtils.isBlank(tenantIdField)) {
throw new IllegalArgumentException("MultiTenantPlugin need tenantIdField property value");
}
String tableNames = properties.getProperty("tableNames");
if (!StringUtils.isBlank(tableNames)) {
tableSet = new HashSet<>(Arrays.asList(StringUtils.split(tableNames, ",")));
}
// 多租户条件字段决策器
TableFieldConditionDecision conditionDecision = new TableFieldConditionDecision() {
@Override
public boolean isAllowNullValue() {
return false;
}
@Override
public boolean adjudge(String tableName, String fieldName) {
// 去除反引号
tableName = tableName.replace("`", "");
return tableSet != null && tableSet.contains(tableName);
}
};
sqlFieldHelper = new SqlFieldHelper(conditionDecision);
conditionHelper = new SqlConditionHelper();
}
/**
* 定义一个内部辅助类,作用是包装 SQL
*/
static class BoundSqlSqlSource implements SqlSource {
private final BoundSql boundSql;
public BoundSqlSqlSource(BoundSql boundSql) {
this.boundSql = boundSql;
}
@Override
public BoundSql getBoundSql(Object parameterObject) {
return boundSql;
}
}
/**
* 根据原MappedStatement更新SqlSource生成新MappedStatement
*
* @param ms MappedStatement
* @param newSqlSource 新SqlSource
* @return
*/
private MappedStatement newMappedStatement(MappedStatement ms, SqlSource newSqlSource) {
MappedStatement.Builder builder = new
MappedStatement.Builder(ms.getConfiguration(), ms.getId(), newSqlSource, ms.getSqlCommandType());
builder.resource(ms.getResource());
builder.fetchSize(ms.getFetchSize());
builder.statementType(ms.getStatementType());
builder.keyGenerator(ms.getKeyGenerator());
if (ms.getKeyProperties() != null && ms.getKeyProperties().length > 0) {
builder.keyProperty(ms.getKeyProperties()[0]);
}
builder.timeout(ms.getTimeout());
builder.parameterMap(ms.getParameterMap());
builder.resultMaps(ms.getResultMaps());
builder.resultSetType(ms.getResultSetType());
builder.cache(ms.getCache());
builder.flushCacheRequired(ms.isFlushCacheRequired());
builder.useCache(ms.isUseCache());
return builder.build();
}
/**
* 给sql语句where添加租户id过滤条件
*
* @param sql 要添加过滤条件的sql语句
* @param tenantFieldValue 当前的租户id
* @return 添加条件后的sql语句
*/
private String getNewSql(String sql, String tenantFieldValue) {
List<SQLStatement> statementList = SQLUtils.parseStatements(sql, JdbcConstants.MYSQL);
if (statementList.size() == 0) {
return sql;
}
SQLStatement sqlStatement = statementList.get(0);
//新增,修改 字段赋值 (创建人 时间,修改人 时间,多租户字段)
sqlFieldHelper.addStatementField(sqlStatement, tenantIdField, tenantFieldValue);
//多地区 多租户
//查询、修改、删除 where条件添加多租户
if (ContextHolderUtil.isDataSpace()){
// conditionHelper.addStatementCondition(sqlStatement, tenantIdField, tenantFieldValue);
conditionHelper.addStatementCondition(sqlStatement, (String tableName, String tableAlias) ->{
// 去除反引号
tableName = tableName.replace("`", "");
if (tableSet != null && tableSet.contains(tableName)){
String fieldName = StringUtils.isBlank(tableAlias) ? tableName + "." + tenantIdField : tableAlias + "." + tenantIdField;
return SqlConditionJointHelper.jointEqSql(fieldName, tenantFieldValue);
}
return null;
});
}
String newSql = SQLUtils.toSQLString(statementList, JdbcConstants.MYSQL);
//去掉自动加上去的 \
return newSql.replaceAll("\\\\", "");
}
}
public class SqlFieldHelper {
private final TableFieldConditionDecision conditionDecision;
public SqlFieldHelper(TableFieldConditionDecision conditionDecision) {
this.conditionDecision = conditionDecision;
}
/**
* 为sql语句添加指定字段
*
* @param sqlStatement
* @param fieldName
* @param fieldValue
*/
public void addStatementField(SQLStatement sqlStatement, String fieldName, String fieldValue) {
if (sqlStatement instanceof SQLInsertStatement) {
SQLInsertStatement insertStatement = (SQLInsertStatement) sqlStatement;
addInsertStatementField(insertStatement, fieldName, fieldValue);
} else if (sqlStatement instanceof SQLUpdateStatement) {
SQLUpdateStatement updateStatement = (SQLUpdateStatement) sqlStatement;
addUpdateStatementField(updateStatement, fieldName, fieldValue);
}
}
/**
* 为insert语句添加 字段
*
* @param insertStatement
* @param fieldName
* @param fieldValue
*/
private void addInsertStatementField(SQLInsertStatement insertStatement, String fieldName, String fieldValue) {
List<SQLExpr> columns = insertStatement.getColumns();
//list是存在批量插入的情况
List<SQLInsertStatement.ValuesClause> valuesList = insertStatement.getValuesList();
//多租户字段
if (ContextHolderUtil.isDataSpace() && conditionDecision.adjudge(insertStatement.getTableName().getSimpleName(), fieldName)){
addInsertItemField(columns, valuesList, fieldName, fieldValue);
}
addInsertItemField(columns, valuesList, "created_by", ContextHolderUtil.getAuthedUserId());
addInsertItemField(columns, valuesList, "created_at", getCurrentTime());
}
private void addInsertItemField(List<SQLExpr> columns, List<SQLInsertStatement.ValuesClause> valuesList, String fieldName, Object fieldValue) {
if (fieldName == null || fieldValue == null){
return;
}
for (int i = 0 ; i < columns.size() ; i++) {
SQLIdentifierExpr column = (SQLIdentifierExpr) columns.get(i);
//如果字段名字匹配
if (column.getName().equals(fieldName)){
for (SQLInsertStatement.ValuesClause valuesClause : valuesList) {
List<SQLExpr> values = valuesClause.getValues();
SQLExpr valExpr = values.get(i);
//如果值是空的就设置
if (valExpr instanceof SQLNullExpr){
values.set(i, getValuableExpr(fieldValue));
}
}
return;
}
}
//如果没有匹配, 加入这个字段
columns.add(new SQLIdentifierExpr(fieldName));
for (SQLInsertStatement.ValuesClause valuesClause : valuesList) {
List<SQLExpr> values = valuesClause.getValues();
values.add(getValuableExpr(fieldValue));
}
}
/**
* 为update语句添加 字段
*
* @param updateStatement
* @param fieldName
* @param fieldValue
*/
private void addUpdateStatementField(SQLUpdateStatement updateStatement, String fieldName, Object fieldValue) {
List<SQLUpdateSetItem> items = updateStatement.getItems();
addUpdateItemField(items, "updated_by", ContextHolderUtil.getAuthedUserId());
addUpdateItemField(items, "updated_at", getCurrentTime());
}
private void addUpdateItemField(List<SQLUpdateSetItem> items, String fieldName, Object fieldValue) {
if (fieldName == null || fieldValue == null){
return;
}
for (SQLUpdateSetItem item : items) {
// if (((SQLIdentifierExpr)item.getColumn()).getName().equals(fieldName) && item.getValue() instanceof SQLNullExpr) {
// item.setValue(getValuableExpr(fieldValue));
// return;
// }
if (((SQLIdentifierExpr)item.getColumn()).getName().equals(fieldName)) {
return;
}
}
//如果没有匹配, 加入这个字段
SQLUpdateSetItem sqlUpdateSetItem = new SQLUpdateSetItem();
sqlUpdateSetItem.setColumn(new SQLIdentifierExpr(fieldName));
sqlUpdateSetItem.setValue(getValuableExpr(fieldValue));
items.add(sqlUpdateSetItem);
}
/**
* 封装SQLValuableExpr
* @param val
* @return
*/
private SQLValuableExpr getValuableExpr(Object val){
if (val == null){
return new SQLNullExpr();
} else if (val instanceof Number){
return new SQLNumberExpr((Number)val);
} else if (val instanceof String){
return new SQLCharExpr((String)val);
}else {
return new SQLCharExpr(val.toString());
}
}
/**
* 获取当前时间
*/
private final SimpleDateFormat sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");
private String getCurrentTime(){
return sdf.format(new Date());
}
public static void main(String[] args) {
// String sql = "select * from user s ";
// String sql = "select * from user s where s.name='333'";
// String sql = "select * from (select * from tab t where id = 2 and name = 'wenshao') s where s.name='333'";
// String sql="select u.*,g.name from user u join user_group g on u.groupId=g.groupId where u.name='123'";
String sql = "update user set name=? where id =(select id from user s)";
// String sql = "delete from user where id = ( select id from user s )";
// String sql = "insert into user (id,name) select g.id,g.name from user_group g where id=1";
// String sql = "insert into user (id,name) values (1, null), (2, '小白')";
// String sql = "SELECT id, pw_id, warehouse_top, warehouse_level, warehouse_name , is_default, can_to, disabled, data_space_id, created_at , created_by, updated_at, updated_by FROM erp_warehouses_warehouse WHERE disabled = 0 AND is_default = 1 AND warehouse_top = ( SELECT id FROM erp_warehouses_warehouse WHERE is_default = 1 AND disabled = 0 )";
List<SQLStatement> statementList = SQLUtils.parseStatements(sql, JdbcConstants.MYSQL);
SQLStatement sqlStatement = statementList.get(0);
SqlFieldHelper sqlFieldHelper = new SqlFieldHelper(new TableFieldConditionDecision() {
@Override
public boolean adjudge(String tableName, String fieldName) {
return true;
}
@Override
public boolean isAllowNullValue() {
return false;
}
});
sqlFieldHelper.addStatementField(sqlStatement, "regionId", "2");
//添加多租户条件,domain是字段ignc,yay是筛选值
System.out.println("源sql:" + sql);
System.out.println("修改后sql:" + SQLUtils.toSQLString(statementList, JdbcConstants.MYSQL));
}
}
public class SqlConditionHelper {
/**
* 为sql语句添加指定where条件
*
* @param sqlStatement
* @param sqlConditionGenerate
*/
public void addStatementCondition(SQLStatement sqlStatement, SqlConditionGenerate sqlConditionGenerate) {
if (sqlStatement instanceof SQLSelectStatement) {
SQLSelectQueryBlock queryObject = (SQLSelectQueryBlock) ((SQLSelectStatement) sqlStatement).getSelect().getQuery();
addSelectStatementCondition(queryObject, queryObject.getFrom(), sqlConditionGenerate);
} else if (sqlStatement instanceof SQLUpdateStatement) {
SQLUpdateStatement updateStatement = (SQLUpdateStatement) sqlStatement;
addUpdateStatementCondition(updateStatement, sqlConditionGenerate);
} else if (sqlStatement instanceof SQLDeleteStatement) {
SQLDeleteStatement deleteStatement = (SQLDeleteStatement) sqlStatement;
addDeleteStatementCondition(deleteStatement, sqlConditionGenerate);
} else if (sqlStatement instanceof SQLInsertStatement) {
SQLInsertStatement insertStatement = (SQLInsertStatement) sqlStatement;
addInsertStatementCondition(insertStatement, sqlConditionGenerate);
}
}
/**
* 为insert语句添加where条件
*
* @param insertStatement
* @param sqlConditionGenerate
*/
private void addInsertStatementCondition(SQLInsertStatement insertStatement, SqlConditionGenerate sqlConditionGenerate) {
if (insertStatement != null) {
SQLSelect sqlSelect = insertStatement.getQuery();
if (sqlSelect != null) {
SQLSelectQueryBlock selectQueryBlock = (SQLSelectQueryBlock) sqlSelect.getQuery();
addSelectStatementCondition(selectQueryBlock, selectQueryBlock.getFrom(), sqlConditionGenerate);
}
}
}
/**
* 为delete语句添加where条件
*
* @param deleteStatement
* @param sqlConditionGenerate
*/
private void addDeleteStatementCondition(SQLDeleteStatement deleteStatement, SqlConditionGenerate sqlConditionGenerate) {
SQLExpr where = deleteStatement.getWhere();
//添加子查询中的where条件
addSQLExprCondition(where, sqlConditionGenerate);
SQLExpr newCondition = newEqualityCondition(deleteStatement.getTableName().getSimpleName(),
deleteStatement.getTableSource().getAlias(), sqlConditionGenerate, where);
deleteStatement.setWhere(newCondition);
}
/**
* where中添加指定筛选条件
*
* @param where 源where条件
* @param sqlConditionGenerate
*/
private void addSQLExprCondition(SQLExpr where, SqlConditionGenerate sqlConditionGenerate) {
if (where instanceof SQLInSubQueryExpr) {
SQLInSubQueryExpr inWhere = (SQLInSubQueryExpr) where;
SQLSelect subSelectObject = inWhere.getSubQuery();
SQLSelectQueryBlock subQueryObject = (SQLSelectQueryBlock) subSelectObject.getQuery();
addSelectStatementCondition(subQueryObject, subQueryObject.getFrom(), sqlConditionGenerate);
} else if (where instanceof SQLBinaryOpExpr) {
SQLBinaryOpExpr opExpr = (SQLBinaryOpExpr) where;
SQLExpr left = opExpr.getLeft();
SQLExpr right = opExpr.getRight();
addSQLExprCondition(left, sqlConditionGenerate);
addSQLExprCondition(right, sqlConditionGenerate);
} else if (where instanceof SQLQueryExpr) {
SQLSelectQueryBlock selectQueryBlock = (SQLSelectQueryBlock) (((SQLQueryExpr) where).getSubQuery()).getQuery();
addSelectStatementCondition(selectQueryBlock, selectQueryBlock.getFrom(), sqlConditionGenerate);
}
}
/**
* 为update语句添加where条件
*
* @param updateStatement
* @param sqlConditionGenerate
*/
private void addUpdateStatementCondition(SQLUpdateStatement updateStatement, SqlConditionGenerate sqlConditionGenerate) {
SQLExpr where = updateStatement.getWhere();
//添加子查询中的where条件
addSQLExprCondition(where, sqlConditionGenerate);
SQLExpr newCondition = newEqualityCondition(updateStatement.getTableName().getSimpleName(),
updateStatement.getTableSource().getAlias(), sqlConditionGenerate, where);
updateStatement.setWhere(newCondition);
}
/**
* 给一个查询对象添加一个where条件
*
* @param queryObject
* @param sqlConditionGenerate
*/
private void addSelectStatementCondition(SQLSelectQueryBlock queryObject, SQLTableSource from, SqlConditionGenerate sqlConditionGenerate) {
if (from == null || queryObject == null) {
return;
}
SQLExpr originCondition = queryObject.getWhere();
// 添加子查询中的where条件
addSQLExprCondition(originCondition, sqlConditionGenerate);
if (from instanceof SQLExprTableSource) {
// TODO 对于JOIN_TABLE支持有问题,待优化
String tableName = ((SQLIdentifierExpr) ((SQLExprTableSource) from).getExpr()).getName();
String alias = from.getAlias();
SQLExpr newCondition = newEqualityCondition(tableName, alias, sqlConditionGenerate, originCondition);
queryObject.setWhere(newCondition);
} else if (from instanceof SQLJoinTableSource) {
SQLJoinTableSource joinObject = (SQLJoinTableSource) from;
SQLTableSource left = joinObject.getLeft();
SQLTableSource right = joinObject.getRight();
addSelectStatementCondition(queryObject, left, sqlConditionGenerate);
addSelectStatementCondition(queryObject, right, sqlConditionGenerate);
} else if (from instanceof SQLSubqueryTableSource) {
SQLSelect subSelectObject = ((SQLSubqueryTableSource) from).getSelect();
SQLSelectQueryBlock subQueryObject = (SQLSelectQueryBlock) subSelectObject.getQuery();
addSelectStatementCondition(subQueryObject, subQueryObject.getFrom(), sqlConditionGenerate);
} else {
throw new NotImplementedException("未处理的异常");
}
}
/**
* 根据原来的condition创建一个新的condition
*
* @param tableName 表名称
* @param tableAlias 表别名 字段名
* @param sqlConditionGenerate 字段值
* @param originCondition 原始条件
* @return
*/
private SQLExpr newEqualityCondition(String tableName, String tableAlias, SqlConditionGenerate sqlConditionGenerate, SQLExpr originCondition) {
String newConditionSql = sqlConditionGenerate.generate(tableName, tableAlias);
if (StringUtil.isEmpty(newConditionSql)){
return originCondition;
} else {
return SQLUtils.buildCondition(SQLBinaryOperator.BooleanAnd, new SQLVariantRefExpr(newConditionSql), false, originCondition);
}
}
public static void main(String[] args) {
// String sql = "select * from user s ";
// String sql = "select * from user s where s.name='333'";
// String sql = "select * from (select * from tab t where id = 2 and name = 'wenshao') s where s.name='333'";
// String sql="select u.*,g.name from user u join user_group g on u.groupId=g.groupId where u.name='123'";
// String sql = "update user set name=? where id =(select id from user s)";
// String sql = "delete from user where id = ( select id from user s )";
String sql = "insert into user (id,name) select g.id,g.name from user_group g where id=1";
// String sql = "SELECT id, pw_id, warehouse_top, warehouse_level, warehouse_name , is_default, can_to, disabled, data_space_id, created_at , created_by, updated_at, updated_by FROM erp_warehouses_warehouse WHERE disabled = 0 AND is_default = 1 AND warehouse_top = ( SELECT id FROM erp_warehouses_warehouse WHERE is_default = 1 AND disabled = 0 )";
List<SQLStatement> statementList = SQLUtils.parseStatements(sql, JdbcConstants.MYSQL);
SQLStatement sqlStatement = statementList.get(0);
SQLInsertStatement sqlSelectStatement = (SQLInsertStatement) sqlStatement;
//决策器定义
SqlConditionHelper helper = new SqlConditionHelper();
//添加多租户条件,domain是字段ignc,yay是筛选值
helper.addStatementCondition(sqlStatement, (String tableName, String tableAlias) ->{
String fieldName = "data_space_id";
fieldName = StringUtils.isBlank(tableAlias) ? tableName + "." + fieldName : tableAlias + "." + fieldName;
return fieldName + " = 1";
});
System.out.println("源sql:" + sql);
System.out.println("修改后sql:" + SQLUtils.toSQLString(statementList, JdbcConstants.MYSQL));
}
}
public class SqlConditionJointHelper {
public static final String operation_eq = "operation_eq";
public static final String operation_in = "operation_in";
public static final String operation_jsonContains = "operation_jsonContains";
public static String joint(String fieldName, String operation, Object fieldValue, List<?> fieldValueList){
if (operation_eq.equals(operation)){
return jointEqSql(fieldName, fieldValue);
} else if (operation_in.equals(operation)) {
return jointInSql(fieldName, fieldValueList);
} else if (operation_jsonContains.equals(operation)) {
return jointJsonContainsSql(fieldName, fieldValueList);
}
return null;
}
/**
* 拼接eq条件sql
* @param fieldName
* @param fieldValue
* @return
*/
public static String jointEqSql(String fieldName, Object fieldValue){
if (StringUtil.isEmpty(fieldName) || fieldValue == null){
return null;
}
return fieldName + " = " + fieldValue;
}
/**
* 拼接in条件sql
* @param fieldName
* @param list
* @return
*/
public static String jointInSql(String fieldName, List<?> list){
if (list == null || list.size() == 0){
return "FALSE";
}
String str = StringUtils.join(list, ",");
return fieldName + " in (" + str + ")";
}
/**
* 拼接json数组条件sql
* @param fieldName
* @param list
* @return
*/
public static String jointJsonContainsSql(String fieldName, List<?> list){
if (list == null || list.size() == 0){
return "FALSE";
}
final String separator = " OR ";
StringBuilder sb = new StringBuilder();
for (Object item : list) {
sb.append("JSON_CONTAINS(").append(fieldName).append(", '").append(item).append("')").append(separator);
}
String str = sb.toString();
return "(" + str.substring(0, str.length() - separator.length()) + ")";
}
}
@Data
public class SqlConditionDTO {
//表名
String tableName;
//字段
String fieldName;
//操作
String operation;
//数据
Object fieldValue;
//数据
List<?> fieldValueList;
public static SqlConditionDTO eq(String tableName, String fieldName, Object fieldValue){
SqlConditionDTO sqlConditionDTO = new SqlConditionDTO();
sqlConditionDTO.setOperation(SqlConditionJointHelper.operation_eq);
sqlConditionDTO.setTableName(tableName);
sqlConditionDTO.setFieldName(fieldName);
sqlConditionDTO.setFieldValue(fieldValue);
return sqlConditionDTO;
}
public static SqlConditionDTO in(String tableName, String fieldName, List<?> fieldValueList){
SqlConditionDTO sqlConditionDTO = new SqlConditionDTO();
sqlConditionDTO.setOperation(SqlConditionJointHelper.operation_in);
sqlConditionDTO.setTableName(tableName);
sqlConditionDTO.setFieldName(fieldName);
sqlConditionDTO.setFieldValueList(fieldValueList);
return sqlConditionDTO;
}
public static SqlConditionDTO jsonContains(String tableName, String fieldName, List<?> fieldValueList){
SqlConditionDTO sqlConditionDTO = new SqlConditionDTO();
sqlConditionDTO.setOperation(SqlConditionJointHelper.operation_jsonContains);
sqlConditionDTO.setTableName(tableName);
sqlConditionDTO.setFieldName(fieldName);
sqlConditionDTO.setFieldValueList(fieldValueList);
return sqlConditionDTO;
}
}
/**
* 数据权限mybatis拦截器 sql 生成规则 处理者
*/
public interface DataPermissionConditionHandler {
List<SqlConditionDTO> handle();
}
public enum DataPermissionEnum {
/**
* 用于aaa的数据隔离
*/
aaa(new aaa_DataPermissionHandler(), "用于aaa的数据隔离")
/**
* 用于bbb的数据隔离
*/
bbb(new bbb_DataPermissionHandler(), "用于bbb的数据隔离")
;
/**
* 数据权限mybatis拦截器 sql 生成规则 处理这
*/
private DataPermissionConditionHandler dataPermissionConditionHandler;
/**
* 提示
*/
private String msg;
DataPermissionEnum(DataPermissionConditionHandler dataPermissionConditionHandler, String msg) {
this.dataPermissionConditionHandler = dataPermissionConditionHandler;
this.msg = msg;
}
public DataPermissionConditionHandler getDataPermissionConditionHandler() {
return dataPermissionConditionHandler;
}
public String getMsg(){
return msg;
}
}
@Target({ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
public @interface DataPermission {
DataPermissionEnum[] value();
}
@Component
@Aspect
@Slf4j
public class DataPermissionAspect {
@Pointcut("@annotation(com.kintexgroup.tfmes.common.annotation.DataPermission)")
public void pointcut(){
}
@Before("pointcut()")
public void openDataSpaceMethod(JoinPoint joinPoint){
DataPermission annotation = ((MethodSignature)joinPoint.getSignature())
.getMethod().getAnnotation(DataPermission.class);
DataPermissionEnum[] dataPermissionEnumList = annotation.value();
log.debug("接口 {} ", ContextHolderUtil.getRequest().getRequestURL());
List<SqlConditionDTO> list = new ArrayList<>();
for (DataPermissionEnum dataPermissionEnum : dataPermissionEnumList) {
log.debug("\t -数据权限处理:{}", dataPermissionEnum.getMsg());
List<SqlConditionDTO> sqlConditionDTOList = dataPermissionEnum.getDataPermissionConditionHandler().handle();
if (sqlConditionDTOList != null && sqlConditionDTOList.size() > 0){
list.addAll(sqlConditionDTOList);
}
}
ContextHolderUtil.setDataPermissionSqlConditionList(list);
}
}
public class ContextHolderUtil {
/**
* 是否开启数据权限条件处理
* @return
*/
public static boolean isDataPermission() {
List<SqlConditionDTO> sqlConditionDTOList = getDataPermissionSqlConditionList();
return sqlConditionDTOList != null && sqlConditionDTOList.size() > 0;
}
private static final String SqlConditionDTOList = "SqlConditionDTOList";
//获取数据权限条件生成dto
public static void setDataPermissionSqlConditionList(List<SqlConditionDTO> list) {
getRequest().setAttribute(SqlConditionDTOList, list);
}
public static List<SqlConditionDTO> getDataPermissionSqlConditionList() {
Object target = getRequestAttribute(SqlConditionDTOList);
if (target != null) {
return (List<SqlConditionDTO>) target;
}
return null;
}
}
@Slf4j
@Intercepts({@Signature(
type = Executor.class,
method = "update",
args = {MappedStatement.class, Object.class}
), @Signature(
type = Executor.class,
method = "query",
args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}
), @Signature(
type = Executor.class,
method = "query",
args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class, CacheKey.class, BoundSql.class}
)})
//@Component 不要这个 MybatisConfig.java 已经注入了
public class TfmesMultiTenantPlugin implements Interceptor {
/**
* 多租户字段名称
*/
private String tenantIdField;
/**
* 需要识别多租户字段的表名称列表
*/
private Set<String> tableSet;
//新增,编辑 时候字段赋值
private SqlFieldHelper sqlFieldHelper;
//sql条件生成
private SqlConditionHelper conditionHelper;
@Override
public Object intercept(Invocation invocation) throws Throwable {
//获取前端传过来的 地区id
String tenantFieldValue = ContextHolderUtil.getDataSpaceId();
//@Signature 上面拦截的方法 的参数
final Object[] args = invocation.getArgs();
MappedStatement ms = (MappedStatement) args[0];
Object parameter = args[1];
//这个巨重要 不然会 新sql成功生成, 但mybatis运行的是旧sql
BoundSql boundSql;
if(args.length == 6){
// 6 个参数时
boundSql = (BoundSql) args[5];
} else {
boundSql = ms.getBoundSql(parameter);
}
//如果查询没有 地区 多租户 不生成新sql
if (SqlCommandType.SELECT == ms.getSqlCommandType() && !ContextHolderUtil.isDataSpace() && !ContextHolderUtil.isDataPermission()){
return invocation.proceed();
}
String processSql = boundSql.getSql();
log.debug("替换前SQL:{}", processSql);
//TODO 核心 语法分析生成新sql
String newSql = getNewSql(processSql, tenantFieldValue);
log.debug("替换后SQL:{}", newSql);
//通过反射修改sql字段
MetaObject boundSqlMeta = MetaObject.forObject(boundSql, new DefaultObjectFactory(), new DefaultObjectWrapperFactory(), new DefaultReflectorFactory());
// 把新sql设置到boundSql
boundSqlMeta.setValue("sql", newSql);
// 重新new一个查询语句对象
BoundSql newBoundSql = new BoundSql(ms.getConfiguration(), newSql,
boundSql.getParameterMappings(), boundSql.getParameterObject());
// 把新的查询放到statement里
MappedStatement newMs = newMappedStatement(ms, new BoundSqlSqlSource(newBoundSql));
for (ParameterMapping mapping : boundSql.getParameterMappings()) {
String prop = mapping.getProperty();
if (boundSql.hasAdditionalParameter(prop)) {
newBoundSql.setAdditionalParameter(prop, boundSql.getAdditionalParameter(prop));
}
}
args[0] = newMs;
return invocation.proceed();
}
@Override
public void setProperties(Properties properties) {
tenantIdField = properties.getProperty("tenantIdField");
if (StringUtils.isBlank(tenantIdField)) {
throw new IllegalArgumentException("MultiTenantPlugin need tenantIdField property value");
}
String tableNames = properties.getProperty("tableNames");
if (!StringUtils.isBlank(tableNames)) {
tableSet = new HashSet<>(Arrays.asList(StringUtils.split(tableNames, ",")));
}
// 多租户条件字段决策器
TableFieldConditionDecision conditionDecision = new TableFieldConditionDecision() {
@Override
public boolean isAllowNullValue() {
return false;
}
@Override
public boolean adjudge(String tableName, String fieldName) {
// 去除反引号
tableName = tableName.replace("`", "");
return tableSet != null && tableSet.contains(tableName);
}
};
sqlFieldHelper = new SqlFieldHelper(conditionDecision);
conditionHelper = new SqlConditionHelper();
}
/**
* 定义一个内部辅助类,作用是包装 SQL
*/
static class BoundSqlSqlSource implements SqlSource {
private final BoundSql boundSql;
public BoundSqlSqlSource(BoundSql boundSql) {
this.boundSql = boundSql;
}
@Override
public BoundSql getBoundSql(Object parameterObject) {
return boundSql;
}
}
/**
* 根据原MappedStatement更新SqlSource生成新MappedStatement
*
* @param ms MappedStatement
* @param newSqlSource 新SqlSource
* @return
*/
private MappedStatement newMappedStatement(MappedStatement ms, SqlSource newSqlSource) {
MappedStatement.Builder builder = new
MappedStatement.Builder(ms.getConfiguration(), ms.getId(), newSqlSource, ms.getSqlCommandType());
builder.resource(ms.getResource());
builder.fetchSize(ms.getFetchSize());
builder.statementType(ms.getStatementType());
builder.keyGenerator(ms.getKeyGenerator());
if (ms.getKeyProperties() != null && ms.getKeyProperties().length > 0) {
builder.keyProperty(ms.getKeyProperties()[0]);
}
builder.timeout(ms.getTimeout());
builder.parameterMap(ms.getParameterMap());
builder.resultMaps(ms.getResultMaps());
builder.resultSetType(ms.getResultSetType());
builder.cache(ms.getCache());
builder.flushCacheRequired(ms.isFlushCacheRequired());
builder.useCache(ms.isUseCache());
return builder.build();
}
/**
* 给sql语句where添加租户id过滤条件
*
* @param sql 要添加过滤条件的sql语句
* @param tenantFieldValue 当前的租户id
* @return 添加条件后的sql语句
*/
private String getNewSql(String sql, String tenantFieldValue) {
List<SQLStatement> statementList = SQLUtils.parseStatements(sql, JdbcConstants.MYSQL);
if (statementList.size() == 0) {
return sql;
}
SQLStatement sqlStatement = statementList.get(0);
//新增,修改 字段赋值 (创建人 时间,修改人 时间,多租户字段)
sqlFieldHelper.addStatementField(sqlStatement, tenantIdField, tenantFieldValue);
//多地区 多租户
//查询、修改、删除 where条件添加多租户
if (ContextHolderUtil.isDataSpace()){
// conditionHelper.addStatementCondition(sqlStatement, tenantIdField, tenantFieldValue);
conditionHelper.addStatementCondition(sqlStatement, (String tableName, String tableAlias) ->{
// 去除反引号
tableName = tableName.replace("`", "");
if (tableSet != null && tableSet.contains(tableName)){
String fieldName = StringUtils.isBlank(tableAlias) ? tableName + "." + tenantIdField : tableAlias + "." + tenantIdField;
return SqlConditionJointHelper.jointEqSql(fieldName, tenantFieldValue);
}
return null;
});
}
//数据权限条件
if (ContextHolderUtil.isDataPermission()){
List<SqlConditionDTO> dataPermissionSqlConditionList = ContextHolderUtil.getDataPermissionSqlConditionList();
Map<String, List<SqlConditionDTO>> groupMap = dataPermissionSqlConditionList.stream().collect(Collectors.groupingBy(SqlConditionDTO::getTableName));
conditionHelper.addStatementCondition(sqlStatement, (String tableName, String tableAlias) ->{
// 去除反引号
tableName = tableName.replace("`", "");
String sqlCondition = null;
List<SqlConditionDTO> sqlConditionDTOList = groupMap.get(tableName);
if (sqlConditionDTOList != null && sqlConditionDTOList.size() > 0){
for (SqlConditionDTO item : sqlConditionDTOList) {
String fieldName = StringUtils.isBlank(tableAlias) ? tableName + "." + item.getFieldName() : tableAlias + "." + item.getFieldName();
if (sqlCondition == null){
sqlCondition = SqlConditionJointHelper.joint(fieldName, item.getOperation(), item.getFieldValue(), item.getFieldValueList());
} else {
sqlCondition = sqlCondition + " and " + SqlConditionJointHelper.joint(fieldName, item.getOperation(), item.getFieldValue(), item.getFieldValueList());
}
}
}
return sqlCondition;
});
}
String newSql = SQLUtils.toSQLString(statementList, JdbcConstants.MYSQL);
//去掉自动加上去的 \
return newSql.replaceAll("\\\\", "");
}
}