springboot+mybatis拦截器+自定义注解实现数据脱敏

20240510 使用发现mapper中方法只有一个参数时会出问题,故进行修复

1.介绍

实际业务场景中,为了安全考虑,某些字段插入数据库之前需要进行加密处理,并且在查询的结果中还需要进行解密处理。为简化开发人员的开发工作量,使用mybatis拦截器对某些字段进行加解密,这样就不用开发人员手动去加密,提高效率。

之前我也写过一篇springboot+mybatis+自定义注解实现数据脱敏是基于springAOP实现的,想了解的小伙伴可以去看那一篇,本文采用mybatis拦截器实现数据脱敏

2.源码

用在mapper上的注解
import java.lang.annotation.*;

@Documented
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.TYPE})
public @interface Message {
    FieldName[] value();
}
用在Message里面的注解
import java.lang.annotation.*;

@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.TYPE})
@Documented
public @interface FieldName {
    String methodName();
    String[] fieldName();
    String collectionName() default "";
    boolean onlyOneParam();
    boolean required() default false;
}

@Message注解用在mapper类上,@FieldName注解是@Message的属性

其中@FieldName有三个属性:

methodName:方法名,标识该方法需要进行加解密操作

fieldName:属性名,标识入参中该属性名的值需要进行加解密操作

collectionName:集合名,入参中若有List,需要把参数名填入到这个字段里

onlyOneParam:方法是否只有一个参数(必填)
required:一个参数时是否需要解密(和onlyOneParam配合使用,非必填)

示例:

@Message({
        @FieldName(methodName = "searchAllByPhone",fieldName = {"phone"},collectionName = "studentInListList")
})
public interface StudentMapper extends IBaseMapper {
    List searchAllByPhone(List studentInListList);
}




public class StudentInList {

    private String phone;

}
自定义Example类
import tk.mybatis.mapper.entity.Example;

import java.util.List;

public class NewExample extends Example {

    public NewExample(Class entityClass) {
        super(entityClass);
    }

    public NewExample(Class entityClass, boolean exists) {
        super(entityClass, exists);
    }

    public NewExample(Class entityClass, boolean exists, boolean notNull) {
        super(entityClass, exists, notNull);
    }

    private List filedNameList;

    public List getFiledNameList() {
        return filedNameList;
    }

    public void setFiledNameList(List filedNameList) {
        this.filedNameList = filedNameList;
    }
}

这里之所以要自定义Example类是为了新增加一个filedNameList属性,当我们在代码里使用Example查询时(mybatis自带的查询方法),这个属性是为了告诉程序哪些属性名需要进行加解密操作。

示例:

public List selectAll(String name,Integer pageNum,Integer pageSize){
        NewExample newExample = new NewExample(Student.class);
        Example.Criteria criteria1 = newExample.createCriteria();
        criteria1.andEqualTo("name",name).andEqualTo("age",24);
        newExample.setFiledNameList(Arrays.asList("name"));
        PageHelper.startPage(pageNum,pageSize);
        return studentMapper.selectByExample(newExample);
    }

 上面这段代码就代表字段名是name的需要进行加解密操作。

AESUtisl加解密工具类

import javax.crypto.Cipher;
import javax.crypto.KeyGenerator;
import javax.crypto.SecretKey;
import javax.crypto.spec.IvParameterSpec;
import javax.crypto.spec.SecretKeySpec;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import java.util.Base64;
 
public class AESUtils {
 
    private static final String DEFAULT_V = "********";//用户自定义
    private static final String ALGORITHM = "AES";
    private static final String KEY = "***********";//用户自定义
 
    private static String byteToHexString(byte[] bytes) {
        StringBuilder sb = new StringBuilder();
        for (byte aByte : bytes) {
            String strHex = Integer.toHexString(aByte);
            if (strHex.length() > 3) {
                sb.append(strHex.substring(6));
            } else {
                if (strHex.length() < 2) {
                    sb.append("0").append(strHex);
                } else {
                    sb.append(strHex);
                }
            }
        }
        return sb.toString();
    }
 
 
    private static SecretKeySpec getKey() {
        byte[] arrBTmp = KEY.getBytes();
        // 创建一个空的16位字节数组(默认值为0)
        byte[] arrB = new byte[16];
        for (int i = 0; i < arrBTmp.length && i < arrB.length; i++) {
            arrB[i] = arrBTmp[i];
        }
        return new SecretKeySpec(arrB, ALGORITHM);
    }
 
    /**
     * 加密
     */
    public static String encrypt(String content) throws Exception {
        final Base64.Encoder encoder = Base64.getEncoder();
        SecretKeySpec keySpec = getKey();
        Cipher cipher = Cipher.getInstance("AES/CBC/PKCS5Padding");
        IvParameterSpec iv = new IvParameterSpec(DEFAULT_V.getBytes());
        cipher.init(Cipher.ENCRYPT_MODE, keySpec, iv);
        byte[] encrypted = cipher.doFinal(content.getBytes());
        return encoder.encodeToString(encrypted);
    }
 
 
    /**
     * 解密
     */
    public static String decrypt(String content) throws Exception {
        final Base64.Decoder decoder = Base64.getDecoder();
        SecretKeySpec keySpec = getKey();
        Cipher cipher = Cipher.getInstance("AES/CBC/PKCS5Padding");
        IvParameterSpec iv = new IvParameterSpec(DEFAULT_V.getBytes());
        cipher.init(Cipher.DECRYPT_MODE, keySpec, iv);
        byte[] base64 = decoder.decode(content);
        byte[] original = cipher.doFinal(base64);
        return new String(original);
    }
 
}
Executor拦截器
import com.bt.dlife.common.base.NewExample;
import com.bt.dlife.common.interceptor.aop.FieldName;
import com.bt.dlife.common.interceptor.aop.Message;
import com.bt.dlife.common.utils.AESUtils;
import lombok.extern.slf4j.Slf4j;
import org.apache.ibatis.cache.CacheKey;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import org.springframework.stereotype.Component;
import tk.mybatis.mapper.entity.Example;

import java.lang.reflect.Field;
import java.util.*;


/**
 * @author zhl
 * @date 2023/08/20
 */

@Intercepts({
        @Signature(type = Executor.class, method = "update", args = {MappedStatement.class, Object.class}),
        @Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}),
        @Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class,
                RowBounds.class, ResultHandler.class, CacheKey.class, BoundSql.class})
})
@Component
@Slf4j
public class ExecutorInterceptor implements Interceptor {


    @Override
    public Object plugin(Object target) {
        return Interceptor.super.plugin(target);
    }

    @Override
    public void setProperties(Properties properties) {
        Interceptor.super.setProperties(properties);
    }

    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        log.info("in:{}", invocation);
        MappedStatement mappedStatement = (MappedStatement) invocation.getArgs()[0];
        Object object = invocation.getArgs()[1];
        if (object instanceof NewExample) {
            NewExample newExample = (NewExample) object;
            List fieldNameList = newExample.getFiledNameList();
            if(fieldNameList==null||fieldNameList.size()==0){
                return invocation.proceed();
            }
            List oredCriteria = newExample.getOredCriteria();
            oredCriteria.forEach(criteria -> {
                List criterionList = criteria.getCriteria();
                criterionList.forEach(criterion -> {
                    String condition = criterion.getCondition();
                    fieldNameList.forEach(s -> {
                        if (condition.equals(s + " " + "=")) {
                            Field[] fields = criterion.getClass().getDeclaredFields();
                            for (Field field : fields) {
                                field.setAccessible(true);
                                if (field.getName().equals("value")) {
                                    try {
                                        field.set(criterion, AESUtils.encrypt(String.valueOf(field.get(criterion))));
                                    } catch (Exception e) {
                                        log.error("异常:", e);
                                    }
                                }
                            }
                        }
                    });
                });
            });

        } else {
            String[] split = mappedStatement.getId().split("\\.");
            List stringList = new ArrayList<>(Arrays.asList(split));
            int index = stringList.size() - 1;
            String methodName = stringList.remove(index);
            String path = String.join(".", stringList);
            Class o = Class.forName(path);
            Message annotation = o.getAnnotation(Message.class);
            if(null==annotation){
                return invocation.proceed();
            }
            FieldName[] fieldNames = annotation.value();
            for (FieldName fieldName : fieldNames) {
                if(fieldName.onlyOneParam()){
                    if(fieldName.required()){
                        invocation.getArgs()[1] = AESUtils.encrypt(String.valueOf(invocation.getArgs()[1]));
                    }
                    return invocation.proceed();
                }//202405100 增加此处逻辑,修复单个参数的问题
                if (methodName.equals(fieldName.methodName())) {
                    List nameList = new ArrayList<>(Arrays.asList(fieldName.fieldName()));
                    String collectionName = fieldName.collectionName();
                    Object param = invocation.getArgs()[1];
                    if(param instanceof Map){
                        Map arg = (Map) invocation.getArgs()[1];
                        if("".equals(collectionName)&&arg.containsKey("arg0")){
                            collectionName = "arg0";
                        }
                        if (!"".equals(collectionName)) {
                            List list = (List) arg.get(collectionName);
                            List newList = new ArrayList<>();
                            list.forEach(g -> {
                                if (g instanceof String) {
                                    try{
                                        newList.add(AESUtils.encrypt(String.valueOf(g)));
                                    }catch (Exception e) {
                                        log.error("异常:",e);
                                    }

                                } else {
                                    changeFieldValue(g, nameList);
                                    newList.add(g);
                                }
                            });
                            arg.put(collectionName, newList);
                        } else {
                            nameList.forEach(s -> {
                                if (arg.containsKey(s)) {
                                    Object value = arg.get(s);
                                    try{
                                        value = AESUtils.encrypt(String.valueOf(value));
                                    } catch (Exception e) {
                                        log.error("异常:",e);
                                    }
                                    arg.put(s, value);
                                }
                            });
                        }
                    }else{
                        changeFieldValue(param,nameList);
                    }
                }
            }
        }
        return invocation.proceed();
    }

    private void changeFieldValue(Object object, List fieldNameList) {
        Field[] fields = object.getClass().getDeclaredFields();
        for (Field field : fields) {
            field.setAccessible(true);
            if (fieldNameList.contains(field.getName())) {
                try {
                    field.set(object, AESUtils.encrypt(String.valueOf(field.get(object))));
                } catch (Exception e) {
                    log.error("异常:", e);
                }

            }
        }
    }

} 
  
ResultSet结果拦截器
import com.bt.dlife.common.interceptor.aop.FieldName;
import com.bt.dlife.common.interceptor.aop.Message;
import com.bt.dlife.common.utils.AESUtils;
import lombok.extern.slf4j.Slf4j;
import org.apache.ibatis.executor.resultset.DefaultResultSetHandler;
import org.apache.ibatis.executor.resultset.ResultSetHandler;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.plugin.*;
import org.springframework.stereotype.Component;

import java.lang.reflect.Field;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Properties;


/**
 * @author zhl
 * @date 2023/08/20
 */
@Slf4j
@Component
@Intercepts({
        @Signature(type = ResultSetHandler.class, method = "handleResultSets", args = {Statement.class})
})
public class ResultSetInterceptor implements Interceptor {


    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        //取出查询的结果
        Object resultObject = invocation.proceed();
        if(resultObject==null){
            return null;
        }
        if(resultObject.getClass().getClassLoader() == null && !(resultObject instanceof List)){//判断是否是java的基本类型封装类
            try{
                return AESUtils.decrypt(String.valueOf(resultObject));
            }catch (Exception e){
                return resultObject;
            }
        }
        DefaultResultSetHandler defaultResultSetHandler = (DefaultResultSetHandler)invocation.getTarget();
        MappedStatement mappedStatement = getMappedStatement(defaultResultSetHandler);
        if(null==mappedStatement){
            return resultObject;
        }
        String[] strings = mappedStatement.getId().split("\\.");
        List stringList = new ArrayList<>(Arrays.asList(strings));
        int index = stringList.size()-1;
        String methodName = stringList.remove(index);
        String path = String.join(".", stringList);
        Class o = Class.forName(path);
        Message message = o.getAnnotation(Message.class);
        if(message==null){
            return resultObject;
        }
        FieldName[] fieldNames = message.value();
        for (FieldName fieldName : fieldNames) {
            if(methodName.equals(fieldName.methodName())){
                List nameList = new ArrayList<>(Arrays.asList(fieldName.fieldName()));
                if(resultObject instanceof List){
                    List objectList = (List) resultObject;
                    List newResult = new ArrayList<>();
                    objectList.forEach(g->{
                        if(g.getClass().getClassLoader()==null){
                            Object result = "";
                            try{
                                result= AESUtils.decrypt(String.valueOf(g));
                            } catch (Exception e) {
                                result = g;
                            }
                            newResult.add(result);
                        }else{
                            result(g,nameList);
                            newResult.add(g);
                        }
                    });
                    return newResult;
                }
            }
        }
        return resultObject;
    }

    @Override
    public Object plugin(Object target) {
        return Plugin.wrap(target, this);
    }

    @Override
    public void setProperties(Properties properties) {

    }

    private void result(Object object,List nameList){
        Field[] fields = object.getClass().getDeclaredFields();
        for(Field field : fields){
            field.setAccessible(true);
            if(nameList.contains(field.getName())){
                try{
                    field.set(object,AESUtils.decrypt(String.valueOf(field.get(object))));
                } catch (Exception e) {
                    log.error("异常:",e);
                }
            }
        }
    }

    private MappedStatement getMappedStatement(DefaultResultSetHandler defaultResultSetHandler) throws IllegalAccessException {
        Field[] fields = defaultResultSetHandler.getClass().getDeclaredFields();
        for (Field field : fields) {
            field.setAccessible(true);
            if("mappedStatement".equals(field.getName())){
                return (MappedStatement) field.get(defaultResultSetHandler);
            }
        }
        return null;
    }

} 
  

以上就是基于mybatis拦截器实现数据脱敏,有问题的小伙伴可以在底下留言

你可能感兴趣的:(spring,boot,mybatis,java)