Druid数据源MySql语句,添加租户(tenant_id)id

由于开发的系统为多租户系统并且,技术采用了jpa+hibernate。查阅了hibernate的官方文档,并不支持sql方式对tenant_id 的数据隔离。所以无奈只能自行实现:

项目源码:https://gitee.com/97wx/nm-datasource

一、修改druid 项目与 druid-spring-boot-starter项目源码

项目源码地址:

https://github.com/alibaba/druid

项目源码版本:

	
    com.alibaba
	druid
	1.1.13-preview_3

1、druid源码变更部分:

1)新增扩展接口

/**
 * 租户sql 服务提供都
 */
public interface TenantSqlProvider {

    /**
     * sql 替换操作
     * @param sql
     * @return
     */
    String prepareStatement(String sql);
}

2)DriudDataSource类与DruidConnectionHolder同时添加成员变量

    /**
     * 租户sql
     */
    private TenantSqlProvider tenantSqlProvider;


    public TenantSqlProvider getTenantSqlProvider() {
        return tenantSqlProvider;
    }

    public void setTenantSqlProvider(TenantSqlProvider tenantSqlProvider) {
        this.tenantSqlProvider = tenantSqlProvider;
    }

3)修改DriudDataSource类在 getConnectionInternal方法最后将tenantSqlProvider 设置到DruidConnectionHolder中

holder.setTenantSqlProvider(tenantSqlProvider);

4)在DruidPooledConnection类下所有prepareStatement方法最开始的位置添加如下代码

if(holder.getTenantSqlProvider()!=null){
    sql = holder.getTenantSqlProvider().prepareStatement(sql);
}

2、druid-spring-boot-starter源码变更部分:

1)DruidDataSourceAutoConfigure类注入TenantSqlProvider并设置到DruidDataSourceWrapper中

public class DruidDataSourceAutoConfigure {

    private static final Logger LOGGER = LoggerFactory.getLogger(DruidDataSourceAutoConfigure.class);

    @Autowired(required = false)
    private TenantSqlProvider tenantSqlProvider;


    @Bean(initMethod = "init")
    @ConditionalOnMissingBean
    public DataSource dataSource() {
        LOGGER.info("Init DruidDataSource");
        DruidDataSourceWrapper wrapper = new DruidDataSourceWrapper();
        wrapper.setTenantSqlProvider(tenantSqlProvider);
        return wrapper;
    }
}

 

二、整体思路如下:

1、通过栈找到匹配的(),根据栈的先进后出特性总能找到匹配成一对的()并记录“(”与“(”的位置

    正则表达式:

       

 "(\\(|\\))"

    数据结构:        

@Data
public class OffsetItem {
    public OffsetItem(){}
    public OffsetItem(Integer start) {
        this.start = start;
    }
    private String tableName;

    private String values;

    private String fields;

    private Integer start;
    private Integer end;
    private String subSql;
}
 

2、当找到")"时,出栈,并使用截取子串方式获得配对的内容,如果字串内容包含select 则可以认为此子串为一个sql 语句,并且将子串保到父节点中

最后提取出来的sql 应该是这样的:

    

select u.* from user left join member m on m.userId=u.id where u.name=123 and u.p=ii and u.id in(select j.id from job limit 1) and m.id in (select d.id from deparment d where d.userId=u.id and d.jobId in (select j.id from job j)) order by id desc

select j.id from job limit 1

select j.id from job j

3、递归分解子句

若存在子句则进行递归调用,继续分解子句。直到最后一个子句不存在子句时,则可以进行替换

4、提取别名

将子句替换后返回到上一层语句,先将当前层的语句,将子句剔除出去,用来提取别名,应该是这样的如果不存在别名则用null代替

将子句剔除后应该是这样的:

select d.id from deparment d where d.userId=u.id and d.jobId in ()

正则表达式:

1)提取表名、左右连接、全连接部分

"((from|update)\\s+\\w+(((\\s+as|\\s+AS)*\\s+(?\\w+))*?)\\s*(where|order|group|set|limit|$))"

2)提取别名通过find 可以获得一个别名数组

((?<=from|join|,|update)\s*\w+(((\s+as|\s+AS)*\s+(?\w+))*?)\s*(?=where|left|right|inner|on|order|group|limit|$|,|set))

5、语句与子句拼接

最后通过子句的位置可以将当前层的语句完整的拼接出来,然后返回

 

三、具体实现

1、用到的java Bean

import lombok.Data;

/***
 *
 * @Auther: guzhangwen
 * @BusinessCard https://blog.csdn.net/gu_zhang_w
 * @Date: 2019/1/24
 * @Time: 14:01
 * @version : V1.0
 */
@Data
public class OffsetItem {
    public OffsetItem() {
    }

    public OffsetItem(Integer start) {
        this.start = start;
    }

    private String tableName;

    private String values;

    private String fields;

    private Integer start;
    private Integer end;
    private String subSql;
}



/**
 * sql 关键字汇总
 */
public interface SqlConstant {

    public static final String SELECT = "select";
    public static final String UPDATE = "update";
    public static final String DELETE = "delete";
    public static final String INSERT = "insert";
    public static final String WHERE = "where";
    public static final String ORDER = "order";
    public static final String GROUP = "group";
    public static final String LIMIT = "limit";
    public static final String AS = "AS";
    public static final String as = "as";
    public static final String JOIN = "join";
    public static final String ON = "on";
    public static final String LEFT = "left";
    public static final String RIGHT = "right";
    public static final String INNER = "inner";
    public static final String FROM = "from";
    public static final String AND = "and";

}

2、sql 置换默认实现抽象类TenantSqlSupport

import com.lvbang.erp.common.support.bean.ConditionItem;
import com.lvbang.erp.common.support.bean.OffsetItem;
import com.lvbang.erp.common.support.bean.SqlConstant;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.StringUtils;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Stack;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

/***
 * 租户sql 帮助类
 * @Auther: guzhangwen
 * @BusinessCard https://blog.csdn.net/gu_zhang_w
 * @Date: 2019/1/24
 * @Time: 13:59
 * @version : V1.0
 */
@Slf4j
public abstract class TenantSqlSupport {
    /**
     * 提取配对的"("")"
     */
    private static final Pattern PATTERN = Pattern.compile("(\\(|\\))");

    /**
     * 提取from 到where 之间的别名或者表名
     */
    private static final Pattern ALALIS_PATTERN = Pattern.compile("((?<=from|join|,|update)\\s*\\w+(((\\s+as|\\s+AS)*\\s+(?\\w+))*?)\\s*(?=where|left|right|inner|on|order|group|limit|$|,|set))", Pattern.CASE_INSENSITIVE);

    /**
     * 保存sql 匹配,当前仅仅匹配带字段的sql即:
     * insert into tmp(name,age)values (0,'gone'),(1,'123') insert into tmp(name,age)values (0,'gone'),(1,'123')
     */
    private static final Pattern INSERT_PATTERN = Pattern.compile("(\\s*insert into\\s+(?\\w+)\\s*\\((?.*?)\\)\\s*value(s)?\\s*(?(\\(.*?\\),?)+?)(?=insert|$|;))", Pattern.CASE_INSENSITIVE);

    /**
     * 保存的值正则
     */
    private static final Pattern INSERT_VALUES_PATTERN = Pattern.compile("((?<=\\().*?(?=\\)))", Pattern.CASE_INSENSITIVE);

    /**
     * 表的正则
     */
    private static final Pattern TABLE_PATTERN = Pattern.compile("((from|update)\\s+\\w+(((\\s+as|\\s+AS)*\\s+(?\\w+))*?)\\s*(where|order|group|set|limit|$))");

    /**
     * 构建查询语句
     *
     * @param sql
     * @return
     */
    public String rebuilderSelectSql(String sql) {

        /**
         * 不加自定义处理
         */
        if (CollectionUtils.isEmpty(getCondtion())) {
            return sql;
        }

        Matcher matcher = PATTERN.matcher(sql);
        Stack offsetItemStack = new Stack<>();
        List offsetItemStackReverse = new ArrayList<>();
        boolean flag = false;
        while (matcher.find()) {
            String operator = matcher.group(1);
            if ("(".equals(operator)) {
                OffsetItem offsetItem = new OffsetItem(matcher.end());
                offsetItemStack.push(offsetItem);
            } else if (")".equals(operator)) {
                OffsetItem offsetItem = offsetItemStack.pop();
                offsetItem.setEnd(matcher.start());
                String subSql = sql.substring(offsetItem.getStart(), offsetItem.getEnd());
                if (subSql.trim().startsWith(SqlConstant.SELECT)) {
                    flag = true;
                    if (offsetItemStack.empty()) {
                        checkedSubSql(offsetItem, subSql);
                        offsetItemStackReverse.add(offsetItem);
                    } else {
                        OffsetItem prevOffsetItem = offsetItemStack.peek();
                        checkedSubSql(prevOffsetItem, subSql);
                    }
                }
            }
        }
        String newSql = null;
        if (flag) {
            newSql = mergeSql(sql, offsetItemStackReverse);
        } else {
            List alaisList = getSqlAlais(sql);
            newSql = appendToCondition(sql, alaisList);
        }
        log.debug("==>>构建查询语句:置换后的sql={}", newSql);
        return newSql;
    }

    /**
     * 重构更新语句
     *
     * @param sql
     * @return
     */
    public String rebuilderUpdateSql(String sql) {
        /**
         * 不加自定义处理
         */
        if (CollectionUtils.isEmpty(getCondtion())) {
            return sql;
        }

        List alaisList = getSqlAlais(sql);
        if (CollectionUtils.isEmpty(alaisList)) {
            log.error("无法获取别名、请检查sql 匹配规则是否正确,语句={}", sql);
            throw new RuntimeException("无法获取别名、请检查sql 匹配规则是否正确");
        }
        String newSql = appendToCondition(sql, Arrays.asList(alaisList.get(0)));
        log.debug("==>>重构更新语句:置换后的sql={}", newSql);
        return newSql;
    }

    /**
     * 重构删除语句
     *
     * @param sql
     * @return
     */
    public String rebuilderDeletedSql(String sql) {
        /**
         * 不加自定义处理
         */
        if (CollectionUtils.isEmpty(getCondtion())) {
            return sql;
        }

        List alaisList = getSqlAlais(sql);
        if (CollectionUtils.isEmpty(alaisList)) {
            log.error("无法获取别名、请检查sql 匹配规则是否正确,语句={}", sql);
            throw new RuntimeException("无法获取别名、请检查sql 匹配规则是否正确");
        }
        String newSql = appendToCondition(sql, Arrays.asList(alaisList.get(0)));
        log.debug("==>>重构删除语句:置换后的sql={}", newSql);
        return newSql;
    }

    /**
     * 构建插入语句
     *
     * @param sql
     * @return
     */
    public String rebuilderInsertSql(String sql) {
        /**
         * 不加自定义处理
         */
        List conditionItemList = getCondtion();
        if (CollectionUtils.isEmpty(conditionItemList)) {
            return sql;
        }
        try {
            Matcher matcher = INSERT_PATTERN.matcher(sql);
            List offsetItemList = new ArrayList<>();
            while (matcher.find()) {
                OffsetItem offsetItem = new OffsetItem();
                offsetItem.setTableName(matcher.group("tableName"));
                offsetItem.setFields(matcher.group("fields").trim());
                offsetItem.setValues(matcher.group("values").trim());
                offsetItemList.add(offsetItem);
            }

            if (CollectionUtils.isNotEmpty(offsetItemList)) {
                StringBuffer sb = new StringBuffer();
                for (OffsetItem item : offsetItemList) {
                    if (sb.length() > 0) {
                        sb.append(";");
                    }
                    sb.append("insert into ").append(item.getTableName()).append(" (").append(item.getFields());
                    String fields = item.getFields();
                    for (ConditionItem conditionItem : conditionItemList) {
                        if (fields.contains(conditionItem.getName()) || (StringUtils.isBlank(conditionItem.getName()) || StringUtils.isBlank(conditionItem.getValue()))) {
                            continue;
                        }
                        sb.append(",").append(conditionItem.getName());
                    }
                    sb.append(")");
                    List valueList = getInsertSqlValues(item.getValues());
                    if (CollectionUtils.isEmpty(valueList)) {
                        log.error("插入语句重构失败,请检查匹配规则与语法是否正确,sql={},正则表达式={}", sql, INSERT_VALUES_PATTERN.pattern());
                        throw new RuntimeException("插入语句重构失败,请检查匹配规则与语法是否正确");
                    }
                    sb.append("value");
                    if (valueList.size() > 1) {
                        sb.append("s");
                    }
                    for (int i = 0; i < valueList.size(); i++) {
                        if (i > 0) {
                            sb.append(",");
                        }
                        sb.append("(").append(valueList.get(i));
                        for (ConditionItem conditionItem : conditionItemList) {
                            if (fields.contains(conditionItem.getName()) || (StringUtils.isBlank(conditionItem.getName()) || StringUtils.isBlank(conditionItem.getValue()))) {
                                continue;
                            }
                            sb.append(",").append(conditionItem.getValue());
                        }
                        sb.append(")");
                    }
                }
                String newSql = sb.toString();
                log.debug("==>>构建插入语句:置换后的sql={}", newSql);
                return newSql;
            }
        } catch (Exception e) {

        }
        return sql;
    }

    /**
     * 获取插入语句的sql
     *
     * @param values
     * @return
     */
    private List getInsertSqlValues(String values) {
        List valueList = new ArrayList<>();
        Matcher matcher = INSERT_VALUES_PATTERN.matcher(values);
        while (matcher.find()) {
            valueList.add(matcher.group(1));
        }
        return valueList;
    }

    /**
     * 合并sql
     *
     * @param originalSql            原sql
     * @param offsetItemStackReverse sql子句对象{@link OffsetItem}
     * @return
     */
    private String mergeSql(String originalSql, List offsetItemStackReverse) {
        StringBuffer sb = new StringBuffer();
        List subSqlList = new ArrayList<>();
        if (CollectionUtils.isNotEmpty(offsetItemStackReverse)) {
            OffsetItem prevItem = null;
            for (int i = 0; i < offsetItemStackReverse.size(); i++) {
                OffsetItem item = offsetItemStackReverse.get(i);
                if (sb.length() <= 0) {
                    sb.append(originalSql.substring(0, item.getStart()));
                }
                if (prevItem != null && prevItem.getEnd() != item.getStart()) {
                    sb.append(originalSql.substring(prevItem.getEnd(), item.getStart()));
                }
                sb.append("%s");
                if ((i + 1) == offsetItemStackReverse.size()) {
                    sb.append(originalSql.substring(item.getEnd(), originalSql.length()));
                }
                subSqlList.add(item.getSubSql());
                prevItem = item;
            }
        }

        //例如 select * from a where uid in(%s) and  did in(%s)
        String newSql = sb.toString();
        List alaisList = getSqlAlais(newSql);
        newSql = appendToCondition(newSql, alaisList);
        if (CollectionUtils.isNotEmpty(subSqlList)) {
            for (String s : subSqlList) {
                newSql = newSql.replaceFirst("%s", s);
            }
        }
        return newSql;
    }

    /**
     * 检查subSql 是否还存在子查询
     *
     * @param offsetItem sql子句对象{@link OffsetItem}
     * @param subSql     子查询sql
     */
    private void checkedSubSql(OffsetItem offsetItem, String subSql) {
        if (appearNumber(subSql, SqlConstant.SELECT) > 1) {
            offsetItem.setSubSql(rebuilderSelectSql(subSql));
        } else {
            List alaisList = getSqlAlais(subSql);
            String newSql = appendToCondition(subSql, alaisList);
            offsetItem.setSubSql(newSql);
        }
    }

    /**
     * 获取指定字符串出现的次数
     *
     * @param srcText  源字符串
     * @param findText 要查找的字符串
     * @return
     */
    public int appearNumber(String srcText, String findText) {
        int count = 0;
        Pattern p = Pattern.compile(findText);
        Matcher m = p.matcher(srcText);
        while (m.find()) {
            count++;
        }
        return count;
    }

    /**
     * 获取sql 语句的别名
     *
     * @param sql
     * @return
     */
    public List getSqlAlais(String sql) {
        Matcher tableMatcher = TABLE_PATTERN.matcher(sql);
        StringBuffer stringBuffer = new StringBuffer();
        while (tableMatcher.find()) {
            stringBuffer.append(tableMatcher.group(1));
        }
        Matcher matcher = ALALIS_PATTERN.matcher(stringBuffer.toString());
        List alaisList = new ArrayList<>();
        while (matcher.find()) {
            alaisList.add(matcher.group(SqlConstant.as));
        }
        return alaisList;
    }

    /**
     * 添加查询过滤条件
     *
     * @param sql
     * @param alaisList
     * @return
     */
    private String appendToCondition(String sql, List alaisList) {
        StringBuffer sb = new StringBuffer();
        int offset = 0;
        offset = sql.lastIndexOf(SqlConstant.WHERE);
        if (offset > 0) {
            offset += 5;
            sb.append(sql.substring(0, offset)).append(" ");
            boolean flag = appendTenant(sb, alaisList);
            if (flag) {
                sb.append(" and ");
            }
            sb.append(sql.substring(offset, sql.length()));
        } else {
            try {
                offset = sql.lastIndexOf(SqlConstant.ORDER);
                if (offset > 0) {
                    throw new Exception();
                }
                offset = sql.lastIndexOf(SqlConstant.GROUP);
                if (offset > 0) {
                    throw new Exception();
                }
                offset = sql.lastIndexOf(SqlConstant.LIMIT);
                if (offset > 0) {
                    throw new Exception();
                }
            } catch (Exception e) {
            }
            if (offset > 0) {
                sb.append(sql.substring(0, offset)).append(SqlConstant.WHERE).append(" ");
                appendTenant(sb, alaisList);
                sb.append(" ").append(sql.substring(offset, sql.length()));
            } else {
                sb.append(sql).append(" ").append(SqlConstant.WHERE + " ");
                appendTenant(sb, alaisList);
            }
        }
        return sb.toString();
    }

    /**
     * 添加租户id
     *
     * @param sb        拼接的字符串
     * @param alaisList sql 中包含的别名
     */
    private boolean appendTenant(StringBuffer sb, List alaisList) {
        boolean flag = false;
        for (int i = 0; i < alaisList.size(); i++) {
            String alais = alaisList.get(i);
            if (i >= 1) {
                sb.append(" " + SqlConstant.AND + " ");
            }
            List conditionItemList = getCondtion();
            if (CollectionUtils.isNotEmpty(conditionItemList)) {
                for (ConditionItem item : conditionItemList) {
                    //已经包含指定字段名称
                    if (StringUtils.isEmpty(item.getName()) || StringUtils.isEmpty(item.getValue())) {
                        continue;
                    }
                    flag = true;
                    if (StringUtils.isBlank(item.getName()) || StringUtils.isBlank(item.getValue())) {
                        log.error("condition name 与 value 都不能为null");
                        throw new RuntimeException("condition name 与 value 都不能为null");
                    }
                    if (alais == null) {
                        sb.append(String.format("%s=%s", item.getName(), item.getValue()));
                    } else {
                        sb.append(String.format("%s.%s=%s", alais, item.getName(), item.getValue()));
                    }
                }
            }
        }
        return flag;
    }

    /**
     * 获取需要附加的查询条件
     *
     * @return
     */
    public abstract List getCondtion();
}

3、TenantId 扩展接口的实现者与默认sql 置换默认实现抽象类的实现者

import com.alibaba.druid.tenant.TenantSqlProvider;
import com.lvbang.erp.common.support.bean.ConditionItem;
import com.lvbang.erp.common.support.bean.SqlConstant;

import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

/***
 *
 * @Auther: guzhangwen
 * @BusinessCard https://blog.csdn.net/gu_zhang_w
 * @Date: 2019/1/22
 * @Time: 18:01
 * @version : V1.0
 */
public class DefaultTenantSqlProvider extends TenantSqlSupport implements TenantSqlProvider {

    @Override
    public String prepareStatement(String sql) {
        sql = sql.replaceAll("/\\*.*?\\*/","").trim();
        Optional optional = Optional.of(sql);
        if (sql.startsWith(SqlConstant.SELECT)) {
            optional = Optional.ofNullable(getSelectSql(sql));
        } else if (sql.startsWith(SqlConstant.UPDATE)) {
            optional = Optional.ofNullable(getUpdateSql(sql));
        } else if (sql.startsWith(SqlConstant.DELETE)) {
            optional = Optional.ofNullable(getDeletedSql(sql));
        } else if (sql.startsWith(SqlConstant.INSERT)) {
            optional = Optional.ofNullable(getInsertSql(sql));
        }
        return optional.orElse(sql);
    }

    private String getUpdateSql(String sql) {
        return rebuilderUpdateSql(sql);
    }

    private String getSelectSql(String sql) {
        return rebuilderSelectSql(sql);
    }

    private String getDeletedSql(String sql) {
        return rebuilderDeletedSql(sql);
    }

    private String getInsertSql(String sql) {
        return rebuilderInsertSql(sql);
    }

    private static final Pattern alalisPattern = Pattern.compile("((?<=from|join|update|,)\\s*\\w+(((\\s+as|\\s+AS)*\\s+(?\\w+))*?)\\s*(?=where|left|right|inner|on|order|group|limit|$|set|,))",Pattern.CASE_INSENSITIVE);
    private static final Pattern insertPattern = Pattern.compile("(\\s*insert into\\s+\\w+\\s*(?\\(.*?\\))\\s*value(s)?\\s*(?(\\(.*?\\),?)+?)(?=insert|$))",Pattern.CASE_INSENSITIVE);
    private static final Pattern insertValuesPattern = Pattern.compile("(\\(.*?\\))",Pattern.CASE_INSENSITIVE);

    //private static final Pattern pattern = Pattern.compile("(from\\s*\\w+(((\\s+as|\\s+AS)*\\s+(?\\w+))*?)\\s*(where|order|group|limit|$))");

    private static final Pattern TABLE_PATTERN = Pattern.compile("((from|update)\\s+\\w+(((\\s+as|\\s+AS)*\\s+(?\\w+))*?)\\s*(where|order|group|set|limit|$))");
    public static void main(String[] args) {
        //String sql = "select u.* from user,car c,left join member m on m.userId=u.id join tree t on t.userId=u.id where u.name=123 and u.p=ii and u.id in(select d.id from deparment d where d.userId=u.id and d.jobId in (select j.id from job j)) and m.id in (select p.id from position p where p.userId=u.id) order by id desc";
        //String sql = "update customer_customer_type set status=? where id=? and tenant_id=?";
        //String sql = "delete from financial_bank where id=4";
        String sql = "select customerty0_.id as id1_7_, customerty0_.create_host as create_h2_7_, customerty0_.create_time as create_t3_7_, customerty0_.create_user_id as create_u4_7_, customerty0_.create_user_name as create_u5_7_, customerty0_.tenant_id as tenant_i6_7_, customerty0_.update_host as update_h7_7_, customerty0_.update_time as update_t8_7_, customerty0_.update_user_id as update_u9_7_, customerty0_.update_user_name as update_10_7_, customerty0_.deleted as deleted11_7_, customerty0_.remark as remark12_7_, customerty0_.status as status13_7_, customerty0_.type_name as type_na14_7_ from customer_customer_type customerty0_ where 1=1 and customerty0_.deleted=? and customerty0_.tenant_id=10 order by customerty0_.update_time desc limit ?";
        //String sql = "insert into tmp(name,age)values (0,'gone'),(1,'123')insert into tmp(name,age)values (0,'gone'),(1,'123')";
        Matcher matcher = TABLE_PATTERN.matcher(sql);
        StringBuffer stringBuffer = new StringBuffer();
        while (matcher.find()) {
            //stringBuffer.append();
            System.out.println(matcher.group(1));
        }
        /*DefaultTenantSqlProvider provider = new DefaultTenantSqlProvider();
        System.out.println(provider.getSqlAlais(sql));*/
    }


    @Override
    public List getCondtion() {
        //StringHelper.getObjectValue(BaseContextHandler.getTenantId())
        List list = Arrays.asList(new ConditionItem("tenant_id", "10"));
        return list;
    }
}

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

你可能感兴趣的:(开发日志,spring,sql,租户id,tenant_id,数据隔离)