MYBATIS 拦截器实现修改留痕

最近准备做修改留痕的业务模块,准备使用注解+mybaits拦截器的方式来实现,这里展示一下代码和实现思路,目前做的这个版本,限定只能拦截到根据ID来更新数据的方法,比较简单。

第一步编写相关实体类

  1. UserMo 用来测试需要记录的实体类
@Data
@TableName("t_user")
public class UerMo implements Serializable {

    @RecordField(name = "ID")
    private Long id;

    @RecordField(name = "姓名")
    private String xm;

    @RecordField(name = "姓名", type= "file")
    private String avatar;
}

2.RecordModel 用来存放被记录的对象

@Data
@AllArgsConstructor
@NoArgsConstructor
public class RecordModel {

    /**
     * 数据ID
     */
    private String dataId;

    /**
     * 类
     */
    private Class classType;

    /**
     * 老数据
     */
    private Object oldData;

}
  1. markMo 留痕记录表
@Data
@TableName("t_mark")
public class markMo {

    // id
    private Long id;

    // 旧值
    private String oldValue;


    // 新值
    private String newValue;
    
}

第步编写注解类

我们实现一个注解,并对这个注解做一个切面

  1. RecordField 用来标记需要被记录的字段
@Target({ElementType.FIELD})
@Retention(RetentionPolicy.RUNTIME)
public @interface RecordField {

    public String name();

    public String type() default "";

}
  1. MarkUpdate.java 用来标记需要记录的方法
@Documented
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.METHOD})
public @interface MarkUpdate {
    
}
  1. MarkUpdateAspect 用来对被MarkUpdate注解的方法做一个切面
@Aspect
@Order(99)
@Component
@AllArgsConstructor
public class MarkUpdateAspect {

    @Autowired
    MarkHandlerDispatch markHandlerDispatch;

    private static final Map> TEM_MAP = new ConcurrentHashMap<>();

    /**
     * 判断是否有该注解
     *
     * @param threadName
     * @return
     */
    public static boolean hasThread(Long threadName) {
        return TEM_MAP.containsKey(threadName);
    }

    /**
     * 放入前置数据
     *
     * @param threadName
     * @param RecordModel
     */
    public static void put(Long threadName, RecordModel recordModel) {
        if (TEM_MAP.containsKey(threadName)) {
            TEM_MAP.get(threadName).add(recordModel);
        }
    }


    @SneakyThrows
    @Before("@annotation(markUpdate)")
    public void before(JoinPoint joinPoint, MarkUpdate markUpdate) {
        // 获取线程名,使用线程名作为同一次操作记录
        Long threadName = Thread.currentThread().getId();
        TEM_MAP.remove(threadName);
        TEM_MAP.put(threadName, new LinkedList<>());
    }


    @SneakyThrows
    @AfterReturning("@annotation(markUpdate)")
    public void after(JoinPoint joinPoint, MarkUpdate markUpdate) {
        // 获取线程名,使用线程名作为同一次操作记录
        Long threadName = Thread.currentThread().getId();
        if (TEM_MAP.get(threadName) == null) {
            return;
        }
        for (RecordModel recordModel: TEM_MAP.get(threadName)) {
            markHandlerDispatch.record(recordModel);
        }
        // 移除当前线程
        TEM_MAP.remove(threadName);
    }
}

第三步实现mybatis拦截器

@Slf4j
@Intercepts({@Signature(type = StatementHandler.class, method = "update", args = {Statement.class})})
public class MarkUpdateInterceptor extends AbstractSqlParserHandler implements Interceptor {

    @Autowired
    MarkHandlerDispatch markHandlerDispatch;

    @Override
    @SneakyThrows
    public Object intercept(Invocation invocation) throws Throwable {
        String sqlCommandType = getSqlCommandType(invocation);
        if (sqlCommandType.equals("UPDATE")) {
            Long threadName = Thread.currentThread().getId();
            Class classType = getClassType(invocation);
            if (classType == null) {
                return invocation.proceed();
            }
            // 判断该类型是否需要被记录
            if (!markHandlerDispatch.isCanMark(classType)) {
                return invocation.proceed();
            }
            // 获取更新ID
            String id = getUpdateId(invocation);
            if (StringUtils.isEmpty(id)) {
                return invocation.proceed();
            }
            // 获取老数据
            Object oldData = markHandlerDispatch.getData(classType, id);
            Object proceedObj = invocation.proceed();
            // 生成记录对象 并放入线程Map中等待处理
            RecordModel recordModel = new RecordModel (id, classType, oldData);
            if (MarkUpdateAspect.hasThread(threadName)) {
                MarkUpdateAspect.put(threadName, recordModel);
            }
            return proceedObj;
        }
        return invocation.proceed();
    }


    /**
     * 获取语句的类型
     *
     * @param invocation
     * @return
     */
    public String getSqlCommandType(Invocation invocation) {
        StatementHandler statementHandler = PluginUtils.realTarget(invocation.getTarget());
        MetaObject metaObject = SystemMetaObject.forObject(statementHandler);
        this.sqlParser(metaObject);
        MappedStatement mappedStatement = (MappedStatement) metaObject.getValue("delegate.mappedStatement");
        String sqlCommandType = mappedStatement.getSqlCommandType().toString();
        return sqlCommandType;
    }

    /**
     * 获取更新的类
     *
     * @param invocation
     * @return
     */
    public Class getClassType(Invocation invocation) {
        try {
            String sql = getOriginSql(invocation);
            Collection tables = new TableNameParser(sql).tables();
            if (CollectionUtils.isEmpty(tables)) {
                return null;
            }
            String tableName = tables.iterator().next();
            TableInfo tableInfo = TableInfoHelper.getTableInfos().stream().filter(item -> {
                return item.getTableName().equals(tableName);
            }).findFirst().orElse(new TableInfo(null));
            Class entityType = tableInfo.getEntityType();
            return entityType;
        } catch (Exception e) {
            return null;
        }
    }

    /**
     * 获取SQL语句
     *
     * @param invocation
     * @return
     */
    public String getOriginSql(Invocation invocation) {
        Statement statement = getStatement(invocation);
        String originalSql = statement.toString();
        return originalSql;
    }

    /**
     * 获取Statement
     *
     * @param invocation
     * @return
     */
    public Statement getStatement(Invocation invocation) {
        Object firstArg = invocation.getArgs()[0];
        Statement statement = null;
        if (Proxy.isProxyClass(firstArg.getClass())) {
            statement = (Statement) SystemMetaObject.forObject(firstArg).getValue("h.statement");
        } else {
            statement = (Statement) firstArg;
        }
        MetaObject stmtMetaObj = SystemMetaObject.forObject(statement);
        try {
            statement = (Statement) stmtMetaObj.getValue("stmt.statement");
        } catch (Exception e) {
            // do nothing
        }
        if (stmtMetaObj.hasGetter("delegate")) {
            //Hikari
            try {
                statement = (Statement) stmtMetaObj.getValue("delegate");
            } catch (Exception ignored) {

            }
        }
        return statement;
    }

    /**
     * 获取更新数据的ID
     *
     * @param invocation
     * @return
     */
    public String getUpdateId(Invocation invocation) {
        try {
            Map conditionMap = new HashMap<>();
            String sql = getOriginSql(invocation);
            Integer index = sql.lastIndexOf("WHERE");
            String allCondition = sql.substring(index, sql.length()).replaceAll("WHERE", "").replaceAll("\\(", "").replaceAll("\\)", "");
            String[] conditionArr = allCondition.split("AND|OR");
            Arrays.stream(conditionArr).forEach(item -> {
                if (item.lastIndexOf("=") > 0) {
                    List condition = Arrays.stream(item.split("=")).collect(Collectors.toList());
                    conditionMap.put(condition.get(0).trim().toLowerCase(), condition.get(1).trim());
                }
            });
            return conditionMap.get("id");
        } catch (Exception e) {
            return null;
        }
    }
}

第四步实现记录处理类

MarkHandlerDispatch 用来记录数据到表中

@Slf4j
@Component
public class MarkHandlerDispatch {

    @Lazy
    @Autowired
    UserMapper userMapper;


    @Lazy
    @Autowired
    MarkMapper markMapper;

    @Lazy
    @Autowired
    TrueValueConventHandlerDispatch trueValueConventHandlerDispatch ;

    


    private static List needHandlerClass = new ArrayList<>();

    static {
        // 放入需要被标记的实体类
        needHandlerClass.add(UserMo.class);
    }

    /**
     * 获取实体数据
     * 
     */
    public Object getData(Class type, String id) {
        if (type == UserMo.class) {
            return userMapper.selectById(Long.valueOf(id));
        }
        return null;
    }



    /**
     * 获取实体数据
     * 这里需要异步处理,否则还是获取的老数据
     */
    @Async
    public void record(RecordModel recordModel) {
        // 获取最新的数据
        Object updateData = this.getData(recordModel.getClassType(), recordModel.getDataId());
        if (recordModel.getClassType() == UserMo.class) {
            UserMo userMo = (UserMo) updateData;
            record(recordModel.getOldData(), updateData, "用户信息");
        }
    }

    /**
     * 记录
     */
    public void record(Object oldData, Object updateData, Class type) {
        try {
            Map oldDataMap = trueValueConventHandlerDispatch.dispatchConvert(type, oldData);
            Map updateDataMap = trueValueConventHandlerDispatch.dispatchConvert(type, updateData);
            List markList = getNeedRecordMap(type, oldDataMap, updateDataMap);
            markMapper.saveBatch(markList);
        } catch (Exception e) {
            log.error("记录失败", e);
        }
    }

    /**
     *  通过比较,获取所有需要记录的列表数据
     *
     * @param oldDataMap
     * @param updateDataMap
     * @return
     */
    public List getNeedRecordMap(String type, Map oldDataMap, Map updateDataMap) {
        List markList = new ArrayList<>();
        Set keySet = new HashSet<>();
        keySet.addAll(oldDataMap.keySet());
        keySet.addAll(updateDataMap.keySet());
        for (String key : keySet) {
            String oldValue = oldDataMap.get(key);
            String newValue = updateDataMap.get(key);
            if (oldValue != null && oldValue.equals(newValue)) {
                continue;
            }
            if (newValue != null && newValue.equals(oldValue)) {
                continue;
            }
            MarkMo markMo = new MarkMo();
            markMo.setOldValue(oldValue)
            markMo.setNewValue(newValue)
            markMo.setType(type)
            markList.add(markMo);
        }
        return markList;
    }


    /**
     * 判断是否需要记录
     *
     * @param type
     * @return
     */
    public Boolean isCanMark(Class type) {
        Long count = needHandlerClass.stream().filter(item -> {
            return item == type;
        }).count();
        return count > 0;
    }

}

第5步 编写真实值转换类

@Component
@Slf4j
public class TrueValueConventHandlerDispatch {

    // 需要被记录的基本数据类型
    private Class[] baseTypeList = {Long.class, String.class, Integer.class};

    public Map dispatchConvert(Class type, Object data) {
        if (type == UserMo.class) {
            UserMo userMo = (UserMo) data;
            Map result = baseConvert(userMo);
            // 这里还可以特殊定制需要记录的值
        }
        return result;
    }

    /**
     * 默认将对象转换为key - value 真实表示的值
     * @param o
     * @return
     */
    protected Map baseConvert(Object o) {
        Map result = new HashMap<>();
        try {
            for (Field declaredField : o.getClass().getDeclaredFields()) {
                declaredField.setAccessible(true);
                // 获取到RecordField注解标识的列
                RecordField recordField = declaredField.getAnnotation(RecordField.class);
                if (recordField == null) {
                    continue;
                }
                Object value = declaredField.get(o);
                if (value == null) {
                    continue;
                }
                // 如果是基本类型
                if (isBaseType(declaredField.getType())) {
                    // 如果是别的类型,则采用特殊方式记录
                    if (StringUtils.isNotEmpty(recordField.type())) {
                        result.put(recordField.name(), "(" + recordField.type() + ")" + value);
                    } else {
                        result.put(recordField.name(), String.valueOf(value));
                    }
                }
                // 如果是特殊类型,比如时间日期类型
                if (declaredField.getType() == Date.class) {
                    SimpleDateFormat simpleDateFormat = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");
                    result.put(recordField.name(), simpleDateFormat.format(value));
                }
            }
        } catch (Exception e) {
            log.error("记录转换失败", e);
            throw new BizException("转换记录失败");
        }
        return result;
    }

    /**
     * 判断是否是基本类型
     * @param clazz
     * @return
     */
    private Boolean isBaseType(Class clazz) {
        return Arrays.stream(baseTypeList).anyMatch(item -> {
            return clazz == item;
        });
    }


}

第6步,注册mybatis拦截器

    @Bean
    @Profile({"prod", "dev", "local"})
    @ConditionalOnMissingBean
    @ConditionalOnBean(AbstractMarkHandle.class)
    public MarkUpdateInterceptor markUpdateInterceptor() {
        return new MarkUpdateInterceptor();
    }

第7步,在需要留痕的地方加上注解

@Slf4j
@Service
public class UserServiceImpl implements UserService {
    
    @Autowired
    UserMapper userMapper;  
  
    @MarkUpdate
    @Override
    public void udpate(UserMo userMo) {
        userMapper.updateById(userMo);
    }
}

你可能感兴趣的:(MYBATIS 拦截器实现修改留痕)