使用mybatis plus自定义拦截器,实现数据权限

需求

为了增强程序的安全性,需要在用户访问数据库的时候进行权限判断后选择性进行判断是否需要增强sql,来达到限制低级别权限用户访问数据的目的.

根据业务需要,这里将角色按照数据范围做权限限定.比如,角色如下:

编号 名称 描述
1 管理员 全部数据权限
2 普通角色 自定义数据权限
3 部门权限 部门权限
4 部门及以下数据权限 部门及以下数据权限
5 本人数据 本人数据

部门如下:

编号 父id 名称 描述
1 0 北京总公司
101 1 北京公司1
102 1 北京公司2
10101 101 丰台公司1
10102 101 丰台公司2
10201 102 昌平公司1
10202 102 昌平公司2

思路:

  1. 可以模仿 PageHelper.startPage 用来声明哪些操作需要做范围限制
  2. 定义 Mybatis plus自定义拦截器Interceptor,用于每次拦截查询sql语句,附带数据范围权限sql条件
  3. 因为使用了 PageHelper.startPage 分页插件的使用,先计算总数,怎么在这之前拦截,需要拦截多次
  4. 考虑到Mybatis拦截器能够拦截SQL执行的整个过程,因为我们可以考虑SQL执行之前,对SQL进行重写,从而达到数据行权限的目的。

步骤:

  1. 声明datescope
protected static final ThreadLocal threadLocal = new ThreadLocal();
/**
 * 设置权限标识
 */
public static void startDataScope(){
    threadLocal.set(SecurityConstants.DATA_SCOPE);
}
/**
 * 获取权限标识
 */
public static String getDataScope(){
    return threadLocal.get();
}
/**
 * 清除权限标识
 */
public static void cleanDataScope(){
    threadLocal.remove();
}
复制代码
  1. 定义 Mybatis plus自定义拦截器
/**
 * @author zhc
 * @description 数据权限插件
 * @date 2022-04-01 17:03
 */
@Intercepts(
        {
                @Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}),
                @Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class, CacheKey.class, BoundSql.class}),
        }
)
@Slf4j
public class MybatisDataPermissionIntercept implements Interceptor {
    CCJSqlParserManager parserManager = new CCJSqlParserManager();

    /**
     * 全部数据权限
     */
    public static final String DATA_SCOPE_ALL = "1";

    /**
     * 自定数据权限
     */
    public static final String DATA_SCOPE_CUSTOM = "2";

    /**
     * 部门数据权限
     */
    public static final String DATA_SCOPE_DEPT = "3";

    /**
     * 部门及以下数据权限
     */
    public static final String DATA_SCOPE_DEPT_AND_CHILD = "4";

    /**
     * 仅本人数据权限
     */
    public static final String DATA_SCOPE_SELF = "5";

    @Override
    public Object intercept(Invocation invocation) throws Throwable {

        try {
            Object[] args = invocation.getArgs();
            MappedStatement ms = (MappedStatement) args[0];
            Object parameter = args[1];
            RowBounds rowBounds = (RowBounds) args[2];
            ResultHandler resultHandler = (ResultHandler) args[3];
            Executor executor = (Executor) invocation.getTarget();
            CacheKey cacheKey;
            BoundSql boundSql;
            //由于逻辑关系,只会进入一次
            if (args.length == 4) {
                //4 个参数时
                boundSql = ms.getBoundSql(parameter);
                cacheKey = executor.createCacheKey(ms, parameter, rowBounds, boundSql);
            } else {
                //6 个参数时
                cacheKey = (CacheKey) args[4];
                boundSql = (BoundSql) args[5];
            }
            //TODO 自己要进行的各种处理
            String sql = boundSql.getSql();
            log.info("原始SQL: {}", sql);
            //判断线程内是否有权限信息
            String dataScope = SecurityUtils.getDataScope();
            if (SecurityConstants.DATA_SCOPE.equals(dataScope)){
                // 增强sql
                Select select = (Select) parserManager.parse(new StringReader(sql));
                SelectBody selectBody = select.getSelectBody();
                if (selectBody instanceof PlainSelect) {
                    this.setWhere((PlainSelect) selectBody);
                } else if (selectBody instanceof SetOperationList) {
                    SetOperationList setOperationList = (SetOperationList) selectBody;
                    List selectBodyList = setOperationList.getSelects();
                    selectBodyList.forEach((s) -> {
                        this.setWhere((PlainSelect) s);
                    });
                }
                String dataPermissionSql = select.toString();
                log.info("增强SQL: {}", dataPermissionSql);
                BoundSql dataPermissionBoundSql = new BoundSql(ms.getConfiguration(), dataPermissionSql, boundSql.getParameterMappings(), parameter);
                return executor.query(ms, parameter, rowBounds, resultHandler, cacheKey, dataPermissionBoundSql);
            }
            //注:下面的方法可以根据自己的逻辑调用多次,在分页插件中,count 和 page 各调用了一次
            return executor.query(ms, parameter, rowBounds, resultHandler, cacheKey, boundSql);
        } finally {
            //清除线程中权限参数
            SecurityUtils.cleanDataScope();
        }
    }

    @Override
    public Object plugin(Object target) {
        return Plugin.wrap(target, this);
    }

    @Override
    public void setProperties(Properties properties) {
    }

    protected void setWhere(PlainSelect plainSelect) {
        Expression sqlSegment = this.getSqlSegment(plainSelect.getWhere());
        if (null != sqlSegment) {
            plainSelect.setWhere(sqlSegment);
        }
    }

    @SneakyThrows
    public Expression getSqlSegment(Expression where) {
        JSONObject loginUser = getLoginUser();
        if (loginUser == null){
            return where;
        }
        Integer deptId = loginUser.getInteger("deptId");
        String userId = loginUser.getString("userId");
        JSONArray roles = loginUser.getJSONArray("roles");
        StringBuilder sqlString = new StringBuilder();
        for (Object role : roles) {
            JSONObject roleJson = JSONObject.parseObject(role.toString());
            String dataScopeNum = roleJson.getString(SecurityConstants.DATA_SCOPE);
            Integer roleId = roleJson.getInteger("roleId");
            if (DATA_SCOPE_ALL.equals(dataScopeNum)) {
                // 全部数据权限
                sqlString = new StringBuilder();
                break;
            } else if (DATA_SCOPE_CUSTOM.equals(dataScopeNum)) {
                sqlString.append(" OR `sys_dept`.dept_id IN ( SELECT dept_id FROM `sys_role_dept` WHERE role_id = '")
                        .append(roleId)
                        .append("' ) ");
            } else if (DATA_SCOPE_DEPT.equals(dataScopeNum)) {
                sqlString.append(" OR `sys_dept`.dept_id = '").append(deptId).append("' ");
            } else if (DATA_SCOPE_DEPT_AND_CHILD.equals(dataScopeNum)) {
                sqlString.append(" OR `sys_dept`.dept_id IN ( SELECT dept_id FROM `sys_dept` WHERE dept_id = '")
                        .append(deptId)
                        .append("' or find_in_set( '")
                        .append(deptId)
                        .append("' , ancestors ) ) ");
            }else if (DATA_SCOPE_SELF.equals(dataScopeNum)) {
                //TODO 暂时有问题
                sqlString.append(" OR `sys_user`.user_id = '").append(userId).append("' ");
            }
        }
        if (StringUtils.isNotBlank(sqlString.toString())) {
            if (where == null){
                where = new HexValue(" 1 = 1 ");
            }
            sqlString.insert(0," AND (");
            sqlString.append(")");
            sqlString.delete(7, 9);
            //判断是不是分页, 分页完成之后 清除权限标识
            return CCJSqlParserUtil.parseCondExpression(where + sqlString.toString());
        }else {
            return where;
        }
    }

}
复制代码
  1. 修改MybatisDataPermissionIntercept在PageHelper插件后面执行
org.springframework.boot.autoconfigure.EnableAutoConfiguration=\
com.zhc.cloud.mybatis.intercept.MybatisInterceptorAutoConfiguration

/**
 * @author zhc
 * @description 数据权限插件 配置
 * @date 2022-04-01 17:01
 */
@AutoConfigureAfter(PageHelperAutoConfiguration.class)
@Configuration
public class MybatisInterceptorAutoConfiguration implements InitializingBean{

    @Autowired
    private List sqlSessionFactoryList;

    @Override
    @PostConstruct
    public void afterPropertiesSet() throws Exception {
        //创建自定义mybatis拦截器,添加到chain的最后面
        MybatisDataPermissionIntercept mybatisInterceptor = new MybatisDataPermissionIntercept();

        for (SqlSessionFactory sqlSessionFactory : sqlSessionFactoryList) {
            org.apache.ibatis.session.Configuration configuration = sqlSessionFactory.getConfiguration();
            //自己添加
            configuration.addInterceptor(mybatisInterceptor);
        }
    }

}


 

你可能感兴趣的:(面试,后端,java,spring,cloud,mysql)