修改若依的数据权限功能

若依(cloud版本)的数据权限功能是通过注解实现的,在需要数据权限的方法上加上注解。

在注解中判断当前用户的角色对应的数据权限类型,在执行的sql语句后面拼接部门和用户的sql过滤条件从而实现数据权限功能。

数据权限的核心功能的代码如下:修改若依的数据权限功能_第1张图片

如果当前用户有多个角色,那么多个角色之间不同的数据权限通过sql的or来连接。

最后在mybatis的xml中通过拼接sql语句实现。

修改若依的数据权限功能_第2张图片

这个功能没有问题,使用$来拼接sql也没有问题,因为数据权限中需要拼接的语句是写死的,而且若依也做了防注入的操作。

修改若依的数据权限功能_第3张图片

 但是我们的客户验收项目的时候会做代码扫描,这种通过$来拼接sql语句的情况是不允许出现的,哪怕这么写完全没有注入风险也不可以。

所以要找到另一种方法替换若依当前的数据权限功能。

网上翻了翻其他的成品框架,决定升级mybatis为myabtis-plus,使用mybatis-plus的内置拦截器来实现。数据权限的核心功能。

写一个自定义拦截器,实现myabtis-plus的内置拦截器InnerInterceptor ,就可以在sql语句执行前拦截到sql相关信息 ,在这里进行数据权限的操作。

先上拦截器的参考代码:

@Component
public class DataScopeInnerInterceptor implements InnerInterceptor {

    @Autowired
    private RedisService redisService;

    @Autowired
    private DataScopeService dataScopeService;//这个service的功能是使用feign调用auth模块中的代码,也就是下面代码的getUserDataScope方法。在登录的时候获得当前数据权限的角色id和部门id

    @Override
    public void beforeQuery(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds,
                            ResultHandler resultHandler, BoundSql boundSql) {
        String originalSql = boundSql.getSql();
        PluginUtils.MPBoundSql mpBs = PluginUtils.mpBoundSql(boundSql);
        //带有datascope注解的mapper方法名
        Object cacheObject1 = redisService.getCacheObject(UserConstants.DATA_SCOPE_ANNO);
        String scopeMethod = String.valueOf(cacheObject1);
        //获得所执行sql语句的mapper方法全限定名
        String mapperMethod = ms.getId();
        mapperMethod = mapperMethod.substring(mapperMethod.lastIndexOf(".")+1);
        if (scopeMethod.contains(mapperMethod)) {
            //获得当前登录用户
            LoginUser loginUser = SecurityUtils.getLoginUser();
            if (null != loginUser){//项目启动和登录时没有用户
                SysUser sysUser = loginUser.getSysUser();
                Object cacheObject = redisService.getCacheObject(UserConstants.USER_DATA_SCOPE + sysUser.getUserId());
                //如果redis中没有存,就再调用方法存一次
                if (null == cacheObject) {
                    dataScopeService.getUserDataScope(loginUser);
                    cacheObject = redisService.getCacheObject(UserConstants.USER_DATA_SCOPE + sysUser.getUserId());
                }
                String scoprStr = String.valueOf(cacheObject);
                if ("#".equals(scoprStr)) {//查询全部数据
                    originalSql = String.format("SELECT * FROM (%s) temp_data_scope",
                            originalSql);
                }
                else {
                    String[] split = scoprStr.split("#");
                    if (scoprStr.length()>1 && scoprStr.startsWith("#")) {   //只有对当前用户的数据权限   
                        originalSql = String.format("SELECT * FROM (%s) temp_data_scope WHERE temp_data_scope.user_id = %s", originalSql,split[1]);
                    } else if (scoprStr.length()>1 && scoprStr.endsWith("#")) {     //只有对部门的数据权限(包括本部门,本部门及子部门和自定义)  
                        originalSql = String.format("SELECT * FROM (%s) temp_data_scope WHERE temp_data_scope.dept_id IN (%s)", originalSql,split[0]);
                    } else if (!scoprStr.startsWith("#") &&  !scoprStr.endsWith("#")) {    //既有部门的数据权限又有用户的数据权限
                        originalSql = String.format("SELECT * FROM (%s) temp_data_scope WHERE temp_data_scope.dept_id IN (%s) or temp_data_scope.user_id = (%s)",originalSql, split[0],split[1]);
                    }
                }
                mpBs.sql(originalSql);
            }
        }
    }
}

这个拦截器默认拦截所有的sql语句操作,可以看到,这个是通过把原有的sql语句作为子查询来实现数据权限的,如果这样实现有以下问题需要解决:

1,子查询中的字段不能是重复的字段名,比如用户表和部门表关联,那么用户的user_id和部门的user_id不能在一起,必须有一个要起别名。所以要修改原有的语句。

2,若依之前的注解是有参数的,参数可以指定数据权限是过滤部门id还是用户id,

@DataScope(deptAlias = "d", userAlias = "u")

比如如果数据权限是“只有当前用户”的时候,按照上面的代码是在子查询外面加WHERE temp_data_scope.user_id = %s,这样就要求语句中一定要有user_id这个字段,就需要userAllias这个参数,如果没有这个参数,如果只有deptAlias参数,数据权限是不过滤的。

我打算把若依的注解的参数去掉,所以要在需要数据权限的语句中添加用户id和部门id,如果没有就关联用户表和部门表然后加上id。

3,拦截器默认拦截所有sql操作,所以要在在过滤器中找到哪些方法是加了若依之前的数据权限注解的,修改后的数据权限也是需要对加了注解的方法才进行过滤。

解决方法如下:在项目启动的时候扫描所有加注解的方法名,并redis缓存

参考代码:

@Component
public class PackageUtil {

    @Autowired
    RedisService redisService;

    @PostConstruct
    public void doTheJob(){
        getDatascopeAnnotationMethodName("com.tbenefitofcloud.system");
    }

    private void getDatascopeAnnotationMethodName(String packageName) {
        Set> classSet = getClassSet(packageName);
        List anno = new ArrayList<>();
        for (Class clazz : classSet) {
            Method[] declaredMethods = clazz.getDeclaredMethods();
            for (Method declaredMethod : declaredMethods) {
                String isNotNullStr = "";
                // 判断是否方法上存在注解  MethodInterface
                boolean annotationPresent = declaredMethod.isAnnotationPresent(DataScope.class);
                if (annotationPresent) {
                    // 获取自定义注解对象
                    DataScope methodAnno = declaredMethod.getAnnotation(DataScope.class);
                    // 根据对象获取注解值
                    anno.add(declaredMethod.getName());
                }
            }
        }
        System.out.println(anno);
        String join = StringUtils.join(anno.toArray(), ",");
        redisService.setCacheObject(UserConstants.DATA_SCOPE_ANNO,join);
    }

    /**
     * 获取类加载器
     */
    public ClassLoader getClassLoader() {
        return Thread.currentThread().getContextClassLoader();
    }

    /**
     * 加载类
     */
    public Class loadClass(String className, boolean isInitialized) {
        Class cls;
        try {
            cls = Class.forName(className, isInitialized, getClassLoader());
        } catch (ClassNotFoundException e) {
            throw new RuntimeException(e);
        }
        return cls;
    }

    /**
     * 获取指定包名下的所有类
     */
    public Set> getClassSet(String packageName) {
        Set> classSet = new HashSet>();
        try {
            Enumeration urls = getClassLoader().getResources(packageName.replace(".", "/"));
            while (urls.hasMoreElements()) {
                URL url = urls.nextElement();
                if (url != null) {
                    String protocol = url.getProtocol();
                    if (protocol.equals("file")) {
                        String packagePath = url.getPath().replaceAll("%20", " ");
                        addClass(classSet, packagePath, packageName);
                    } else if (protocol.equals("jar")) {
                        JarURLConnection jarURLConnection = (JarURLConnection) url.openConnection();
                        if (jarURLConnection != null) {
                            JarFile jarFile = jarURLConnection.getJarFile();
                            if (jarFile != null) {
                                Enumeration jarEntries = jarFile.entries();
                                while (jarEntries.hasMoreElements()) {
                                    JarEntry jarEntry = jarEntries.nextElement();
                                    String jarEntryName = jarEntry.getName();
                                    if (jarEntryName.endsWith(".class")) {
                                        String className = jarEntryName.substring(0, jarEntryName.lastIndexOf(".")).replaceAll("/", ".");
                                        doAddClass(classSet, className);
                                    }
                                }
                            }
                        }
                    }
                }
            }
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
        return classSet;
    }

    private void addClass(Set> classSet, String packagePath, String packageName) {
        File[] files = new File(packagePath).listFiles(new FileFilter() {
            public boolean accept(File file) {
                return (file.isFile() && file.getName().endsWith(".class")) || file.isDirectory();
            }
        });
        for (File file : files) {
            String fileName = file.getName();
            if (file.isFile()) {
                String className = fileName.substring(0, fileName.lastIndexOf("."));
                if (StringUtil.isNotEmpty(packageName)) {
                    className = packageName + "." + className;
                }
                doAddClass(classSet, className);
            } else {
                String subPackagePath = fileName;
                if (StringUtil.isNotEmpty(packagePath)) {
                    subPackagePath = packagePath + "/" + subPackagePath;
                }
                String subPackageName = fileName;
                if (StringUtil.isNotEmpty(packageName)) {
                    subPackageName = packageName + "." + subPackageName;
                }
                addClass(classSet, subPackagePath, subPackageName);
            }
        }
    }

    private void doAddClass(Set> classSet, String className) {
        Class cls = loadClass(className, false);
        classSet.add(cls);
    }
}

在登录的时候查找出当前用户对应角色所需的用户id和部门id,并缓存

参考代码:

public void getUserDataScope(LoginUser userInfo) {
        SysUser sysUser = userInfo.getSysUser();
        List roles = sysUser.getRoles();
        //该用户对应部门id
        List deptIdsByRoleId = new ArrayList<>();
        String scopeType = "";
        String userId = "";//用户id,数据权限为仅本人数据权限时使用
        for (SysRole role : roles) {
            scopeType = role.getDataScope();
            if ("2".equals(scopeType)){     //自定数据权限,从sys_role_dept表中查询
                List scopeDeptIds = remoteScopeService.getDeptIdsByRoleId(role.getRoleId());
                deptIdsByRoleId.addAll(scopeDeptIds);
            }else if ("3".equals(scopeType)){       //部门数据权限
                //当前角色对应部门id
                deptIdsByRoleId.add(sysUser.getDeptId());
            }else if ("4".equals(scopeType)){       //部门及下属部门
                List deptAndChild = remoteScopeService.findDeptAndChild(sysUser.getDeptId());
                deptIdsByRoleId.addAll(deptAndChild);
            }else if ("5".equals(scopeType)){//仅本人数据权限
                userId = sysUser.getUserId().toString();
            }
        }
        Set scopeDepts = new HashSet<>(deptIdsByRoleId);
        String join = StringUtils.join(scopeDepts.toArray(), ",");
        System.out.println(join);
        redisService.setCacheObject(UserConstants.USER_DATA_SCOPE+sysUser.getUserId(),join+"#"+userId);
    }

你可能感兴趣的:(自己总结,spring,反射)