由于开发的系统为多租户系统并且,技术采用了jpa+hibernate。查阅了hibernate的官方文档,并不支持sql方式对tenant_id 的数据隔离。所以无奈只能自行实现:
项目源码:https://gitee.com/97wx/nm-datasource
https://github.com/alibaba/druid
com.alibaba
druid
1.1.13-preview_3
/**
* 租户sql 服务提供都
*/
public interface TenantSqlProvider {
/**
* sql 替换操作
* @param sql
* @return
*/
String prepareStatement(String sql);
}
/**
* 租户sql
*/
private TenantSqlProvider tenantSqlProvider;
public TenantSqlProvider getTenantSqlProvider() {
return tenantSqlProvider;
}
public void setTenantSqlProvider(TenantSqlProvider tenantSqlProvider) {
this.tenantSqlProvider = tenantSqlProvider;
}
holder.setTenantSqlProvider(tenantSqlProvider);
if(holder.getTenantSqlProvider()!=null){
sql = holder.getTenantSqlProvider().prepareStatement(sql);
}
public class DruidDataSourceAutoConfigure {
private static final Logger LOGGER = LoggerFactory.getLogger(DruidDataSourceAutoConfigure.class);
@Autowired(required = false)
private TenantSqlProvider tenantSqlProvider;
@Bean(initMethod = "init")
@ConditionalOnMissingBean
public DataSource dataSource() {
LOGGER.info("Init DruidDataSource");
DruidDataSourceWrapper wrapper = new DruidDataSourceWrapper();
wrapper.setTenantSqlProvider(tenantSqlProvider);
return wrapper;
}
}
正则表达式:
"(\\(|\\))"
数据结构:
@Data
public class OffsetItem {
public OffsetItem(){}
public OffsetItem(Integer start) {
this.start = start;
}
private String tableName;
private String values;
private String fields;
private Integer start;
private Integer end;
private String subSql;
}
最后提取出来的sql 应该是这样的:
select u.* from user left join member m on m.userId=u.id where u.name=123 and u.p=ii and u.id in(select j.id from job limit 1) and m.id in (select d.id from deparment d where d.userId=u.id and d.jobId in (select j.id from job j)) order by id desc
select j.id from job limit 1
select j.id from job j
若存在子句则进行递归调用,继续分解子句。直到最后一个子句不存在子句时,则可以进行替换
将子句替换后返回到上一层语句,先将当前层的语句,将子句剔除出去,用来提取别名,应该是这样的如果不存在别名则用null代替
将子句剔除后应该是这样的:
select d.id from deparment d where d.userId=u.id and d.jobId in ()
正则表达式:
1)提取表名、左右连接、全连接部分
"((from|update)\\s+\\w+(((\\s+as|\\s+AS)*\\s+(?\\w+))*?)\\s*(where|order|group|set|limit|$))"
2)提取别名通过find 可以获得一个别名数组
((?<=from|join|,|update)\s*\w+(((\s+as|\s+AS)*\s+(?\w+))*?)\s*(?=where|left|right|inner|on|order|group|limit|$|,|set))
最后通过子句的位置可以将当前层的语句完整的拼接出来,然后返回
import lombok.Data;
/***
*
* @Auther: guzhangwen
* @BusinessCard https://blog.csdn.net/gu_zhang_w
* @Date: 2019/1/24
* @Time: 14:01
* @version : V1.0
*/
@Data
public class OffsetItem {
public OffsetItem() {
}
public OffsetItem(Integer start) {
this.start = start;
}
private String tableName;
private String values;
private String fields;
private Integer start;
private Integer end;
private String subSql;
}
/**
* sql 关键字汇总
*/
public interface SqlConstant {
public static final String SELECT = "select";
public static final String UPDATE = "update";
public static final String DELETE = "delete";
public static final String INSERT = "insert";
public static final String WHERE = "where";
public static final String ORDER = "order";
public static final String GROUP = "group";
public static final String LIMIT = "limit";
public static final String AS = "AS";
public static final String as = "as";
public static final String JOIN = "join";
public static final String ON = "on";
public static final String LEFT = "left";
public static final String RIGHT = "right";
public static final String INNER = "inner";
public static final String FROM = "from";
public static final String AND = "and";
}
import com.lvbang.erp.common.support.bean.ConditionItem;
import com.lvbang.erp.common.support.bean.OffsetItem;
import com.lvbang.erp.common.support.bean.SqlConstant;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Stack;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
/***
* 租户sql 帮助类
* @Auther: guzhangwen
* @BusinessCard https://blog.csdn.net/gu_zhang_w
* @Date: 2019/1/24
* @Time: 13:59
* @version : V1.0
*/
@Slf4j
public abstract class TenantSqlSupport {
/**
* 提取配对的"("")"
*/
private static final Pattern PATTERN = Pattern.compile("(\\(|\\))");
/**
* 提取from 到where 之间的别名或者表名
*/
private static final Pattern ALALIS_PATTERN = Pattern.compile("((?<=from|join|,|update)\\s*\\w+(((\\s+as|\\s+AS)*\\s+(?\\w+))*?)\\s*(?=where|left|right|inner|on|order|group|limit|$|,|set))", Pattern.CASE_INSENSITIVE);
/**
* 保存sql 匹配,当前仅仅匹配带字段的sql即:
* insert into tmp(name,age)values (0,'gone'),(1,'123') insert into tmp(name,age)values (0,'gone'),(1,'123')
*/
private static final Pattern INSERT_PATTERN = Pattern.compile("(\\s*insert into\\s+(?\\w+)\\s*\\((?.*?)\\)\\s*value(s)?\\s*(?(\\(.*?\\),?)+?)(?=insert|$|;))", Pattern.CASE_INSENSITIVE);
/**
* 保存的值正则
*/
private static final Pattern INSERT_VALUES_PATTERN = Pattern.compile("((?<=\\().*?(?=\\)))", Pattern.CASE_INSENSITIVE);
/**
* 表的正则
*/
private static final Pattern TABLE_PATTERN = Pattern.compile("((from|update)\\s+\\w+(((\\s+as|\\s+AS)*\\s+(?\\w+))*?)\\s*(where|order|group|set|limit|$))");
/**
* 构建查询语句
*
* @param sql
* @return
*/
public String rebuilderSelectSql(String sql) {
/**
* 不加自定义处理
*/
if (CollectionUtils.isEmpty(getCondtion())) {
return sql;
}
Matcher matcher = PATTERN.matcher(sql);
Stack offsetItemStack = new Stack<>();
List offsetItemStackReverse = new ArrayList<>();
boolean flag = false;
while (matcher.find()) {
String operator = matcher.group(1);
if ("(".equals(operator)) {
OffsetItem offsetItem = new OffsetItem(matcher.end());
offsetItemStack.push(offsetItem);
} else if (")".equals(operator)) {
OffsetItem offsetItem = offsetItemStack.pop();
offsetItem.setEnd(matcher.start());
String subSql = sql.substring(offsetItem.getStart(), offsetItem.getEnd());
if (subSql.trim().startsWith(SqlConstant.SELECT)) {
flag = true;
if (offsetItemStack.empty()) {
checkedSubSql(offsetItem, subSql);
offsetItemStackReverse.add(offsetItem);
} else {
OffsetItem prevOffsetItem = offsetItemStack.peek();
checkedSubSql(prevOffsetItem, subSql);
}
}
}
}
String newSql = null;
if (flag) {
newSql = mergeSql(sql, offsetItemStackReverse);
} else {
List alaisList = getSqlAlais(sql);
newSql = appendToCondition(sql, alaisList);
}
log.debug("==>>构建查询语句:置换后的sql={}", newSql);
return newSql;
}
/**
* 重构更新语句
*
* @param sql
* @return
*/
public String rebuilderUpdateSql(String sql) {
/**
* 不加自定义处理
*/
if (CollectionUtils.isEmpty(getCondtion())) {
return sql;
}
List alaisList = getSqlAlais(sql);
if (CollectionUtils.isEmpty(alaisList)) {
log.error("无法获取别名、请检查sql 匹配规则是否正确,语句={}", sql);
throw new RuntimeException("无法获取别名、请检查sql 匹配规则是否正确");
}
String newSql = appendToCondition(sql, Arrays.asList(alaisList.get(0)));
log.debug("==>>重构更新语句:置换后的sql={}", newSql);
return newSql;
}
/**
* 重构删除语句
*
* @param sql
* @return
*/
public String rebuilderDeletedSql(String sql) {
/**
* 不加自定义处理
*/
if (CollectionUtils.isEmpty(getCondtion())) {
return sql;
}
List alaisList = getSqlAlais(sql);
if (CollectionUtils.isEmpty(alaisList)) {
log.error("无法获取别名、请检查sql 匹配规则是否正确,语句={}", sql);
throw new RuntimeException("无法获取别名、请检查sql 匹配规则是否正确");
}
String newSql = appendToCondition(sql, Arrays.asList(alaisList.get(0)));
log.debug("==>>重构删除语句:置换后的sql={}", newSql);
return newSql;
}
/**
* 构建插入语句
*
* @param sql
* @return
*/
public String rebuilderInsertSql(String sql) {
/**
* 不加自定义处理
*/
List conditionItemList = getCondtion();
if (CollectionUtils.isEmpty(conditionItemList)) {
return sql;
}
try {
Matcher matcher = INSERT_PATTERN.matcher(sql);
List offsetItemList = new ArrayList<>();
while (matcher.find()) {
OffsetItem offsetItem = new OffsetItem();
offsetItem.setTableName(matcher.group("tableName"));
offsetItem.setFields(matcher.group("fields").trim());
offsetItem.setValues(matcher.group("values").trim());
offsetItemList.add(offsetItem);
}
if (CollectionUtils.isNotEmpty(offsetItemList)) {
StringBuffer sb = new StringBuffer();
for (OffsetItem item : offsetItemList) {
if (sb.length() > 0) {
sb.append(";");
}
sb.append("insert into ").append(item.getTableName()).append(" (").append(item.getFields());
String fields = item.getFields();
for (ConditionItem conditionItem : conditionItemList) {
if (fields.contains(conditionItem.getName()) || (StringUtils.isBlank(conditionItem.getName()) || StringUtils.isBlank(conditionItem.getValue()))) {
continue;
}
sb.append(",").append(conditionItem.getName());
}
sb.append(")");
List valueList = getInsertSqlValues(item.getValues());
if (CollectionUtils.isEmpty(valueList)) {
log.error("插入语句重构失败,请检查匹配规则与语法是否正确,sql={},正则表达式={}", sql, INSERT_VALUES_PATTERN.pattern());
throw new RuntimeException("插入语句重构失败,请检查匹配规则与语法是否正确");
}
sb.append("value");
if (valueList.size() > 1) {
sb.append("s");
}
for (int i = 0; i < valueList.size(); i++) {
if (i > 0) {
sb.append(",");
}
sb.append("(").append(valueList.get(i));
for (ConditionItem conditionItem : conditionItemList) {
if (fields.contains(conditionItem.getName()) || (StringUtils.isBlank(conditionItem.getName()) || StringUtils.isBlank(conditionItem.getValue()))) {
continue;
}
sb.append(",").append(conditionItem.getValue());
}
sb.append(")");
}
}
String newSql = sb.toString();
log.debug("==>>构建插入语句:置换后的sql={}", newSql);
return newSql;
}
} catch (Exception e) {
}
return sql;
}
/**
* 获取插入语句的sql
*
* @param values
* @return
*/
private List getInsertSqlValues(String values) {
List valueList = new ArrayList<>();
Matcher matcher = INSERT_VALUES_PATTERN.matcher(values);
while (matcher.find()) {
valueList.add(matcher.group(1));
}
return valueList;
}
/**
* 合并sql
*
* @param originalSql 原sql
* @param offsetItemStackReverse sql子句对象{@link OffsetItem}
* @return
*/
private String mergeSql(String originalSql, List offsetItemStackReverse) {
StringBuffer sb = new StringBuffer();
List subSqlList = new ArrayList<>();
if (CollectionUtils.isNotEmpty(offsetItemStackReverse)) {
OffsetItem prevItem = null;
for (int i = 0; i < offsetItemStackReverse.size(); i++) {
OffsetItem item = offsetItemStackReverse.get(i);
if (sb.length() <= 0) {
sb.append(originalSql.substring(0, item.getStart()));
}
if (prevItem != null && prevItem.getEnd() != item.getStart()) {
sb.append(originalSql.substring(prevItem.getEnd(), item.getStart()));
}
sb.append("%s");
if ((i + 1) == offsetItemStackReverse.size()) {
sb.append(originalSql.substring(item.getEnd(), originalSql.length()));
}
subSqlList.add(item.getSubSql());
prevItem = item;
}
}
//例如 select * from a where uid in(%s) and did in(%s)
String newSql = sb.toString();
List alaisList = getSqlAlais(newSql);
newSql = appendToCondition(newSql, alaisList);
if (CollectionUtils.isNotEmpty(subSqlList)) {
for (String s : subSqlList) {
newSql = newSql.replaceFirst("%s", s);
}
}
return newSql;
}
/**
* 检查subSql 是否还存在子查询
*
* @param offsetItem sql子句对象{@link OffsetItem}
* @param subSql 子查询sql
*/
private void checkedSubSql(OffsetItem offsetItem, String subSql) {
if (appearNumber(subSql, SqlConstant.SELECT) > 1) {
offsetItem.setSubSql(rebuilderSelectSql(subSql));
} else {
List alaisList = getSqlAlais(subSql);
String newSql = appendToCondition(subSql, alaisList);
offsetItem.setSubSql(newSql);
}
}
/**
* 获取指定字符串出现的次数
*
* @param srcText 源字符串
* @param findText 要查找的字符串
* @return
*/
public int appearNumber(String srcText, String findText) {
int count = 0;
Pattern p = Pattern.compile(findText);
Matcher m = p.matcher(srcText);
while (m.find()) {
count++;
}
return count;
}
/**
* 获取sql 语句的别名
*
* @param sql
* @return
*/
public List getSqlAlais(String sql) {
Matcher tableMatcher = TABLE_PATTERN.matcher(sql);
StringBuffer stringBuffer = new StringBuffer();
while (tableMatcher.find()) {
stringBuffer.append(tableMatcher.group(1));
}
Matcher matcher = ALALIS_PATTERN.matcher(stringBuffer.toString());
List alaisList = new ArrayList<>();
while (matcher.find()) {
alaisList.add(matcher.group(SqlConstant.as));
}
return alaisList;
}
/**
* 添加查询过滤条件
*
* @param sql
* @param alaisList
* @return
*/
private String appendToCondition(String sql, List alaisList) {
StringBuffer sb = new StringBuffer();
int offset = 0;
offset = sql.lastIndexOf(SqlConstant.WHERE);
if (offset > 0) {
offset += 5;
sb.append(sql.substring(0, offset)).append(" ");
boolean flag = appendTenant(sb, alaisList);
if (flag) {
sb.append(" and ");
}
sb.append(sql.substring(offset, sql.length()));
} else {
try {
offset = sql.lastIndexOf(SqlConstant.ORDER);
if (offset > 0) {
throw new Exception();
}
offset = sql.lastIndexOf(SqlConstant.GROUP);
if (offset > 0) {
throw new Exception();
}
offset = sql.lastIndexOf(SqlConstant.LIMIT);
if (offset > 0) {
throw new Exception();
}
} catch (Exception e) {
}
if (offset > 0) {
sb.append(sql.substring(0, offset)).append(SqlConstant.WHERE).append(" ");
appendTenant(sb, alaisList);
sb.append(" ").append(sql.substring(offset, sql.length()));
} else {
sb.append(sql).append(" ").append(SqlConstant.WHERE + " ");
appendTenant(sb, alaisList);
}
}
return sb.toString();
}
/**
* 添加租户id
*
* @param sb 拼接的字符串
* @param alaisList sql 中包含的别名
*/
private boolean appendTenant(StringBuffer sb, List alaisList) {
boolean flag = false;
for (int i = 0; i < alaisList.size(); i++) {
String alais = alaisList.get(i);
if (i >= 1) {
sb.append(" " + SqlConstant.AND + " ");
}
List conditionItemList = getCondtion();
if (CollectionUtils.isNotEmpty(conditionItemList)) {
for (ConditionItem item : conditionItemList) {
//已经包含指定字段名称
if (StringUtils.isEmpty(item.getName()) || StringUtils.isEmpty(item.getValue())) {
continue;
}
flag = true;
if (StringUtils.isBlank(item.getName()) || StringUtils.isBlank(item.getValue())) {
log.error("condition name 与 value 都不能为null");
throw new RuntimeException("condition name 与 value 都不能为null");
}
if (alais == null) {
sb.append(String.format("%s=%s", item.getName(), item.getValue()));
} else {
sb.append(String.format("%s.%s=%s", alais, item.getName(), item.getValue()));
}
}
}
}
return flag;
}
/**
* 获取需要附加的查询条件
*
* @return
*/
public abstract List getCondtion();
}
3、TenantId 扩展接口的实现者与默认sql 置换默认实现抽象类的实现者
import com.alibaba.druid.tenant.TenantSqlProvider;
import com.lvbang.erp.common.support.bean.ConditionItem;
import com.lvbang.erp.common.support.bean.SqlConstant;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
/***
*
* @Auther: guzhangwen
* @BusinessCard https://blog.csdn.net/gu_zhang_w
* @Date: 2019/1/22
* @Time: 18:01
* @version : V1.0
*/
public class DefaultTenantSqlProvider extends TenantSqlSupport implements TenantSqlProvider {
@Override
public String prepareStatement(String sql) {
sql = sql.replaceAll("/\\*.*?\\*/","").trim();
Optional optional = Optional.of(sql);
if (sql.startsWith(SqlConstant.SELECT)) {
optional = Optional.ofNullable(getSelectSql(sql));
} else if (sql.startsWith(SqlConstant.UPDATE)) {
optional = Optional.ofNullable(getUpdateSql(sql));
} else if (sql.startsWith(SqlConstant.DELETE)) {
optional = Optional.ofNullable(getDeletedSql(sql));
} else if (sql.startsWith(SqlConstant.INSERT)) {
optional = Optional.ofNullable(getInsertSql(sql));
}
return optional.orElse(sql);
}
private String getUpdateSql(String sql) {
return rebuilderUpdateSql(sql);
}
private String getSelectSql(String sql) {
return rebuilderSelectSql(sql);
}
private String getDeletedSql(String sql) {
return rebuilderDeletedSql(sql);
}
private String getInsertSql(String sql) {
return rebuilderInsertSql(sql);
}
private static final Pattern alalisPattern = Pattern.compile("((?<=from|join|update|,)\\s*\\w+(((\\s+as|\\s+AS)*\\s+(?\\w+))*?)\\s*(?=where|left|right|inner|on|order|group|limit|$|set|,))",Pattern.CASE_INSENSITIVE);
private static final Pattern insertPattern = Pattern.compile("(\\s*insert into\\s+\\w+\\s*(?\\(.*?\\))\\s*value(s)?\\s*(?(\\(.*?\\),?)+?)(?=insert|$))",Pattern.CASE_INSENSITIVE);
private static final Pattern insertValuesPattern = Pattern.compile("(\\(.*?\\))",Pattern.CASE_INSENSITIVE);
//private static final Pattern pattern = Pattern.compile("(from\\s*\\w+(((\\s+as|\\s+AS)*\\s+(?\\w+))*?)\\s*(where|order|group|limit|$))");
private static final Pattern TABLE_PATTERN = Pattern.compile("((from|update)\\s+\\w+(((\\s+as|\\s+AS)*\\s+(?\\w+))*?)\\s*(where|order|group|set|limit|$))");
public static void main(String[] args) {
//String sql = "select u.* from user,car c,left join member m on m.userId=u.id join tree t on t.userId=u.id where u.name=123 and u.p=ii and u.id in(select d.id from deparment d where d.userId=u.id and d.jobId in (select j.id from job j)) and m.id in (select p.id from position p where p.userId=u.id) order by id desc";
//String sql = "update customer_customer_type set status=? where id=? and tenant_id=?";
//String sql = "delete from financial_bank where id=4";
String sql = "select customerty0_.id as id1_7_, customerty0_.create_host as create_h2_7_, customerty0_.create_time as create_t3_7_, customerty0_.create_user_id as create_u4_7_, customerty0_.create_user_name as create_u5_7_, customerty0_.tenant_id as tenant_i6_7_, customerty0_.update_host as update_h7_7_, customerty0_.update_time as update_t8_7_, customerty0_.update_user_id as update_u9_7_, customerty0_.update_user_name as update_10_7_, customerty0_.deleted as deleted11_7_, customerty0_.remark as remark12_7_, customerty0_.status as status13_7_, customerty0_.type_name as type_na14_7_ from customer_customer_type customerty0_ where 1=1 and customerty0_.deleted=? and customerty0_.tenant_id=10 order by customerty0_.update_time desc limit ?";
//String sql = "insert into tmp(name,age)values (0,'gone'),(1,'123')insert into tmp(name,age)values (0,'gone'),(1,'123')";
Matcher matcher = TABLE_PATTERN.matcher(sql);
StringBuffer stringBuffer = new StringBuffer();
while (matcher.find()) {
//stringBuffer.append();
System.out.println(matcher.group(1));
}
/*DefaultTenantSqlProvider provider = new DefaultTenantSqlProvider();
System.out.println(provider.getSqlAlais(sql));*/
}
@Override
public List getCondtion() {
//StringHelper.getObjectValue(BaseContextHandler.getTenantId())
List list = Arrays.asList(new ConditionItem("tenant_id", "10"));
return list;
}
}