通过注解进行权限控制

1.自定义注解

元注解介绍

@Target

@Target说明了Annotation所修饰的对象范围:Annotation可被用于 packages、types(类、接口、枚举、Annotation类型)、类型成员(方法、构造方法、成员变量、枚举值)、方法参数和本地变量(如循环变量、catch参数)。在Annotation类型的声明中使用了target可更加明晰其修饰的目标。

  1. CONSTRUCTOR:用于描述构造器

  2. FIELD:用于描述域

  3. LOCAL_VARIABLE:用于描述局部变量

  4. METHOD:用于描述方法

  5. PACKAGE:用于描述包

  6. PARAMETER:用于描述参数

  7. TYPE:用于描述类、接口(包括注解类型) 或enum声明

2.@Retention

表示需要在什么级别保存该注释信息,用于描述注解的生命周期(即:被描述的注解在什么范围内有效)

3.@Documented
Documented注解表明这个注释是由 javadoc记录的,在默认情况下也有类似的记录工具。 如果一个类型声明被注释了文档化,它的注释成为公共API的一部分。

自定义注解示例
@Target({ElementType.METHOD, ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface AuthCheck
{
  public abstract String module();

  public abstract String oper();

  public abstract String userProfile();

  public abstract String dbYear();

  public abstract String entityId();

  public abstract String entity();

  public abstract boolean failWhenUserProfileIsNull();

  public abstract String failCode();
}

2.AOP

AOP概念

面向切面编程:扩展功能不修改源代码实现,AOP采用横向抽取机制,取代传统的纵向继承体系重复性代码(性能监视,事务管理,安全检查,缓存)

AOP操作术语:

(JoinPoint)连接点:类里面可以被增强的方法

(Pointcut)切入点:实际增强的方法,也就是新增的方法

(Advice)通知/增强:增强的逻辑,比如扩展日志功能,这个日志功能增强

前置通知、后置通知、异常通知、最终通知、环绕通知

(Aspect)切面:把增强应用到具体方法上,过程称为切面

对@AuthCheck注解的处理实例
package com.supporter.prj.eip.auth_engine.service;

import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.Signature;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Pointcut;
import org.aspectj.lang.reflect.MethodSignature;
import org.springframework.stereotype.Service;
/**
 * 提供对@AuthCheck注解的处理.
 * @Aspect 实现spring aop 切面(Aspect):
 * 一个关注点的模块化,这个关注点可能会横切多个对象。事务管理是J2EE应用中一个关于横切关注点的很好的例子。 在Spring
 * AOP中,切面可以使用通用类(基于模式的风格) 或者在普通类中以 @Aspect 注解(@AspectJ风格)来实现。
 * AOP代理(AOP Proxy): AOP框架创建的对象,用来实现切面契约(aspect contract)(包括通知方法执行等功能)。
 * 在Spring中,AOP代理可以是JDK动态代理或者CGLIB代理。 注意:Spring 2.0引入的基于模式(schema-based)风格和@AspectJ注解风格的切面声明,对于使用这些风格的用户来说,代理的创建是透明的。
 */
@Service
@Aspect
public class AuthAnnotationService {
    private ILogService getLogService() {
        return LogUtil.getAuthDebugLogService();
    }
    /**
     * 构造方法.
     */
    public AuthAnnotationService() {}
    
    /**
     * 定义一个PointCut.
     * @param point
     * @throws Throwable
     */
    @Pointcut("@annotation(com.supporter.prj.eip_service.authority.annotation.AuthCheck)")
    public void methodPointcut() {}

    // 方法执行的前后调用
    @Around("methodPointcut()")
    public Object around(ProceedingJoinPoint point) throws Throwable {
        //getLogService().debug("类实例:" + point.getTarget().getClass() + ":" + point.getTarget().hashCode());
        boolean authCheckPassed = true;
        getLogService().debug("[权限引擎]注解方式开始执行,进入around(), 方法:" + point.getSignature().getName());
        String moduleSetting = this.getModule(point);
        IModule module = EIPService.getModuleService().getModule(moduleSetting);
        if (module != null) {
            //只有module设置有效才继续,否则会影响性能
            String operSetting = this.getOper(point);
            getLogService().debug("[权限引擎]operSetting:" + operSetting);
            IOper oper = EIPService.getModuleService().getOper(operSetting, module);
            if (oper != null) {
                //只有oper设置有效才继续,否则会影响性能
                UserProfile userProfile = this.getUserProfile(point);
                //getLogService().debug("[权限引擎]userProfile:" + userProfile.getAccountLogin());
                if (userProfile != null) {
                    //只有userProfile才继续,否则影响性能 
                    String dbYearSetting = this.getDbYearSetting(point);
                    getLogService().debug("[权限引擎]dbYearSetting:" + dbYearSetting);
                    if (dbYearSetting.length() > 0) {
                        //如果存在dbYear设置,那么说明是年库相关
                        int dbYear = this.getDbYear(point);
                        getLogService().debug("[权限引擎]dbYear:" + dbYear);
                        if (dbYear <= 0) {
                            getLogService().debug("[权限引擎]failed to get dbYear, auth check point:" + this.getMethodDesc(point));
                            authCheckPassed = false;
                        } else {
                               EIPService.getAuthorityService().canAccess(dbYear, userProfile, oper,entity);
                                if (getLogService().isDebugEnabled()) {
                                    getLogService().debug("[权限引擎]权限检查: " + userProfile.getAccountLogin() 
                                            + ":" + oper.getModuleId() + "." + oper.getName());
                                }
                                authCheckPassed = EIPService.getAuthorityService().canAccess(userProfile, oper);
//                          }
                        }
                    } else {
                        //否则与年库无关,在应用库中
//                      if (entity != null) {
//                          authCheckPassed = EIPService.getAuthorityService().canAccess(userProfile, oper, entity);
//                      } else {
                            //没有传递过来entityId,那么忽略数据限定条件
                            authCheckPassed = EIPService.getAuthorityService().canAccess(userProfile, oper);
//                      }
                    }
                } else {
                    //注:在userProfile为null的情况下,是否允许执行?根据failWhenUserProfileIsNull来判断
                    getLogService().debug("[权限引擎]userProfile is null. auch check for method:" + this.getMethodDesc(point));
                    boolean failWhenUserProfileIsNull = getFailWhenUserProfileIsNull(point);
                    if (failWhenUserProfileIsNull) authCheckPassed = false;
                }
            } else {
                getLogService().debug("[权限引擎]invalid oper:" + operSetting + ", module:" + moduleSetting);
            }
        } else {
            getLogService().debug("[权限引擎]invalid module:" + moduleSetting);
        }
        
        if (!authCheckPassed) throw new BaseRuntimeException(getFailCode(point), "[权限引擎]当前用户无权执行本操作:" + this.getMethodDesc(point));
        
        //执行被注解的方法.
        Object object;
        try {
            object = point.proceed();
        } catch (Throwable t) {
            t.printStackTrace();
            throw t;
        }
        long e = System.currentTimeMillis();
        getLogService().debug("[权限引擎]注解方式执行完毕:" + point.getSignature().getName()+",耗时:"+(e-s));
 
        return object;
    }

    // 方法运行出现异常时调用  
    // @AfterThrowing(pointcut = "execution(* com.supporter..*(..))", throwing = "ex")
    public void afterThrowing(Exception ex) {
        if (ex != null) ex.printStackTrace();
    }

    /**
     * key为"className:methodName",value是一个List,里面是同名的方法(可能参数不同).
     */
    @SuppressWarnings({ "unchecked", "rawtypes" })
    private static Map < String, List < Method > > methodMap = new ConcurrentHashMap();
    
    @SuppressWarnings({ "unchecked", "rawtypes" })
    private static List < Method > getMethods(String className, String methodName) {
        String key = className + ":" + methodName;
        if (methodMap.containsKey(key)) return methodMap.get(key);
        
        //否则获取一番
        try {
            Class < ? > targetClass = Class.forName(className);
            return getMethods(targetClass, methodName);
        } catch (ClassNotFoundException e) {
            e.printStackTrace();
        }

        return new ArrayList();
    }
    
    @SuppressWarnings({ "unchecked", "rawtypes" })
    private static List < Method > getMethods(Class < ? > clazz, String methodName) {
        if (clazz == null) return null;
        
        String key = clazz.getName() + ":" + methodName;
        if (methodMap.containsKey(key)) return methodMap.get(key);
        
        List < Method > methodList = new ArrayList();
        //否则获取一番
        //Method[] methods = targetClass.getMethods(); //getMethods()只返回public的方法
        Method[] methods = clazz.getDeclaredMethods();
        for (Method m : methods) {
            if (m.getName().equals(methodName)) {
                methodList.add(m);
            }
        }
        
        if (methodList.size() == 0) {
            //在当前类中找不到,试试看祖先有没有
            if (clazz.getSuperclass() != null) {
                methodList = getMethods(clazz.getSuperclass(), methodName);
            }
        }
        
        methodMap.put(key, methodList); //放入缓存中
        
        return methodList;
    }
    
    private String getMethodDesc(ProceedingJoinPoint joinPoint) {
        String targetName = joinPoint.getTarget().getClass().getName();
        String methodName = joinPoint.getSignature().getName();
        Signature signature = joinPoint.getSignature();
        MethodSignature methodSignature = (MethodSignature) signature;
        
        String paramNameStr = "";
        String[] paramNames = methodSignature.getParameterNames();
        for (String paramName : paramNames) {
            if (paramNameStr.length() > 0) paramNameStr += ",";
            paramNameStr += paramName;
        }
        return targetName + "." + methodName + "(" + paramNameStr + ")";
    }
    
    /**
     * 获取被注解的方法所属的模块.
     * @param joinPoint
     * @return
     * @throws Exception
     */
    private String getModule(ProceedingJoinPoint joinPoint) throws Exception {
        Method m = getMethod(joinPoint);
        if (m == null) {
            String methodDesc = this.getMethodDesc(joinPoint);
            throw new RuntimeException("找不到被注解的方法:" + methodDesc);
        } else {
            AuthCheck authCheck = m.getAnnotation(AuthCheck.class);
            if (authCheck != null) {
                return authCheck.module();
            } else {
                return null;
            }
        }
    }
    
    /**
     * 获取被注解的方法所属的操作.
     * @param joinPoint
     * @return
     * @throws Exception
     */
    private String getOper(ProceedingJoinPoint joinPoint) throws Exception {
        Method m = getMethod(joinPoint);
        if (m == null) {
            String methodDesc = this.getMethodDesc(joinPoint);
            throw new RuntimeException("找不到被注解的方法:" + methodDesc);
        } else {
            AuthCheck authCheck = m.getAnnotation(AuthCheck.class);
            if (authCheck != null) {
                return authCheck.oper();
            } else {
                return null;
            }
        }
    }
    
    /**
     * 获取被注解的方法所属的模块.
     * @param joinPoint
     * @return
     * @throws Exception
     */
    private boolean getFailWhenUserProfileIsNull(ProceedingJoinPoint joinPoint) throws Exception {
        Method m = getMethod(joinPoint);
        if (m == null) {
            String methodDesc = this.getMethodDesc(joinPoint);
            throw new RuntimeException("找不到被注解的方法:" + methodDesc);
        } else {
            AuthCheck authCheck = m.getAnnotation(AuthCheck.class);
            if (authCheck != null) {
                return authCheck.failWhenUserProfileIsNull();
            } else {
                return true;
            }
        }
    }
    
    /**
     * 获取被注解的方法在校验失败后返回的代码.
     * @param joinPoint
     * @return
     * @throws Exception
     */
    private String getFailCode(ProceedingJoinPoint joinPoint) throws Exception {
        Method m = getMethod(joinPoint);
        if (m == null) {
            String methodDesc = this.getMethodDesc(joinPoint);
            throw new RuntimeException("找不到被注解的方法:" + methodDesc);
        } else {
            AuthCheck authCheck = m.getAnnotation(AuthCheck.class);
            if (authCheck != null) {
                return authCheck.failCode();
            } else {
                return null;
            }
        }
    }
    
    /**
     * 获取被注解的方法所属的userProfile.
     * @param joinPoint
     * @return
     * @throws Exception
     */
    private UserProfile getUserProfile(ProceedingJoinPoint joinPoint) throws Exception {
        Method m = getMethod(joinPoint);
        if (m == null) {
            String methodDesc = this.getMethodDesc(joinPoint);
            throw new RuntimeException("找不到被注解的方法:" + methodDesc);
        } else {
            AuthCheck authCheck = m.getAnnotation(AuthCheck.class);
            if (authCheck != null) {
                String userProfileSetting = authCheck.userProfile();
                if (userProfileSetting == null) throw new RuntimeException("userProfileSetting is null");
                userProfileSetting = userProfileSetting.trim();
                //logService.debug("userProfileSetting:" + userProfileSetting);
                
                if (userProfileSetting.endsWith("()")) {
                    //应该是一个方法名,那么通过反射的形式来获取相应的值
                    String methodName = userProfileSetting.substring(0, userProfileSetting.length() - 2);
                    
                    Object returnVal = this.getReturnVal(methodName, joinPoint);
                    if (returnVal == null) {
                        return null;
                    } else {
                        return (UserProfile) returnVal;
                    }
                } else {
                    //应该是参数名,那么寻找对应的参数值
                    //寻找对应的参数值
                    Object paramVal = getParamVal(userProfileSetting, joinPoint);
                    if (paramVal == null) {
                        return null;
                    } else {
                        return (UserProfile) paramVal;
                    }
                }
            } else {
                return null;
            }
        }
    }
    
    /**
     * 获取被注解的方法对应的操作相关的业务实体ID.
     * @param joinPoint
     * @return
     * @throws Exception
     */
    private Object getEnityId(ProceedingJoinPoint joinPoint) throws Exception {
        String entityIdSetting = getEntityIdSetting(joinPoint);
        if (entityIdSetting.length() == 0) {
            return null;
        }
        
        if (entityIdSetting.endsWith("()")) {
            //应该是一个方法名,那么通过反射的形式来获取相应的值
            String methodName = entityIdSetting.substring(0, entityIdSetting.length() - 2);
            return this.getReturnVal(methodName, joinPoint);
        } else {
            //应该是参数名,那么寻找对应的参数值
            //寻找对应的参数值
            return getParamVal(entityIdSetting, joinPoint);
        }
    }
    
    /**
     * 获取被注解的方法所属的entity.
     * @param joinPoint
     * @return
     * @throws Exception
     */
    private Object getEntity(ProceedingJoinPoint joinPoint) throws Exception {
        String entitySetting = getEntitySetting(joinPoint);
        if (entitySetting.length() == 0) {
            return null;
        }
        if (entitySetting.endsWith("()")) {
            //应该是一个方法名,那么通过反射的形式来获取相应的值
            String methodName = entitySetting.substring(0, entitySetting.length() - 2);
            return this.getReturnVal(methodName, joinPoint);
        } else {
            //应该是参数名,那么寻找对应的参数值
            //寻找对应的参数值
            return getParamVal(entitySetting, joinPoint);
        }
    }
    
    /**
     * 获取被注解的方法所属的dbYear设置值字符串,它的含义可能是一个参数名,也可能是一个方法名.
     * @param joinPoint
     * @return
     * @throws Exception
     */
    private String getDbYearSetting(ProceedingJoinPoint joinPoint) throws Exception {
        Method m = getMethod(joinPoint);
        if (m == null) {
            String methodDesc = this.getMethodDesc(joinPoint);
            throw new RuntimeException("找不到被注解的方法:" + methodDesc);
        } else {
            AuthCheck authCheck = m.getAnnotation(AuthCheck.class);
            if (authCheck == null) {
                EIPService.getLogService().error("无法找到相应的注解:" + getMethodDesc(joinPoint));
                return "";
            }

            String dbYear = authCheck.dbYear();
            if (dbYear == null) throw new RuntimeException("dbYear is null");
            
            return dbYear.trim(); 
        }
    }
    
    /**
     * 获取被注解的方法所属的entityId设置值字符串,它的含义可能是一个参数名,也可能是一个方法名.
     * @param joinPoint
     * @return
     * @throws Exception
     */
    private String getEntityIdSetting(ProceedingJoinPoint joinPoint) throws Exception {
        Method m = getMethod(joinPoint);
        if (m == null) {
            String methodDesc = this.getMethodDesc(joinPoint);
            throw new RuntimeException("找不到被注解的方法:" + methodDesc);
        } else {
            AuthCheck authCheck = m.getAnnotation(AuthCheck.class);
            if (authCheck == null) {
                EIPService.getLogService().error("无法找到相应的注解:" + getMethodDesc(joinPoint));
                return "";
            }

            String entityId = authCheck.entityId();
            if (entityId == null) return "";
            
            return entityId.trim(); 
        }
    }
    
    /**
     * 获取被注解的方法所属的entity设置值字符串,它的含义可能是一个参数名,也可能是一个方法名.
     * @param joinPoint
     * @return
     * @throws Exception
     */
    private String getEntitySetting(ProceedingJoinPoint joinPoint) throws Exception {
        Method m = getMethod(joinPoint);
        if (m == null) {
            String methodDesc = this.getMethodDesc(joinPoint);
            throw new RuntimeException("找不到被注解的方法:" + methodDesc);
        } else {
            AuthCheck authCheck = m.getAnnotation(AuthCheck.class);
            if (authCheck == null) {
                EIPService.getLogService().error("无法找到相应的注解:" + getMethodDesc(joinPoint));
                return "";
            }

            String entity = authCheck.entity();
            if (entity == null) return "";
            
            return entity.trim(); 
        }
    }
    
    /**
     * 获取被注解的方法的指定名称的参数值.
     * @param paramName
     * @param joinPoint
     * @return
     */
    private Object getParamVal(String paramName, ProceedingJoinPoint joinPoint) {
        String pName = CommonUtil.trim(paramName);
        if (pName.length() == 0) {
            getLogService().error("paramName is empty.");
            return null;
        }
        
        //寻找对应的参数值
        Signature signature = joinPoint.getSignature();
        MethodSignature methodSignature = (MethodSignature) signature;
        String[] paramNames = methodSignature.getParameterNames();
        Object[] paramVals = joinPoint.getArgs();
        for (int i = 0; i < paramNames.length; i++) {
            String name = paramNames[i];
            //logService.debug(paramNames[i] + ":" + paramVals[i]);
            if (pName.equals(name)) {
                return paramVals[i];
            }
        }
        
        getLogService().error("找不到指定的参数:" + paramName + ", 方法:" + this.getMethodDesc(joinPoint));
        return null;
    }
    
    /**
     * 获取指定方法(没有任何参数)的返回值.
     * @param methodName
     * @param joinPoint
     * @return
     */
    private Object getReturnVal(String methodName, ProceedingJoinPoint joinPoint) {
        String mName = CommonUtil.trim(methodName);
        if (mName.length() == 0) {
            getLogService().error("找不到指定的方法:" + methodName + ", 类:" + joinPoint.getTarget().getClass().getName());
            return null;
        }
        
        Object target = joinPoint.getTarget();
        List < Method > methods = getMethods(target.getClass().getName(), methodName);
        
        Method method = null; 
        for (Method mthd : methods) {
            if (mthd.getParameterTypes().length > 0) {
                continue;
            } else {
                method = mthd;
                break;
            }
        }
        if (method == null) {
            String msg = "无法在类中找到指定方法:" + methodName + ", 所属类:" + target.getClass().getName();
            EIPService.getLogService().error(msg);
            throw new RuntimeException(msg);
        } else {
            //以反射形式执行
            Object returnVal = null;
            try {
                method.setAccessible(true); //据称设置为true可以让java不再检查方法是否私有等,有助性能提高
                returnVal = method.invoke(target);
            } catch (Throwable e) {
                e.printStackTrace();
            } 
            
            if (returnVal == null) {
                String msg = "指定方法返回null值或没有返回值:" + methodName + ", 所属类:" + target.getClass().getName();
                EIPService.getLogService().error(msg);
                throw new RuntimeException(msg);
            } else {
                return returnVal;
            }
        }
    }
    
    /**
     * 获取被注解的方法所属的userProfile.
     * @param joinPoint
     * @return
     * @throws Exception
     */
    private int getDbYear(ProceedingJoinPoint joinPoint) throws Exception {
        String dbYearSetting = getDbYearSetting(joinPoint);
        if (dbYearSetting.length() == 0) {
            return 0;
        }
        
        if (dbYearSetting.endsWith("()")) {
            //应该是一个方法名,那么通过反射的形式来获取相应的值
            String methodName = dbYearSetting.substring(0, dbYearSetting.length() - 2);
            
            Object returnVal = this.getReturnVal(methodName, joinPoint);
            if (returnVal == null) {
                return 0;
            } else {
                return CommonUtil.parseInt(returnVal.toString());
            }

        } else {
            //应该是参数名,那么寻找对应的参数值
            //寻找对应的参数值
            Object paramVal = getParamVal(dbYearSetting, joinPoint);
            if (paramVal == null) {
                return 0;
            } else {
                return CommonUtil.parseInt(paramVal.toString(), 0);
            }
        }
    }
    
    /**
     * 获取被注解的方法.
     * @param joinPoint
     * @return
     * @throws Exception
     */
    private Method getMethod(ProceedingJoinPoint joinPoint) throws Exception {
        String targetName = joinPoint.getTarget().getClass().getName();
        String methodName = joinPoint.getSignature().getName();
        Object[] arguments = joinPoint.getArgs();
        
        List < Method > methods = getMethods(targetName, methodName);
        
        if (methods == null || methods.size() == 0) return null; //不存在,不过这个一般是不可能的
        
        if (methods.size() == 1) return methods.get(0);
        
        for (Method m : methods) {
            //如果不止一个同名方法,那么要判断一番
            Class < ? > [] paramTypes = m.getParameterTypes();
            if (paramTypes.length != arguments.length) continue; //参数的个数不一样,跳过

            //检查是否真的匹配
            boolean matched = true; //先假设是匹配的
            for (int i = 0; i < paramTypes.length; i++) {
                Class < ? > paramType = paramTypes[i];
                Object argument = arguments[i];
                if (argument == null) continue;
                if (!paramType.isInstance(argument)) {
                    matched = false;
                    break; //不匹配,不需要再继续
                }
            }

            if (matched) return m; 
        }
        
        return null;
    }
     
}

权限控制逻辑处理

1.通过反射机制在目标对象中获取用户信息
2.通过应用id及权限项找到所有有该用户权限的用户
3.若所有有该权限用户集合中包含此目标用户则放行

你可能感兴趣的:(通过注解进行权限控制)