基于注解扫描的高性能 mybatis 自动加解密拦截器实现

背景:工作中经常对于数据库的某些敏感字段需要密文存储,为了解放加解密的逻辑嵌入在业务流程中,我利用mybatis 的拦截器机制实现了自动加解密的功能,让业务代码更专注于业务。 

完整代码详见:https://github.com/xiananliu/Mybatis-Encryptor

流程如下:

  1. 项目启动,根据配置的包路径扫描自定义注解
  2. 根据扫描到的注解,利用反射获取该字段的set、get 方法
  3. 将每个字段生成的 读方法 和 写方法缓存起来备用
  4. 当拦截器执行时,将参数列表取出,并进行加密,获得返回对象后,再将返回对象进行解密

加解密配置

1,方式一:自定义注解



import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

/**
 * 数据库自动加解密注解
 * Created by lxn on 2018/9/19.
 */
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.FIELD)
public @interface AutoEncrypt {

    
}

方式二:xml配置

xml dtd:  db-security-cfg.dtd


        
        
        
        
        
        
        
        

xml :示例




    
        xxxxxxx
        
        
    

    
        buyerPhone
    
    
        buyerName
        buyerMobile
    
    
        ownerName
        ownerCardNo
        ownerMobile
    
    
        insuredName
        insuredCardNo
        insuredMobile
    
    
        insuredName
        insuredCardNo
        benefitName
        benefitCardNo
        benefitMobile
    
    
        insuredName
        insuredCardNo
    
    
        ownerName
        ownerCardNo
        ownerMobile
    
    
        customName
        customCardNo
        customMobile
    


mybatis 配置 xml位置




    
        
    
    
        

        
            
        
        
            
            
        
    

 

遇到的坑:

1,为了匹配mybatis-generator生成的Dao中QueryBuilder 可能出现字段类型为List 的情况做了特殊处理,即对已 String xxx 和 List 类型的字段都可以处理

2,为了匹配mybatis-generator生成的Dao中Po字段的写方法 可能没有set前缀做了特殊处理,注:使用JDK内省PropertyDescriptor 获取字段set方法时,该方法返回值必须为void,否则得不到该Method,因此这里没有使用JDK 的PropertyDescriptor 获取字段的读写方法,而是直接使用了反射。

所以这里兼容的读写方法为:  读:getXxx()  、写 :setXxx(Xxx xxx) 和 xxx(Xxx xxx)   

当然如果您的PO严格按照set、get的方式命名,也可以使用 PropertyDescriptor来获取。

3,mybatis 的参数map中会用不同的key指向 同一个Object,因此遍历此map进行加密有可能会对同一个Object加密两次,因此使用了Set做去重处理

 

 

 

package com.jd.baoxian.order.trade.dao.interceptor;

import com.google.common.base.Throwables;
import com.jd.baoxian.order.trade.common.annotation.AutoEncrypt;
import com.jd.baoxian.order.trade.common.utils.AESCoder;
import com.jd.baoxian.order.trade.common.utils.XmlUtil;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.plugin.*;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import org.springframework.core.io.Resource;
import org.springframework.core.io.support.PathMatchingResourcePatternResolver;
import org.springframework.core.io.support.ResourcePatternResolver;
import org.springframework.core.type.classreading.CachingMetadataReaderFactory;
import org.springframework.core.type.classreading.MetadataReader;
import org.springframework.core.type.classreading.MetadataReaderFactory;
import org.springframework.util.ClassUtils;
import org.springframework.util.StringUtils;

import java.io.FileNotFoundException;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.util.*;

import static java.util.Locale.ENGLISH;

/**
 * db自动加解密拦截器
 * Created by lxn on 2018/9/16.
 */
@Slf4j
@Intercepts({
        @Signature(type = Executor.class, method = "update", args = {MappedStatement.class, Object.class }),
        @Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class,ResultHandler.class }) })
public class DBInterceptor implements Interceptor {


    private Map> cache =new HashMap<>();

    private Settings settings;



    /**
     * 注解扫描
     * @param basePackage 扫描的路径
     * @return List 结果
     */
    private List scanAnnotationFrom(String basePackage){
        List results=new ArrayList<>();

        ResourcePatternResolver resourcePatternResolver = new PathMatchingResourcePatternResolver();
        final String RESOURCE_PATTERN = "/**/*.class";
        String pattern = ResourcePatternResolver.CLASSPATH_ALL_URL_PREFIX + ClassUtils.convertClassNameToResourcePath(basePackage)
                + RESOURCE_PATTERN;
        try {
            Resource[] resources = resourcePatternResolver.getResources(pattern);
            MetadataReaderFactory readerFactory = new CachingMetadataReaderFactory(resourcePatternResolver);
            for (Resource resource : resources) {
                if (resource.isReadable()) {
                    MetadataReader reader = readerFactory.getMetadataReader(resource);
                    //扫描到的class
                    String className = reader.getClassMetadata().getClassName();
                    Class clazz = Class.forName(className);
                    Field[] fields = clazz.getDeclaredFields();
                    if (fields==null){
                        continue;
                    }
                    for (Field eachField:fields){
                        AutoEncrypt autoEncrypt=eachField.getAnnotation(AutoEncrypt.class);
                        if (autoEncrypt!=null){
                            //找到注解,加入配置
                            results.add(new ScanResult(clazz,eachField.getName(),eachField.getType()));
                        }
                    }
                }
            }
        } catch (Exception e) {
            log.error("自动加解密扫描包异常basePackage:{},e:{}",basePackage,e);
        }

        return results;
    }






    /**
     * 扫描注解结果
     */

    @AllArgsConstructor
    @Data
    class ScanResult{
        private Class aClass;
        private String field;
        private Class fieldClass;
    }

    /**
     * 批量添加映射
     * @param list
     */
    private void addBeanPropertyList(List list){
        for (ScanResult each:list){
            addBeanProperty(each.getAClass(),each.getField(),each.getFieldClass());
        }
    }

    /**
     * Returns a String which capitalizes the first letter of the string.
     */
    public static String capitalize(String name) {
        if (name == null || name.length() == 0) {
            return name;
        }
        return name.substring(0, 1).toUpperCase(ENGLISH) + name.substring(1);
    }

    /**
     * 添加字段到映射
     * @param clazz
     * @param field
     */
    private void addBeanProperty(Class clazz,String field,Class fieldClass){
        addBeanProperty(clazz,clazz,field,fieldClass);
    }

    /**
     *  添加字段到映射
     * @param clazz
     * @param methodClass
     * @param field
     * @param fieldClass
     */
    private void addBeanProperty(Class clazz,Class methodClass,String field,Class fieldClass){
        log.info("db自动加解密拦截器注册字段:{}-->{}",clazz,field);
        if (!fieldClass.equals(String.class)&&!fieldClass.equals(List.class)&&!clazz.getName().endsWith("UpdateBuilder")){
            log.error("db 自动加解密字段仅支持String 或 List 字段,class:{},field:{}",clazz,field);
            return;
        }

        Method readMethod = null;
        try {
            readMethod=methodClass.getDeclaredMethod("get"+capitalize(field),null);
        } catch (NoSuchMethodException e) {
            log.error("db 解密字段加载异常,class:{},field:{},e:{}", clazz, field,Throwables.getStackTraceAsString(e));
        }
        Method writeMethod=null;
        try {
            writeMethod=methodClass.getDeclaredMethod("set"+capitalize(field),fieldClass);
        } catch (NoSuchMethodException e) {
            //List 特殊处理
            try {
                writeMethod=methodClass.getDeclaredMethod(field,fieldClass);
            } catch (NoSuchMethodException e1) {
                log.error("db 解密字段加载异常,class:{},field:{},e:{}", clazz, field,Throwables.getStackTraceAsString(e));
            }
        }

        if (readMethod==null||writeMethod==null){
            return;
        }

//        FastClass cglibBeanClass = FastClass.create(methodClass);
//        FastMethod readFastMethod=cglibBeanClass.getMethod(readMethod);
//        FastMethod writeFastMethod=cglibBeanClass.getMethod(writeMethod);
        if(cache.get(clazz)==null){
            cache.put(clazz,new HashMap<>());
        }
        Map propertyDescriptorMap=cache.get(clazz);
        propertyDescriptorMap.put(field,new MethodBox(readMethod,writeMethod));
    }

    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        String methodName = invocation.getMethod().getName();
        Object parameter=invocation.getArgs()[1];


        if(parameter instanceof Map) {
            //浅复制字段去重,避免同一个字段被加密多次
            Set disctinct = new HashSet<>();
            for (Object each : ((Map) parameter).values()) {
                boolean result = disctinct.add(each);
                if (result) {
                    modifyField(each, true);
                }
            }
        //特殊处理
        }else if (parameter.getClass().getName().endsWith("UpdateBuilder")){
            Class UpdateBuilderClazz=  Class.forName(parameter.getClass().getName());
            Map cacheMethod=cache.get(UpdateBuilderClazz);
            if (cacheMethod!=null){
                //读取set 和 where
                Object setObj=cacheMethod.get("set").getReadMethod().invoke(parameter);
                Object whereObj=cacheMethod.get("where").getReadMethod().invoke(parameter);
                modifyField(setObj,true);
                modifyField(whereObj,true);
            }
        }else {
            modifyField(parameter,true);
        }

        Object object = invocation.proceed();

        if (methodName.equals("query")) {
            //解密
            if(object instanceof List){
                for (Object each:(List)object){
                    modifyField(each,false);
                }
            }else {
                modifyField(object,false);
            }
        }

        return object;
    }





    /**
     * 结果解密
     * @param object
     */
    private void modifyField(Object object,boolean encript){

        Map map= cache.get(object.getClass());
        if(map==null){
            return;
        }
        for (Map.Entry each:map.entrySet()){
            MethodBox methodBox=each.getValue();
            try {
                Object value=methodBox.getReadMethod().invoke(object,null);
                if(value==null){
                   continue;
                }
                //String
                if (value instanceof String){
                   Object after = dealString(value,encript,object.getClass(),each.getKey());
                   methodBox.getWriteMethod().invoke(object,new Object[]{after});
                   continue;
                }
                //List
                if (value instanceof List) {
                  List list=  (List)value;
                  for (int i=0;i");
            } catch (Exception e) {
                log.error("classs:{},field:{} 不能访问,exception:{}",object.getClass(),each.getKey(),e);
            }
        }
    }



    private Object dealString(Object value,boolean encript,Class clazz,String field){
        if(!(value instanceof String)){
            log.error("classs:{},field:{} 不能访问,exception:{}",clazz,field,"不是String或 List");
            return value;
        }
        if(!StringUtils.hasText((String)value)){
            return value;
        }
        String after;
        if (encript){
            //加密
            after=encrypt((String)value);
        }else {
            //解密
            after=decrypt((String)value);
        }

        return after;
    }

    @Override
    public Object plugin(Object target) {
        return Plugin.wrap(target, this);
    }

    @Override
    public void setProperties(Properties properties) {
        String configLocation=properties.getProperty("configLocation");
        if (!StringUtils.hasText(configLocation)){
            throw new RuntimeException("db加解密拦截器 configLocation 配置丢失");
        }
        try {
           this.settings= XmlUtil.readXml(Settings.class,configLocation);
        } catch (FileNotFoundException e) {
           throw new RuntimeException("db加解密拦截器配置文件加载异常",e.fillInStackTrace());
        }

        init();
    }

    /**
     * 初始化
     */
    private void init(){
        //扫描
        if (settings.getSetting().getScanPackages()!=null) {
            for (String scanPackage : settings.getSetting().getScanPackages()) {
                List annos = scanAnnotationFrom(scanPackage);
                addBeanPropertyList(annos);
            }
        }

        specialProcess();

        //读取配置
        for (Mapper eachMapper:settings.getMappers()){
            Class clazz = null;
            try {
                clazz = Class.forName(eachMapper.getClassName());
            } catch (ClassNotFoundException e) {
                log.error(Throwables.getStackTraceAsString(e));
            }

            for (String eachField:eachMapper.getPropertys()){
                Field field;
                Class currClazz=clazz;
                do {
                    try {
                        field = currClazz.getDeclaredField(eachField);
                        //找到注解,加入配置
                        addBeanProperty(clazz,currClazz,field.getName(),field.getType());
                        break;
                    } catch (NoSuchFieldException e) {
                        currClazz=currClazz.getSuperclass();
                        if (currClazz==Object.class){
                            log.error(Throwables.getStackTraceAsString(e));
                        }
                    }
                }while (currClazz!=Object.class);


            }
        }
    }

    /**
     * 针对自动生成po的特殊处理
     */
    void specialProcess(){
        List specialMappers=new ArrayList<>();
        List mappers=settings.getMappers();
        if (mappers==null){
            return;
        }
        for (Mapper eachMapper:mappers){
            if (!eachMapper.isSpecial()){
                continue;
            }
            Mapper queryBuilder=new Mapper();
            queryBuilder.setClassName(eachMapper.getClassName().concat("$QueryBuilder"));
            Mapper conditionBuilder=new Mapper();
            conditionBuilder.setClassName(eachMapper.getClassName().concat("$ConditionBuilder"));

            List pros=new ArrayList<>();
            for (String orgPro :eachMapper.getPropertys()) {
                pros.add(orgPro.concat("List"));
            }

            queryBuilder.setPropertys(new ArrayList<>());
            queryBuilder.getPropertys().addAll(eachMapper.getPropertys());
            queryBuilder.getPropertys().addAll(pros);
            conditionBuilder.setPropertys(pros);
            specialMappers.add(queryBuilder);
            specialMappers.add(conditionBuilder);
            //updateBuilder 缓存读写方法
            try {
                Class updateBuilder = Class.forName(eachMapper.getClassName().concat("$UpdateBuilder"));
                Class setClass=Class.forName(eachMapper.getClassName());
                Class conditionClass=Class.forName(eachMapper.getClassName().concat("$ConditionBuilder"));
                addBeanProperty(updateBuilder,"set",setClass );
                addBeanProperty(updateBuilder,"where",conditionClass);
            }catch (ClassNotFoundException e) {
                e.printStackTrace();
            }


        }
        mappers.addAll(specialMappers);
    }



    /**远程调用接口实现解密
     * 若处理失败 ,则返回空串,防止错误数据传播
     * @param decryptData
     * @return
     */
    public String decrypt(String decryptData) {
        if(!StringUtils.hasText(decryptData)){
            return decryptData;
        }
        String deData = "";
        try {
            deData=  AESCoder.decrypt(decryptData,settings.getSetting().getAeskey());
            return deData;
        } catch (Exception e) {
            log.error("数据" + decryptData + "解密异常", e);
        }
        //不能解密则返回原文,因为数据库可能有未加密数据
        return decryptData;
    }
    /**远程调用接口实现加密
     * 若处理失败 ,则把原请求数据返回,防止数据丢失
     * @param encryptData
     * @return
     */
    public String encrypt(String encryptData) {
        if(!StringUtils.hasText(encryptData)){
            return encryptData;
        }
        String enData = "";
        try {
            enData=  AESCoder.encrpt(encryptData,settings.getSetting().getAeskey());
            return enData;
        } catch (Exception e) {
            log.error("数据加密异常", e);
            enData=encryptData;
        }
        return enData;
    }



    @AllArgsConstructor
    @Data
    class MethodBox{


        private Method readMethod;

        private Method writeMethod;

    }


} 
  

 

你可能感兴趣的:(java网络编程,软件工程)