mybatis自定义拦截器拦截sql,处理createTime,updateTime,createBy,updateBy等问题

数据表的设计过程中,createTime,updateTime,createBy,updateBy,delFalg字段每张表必不可少。使用mybaits框架操作数据库时每张表的增删改查都需要涉及这几个字段,因此可以自定义拦截器对公共字段进行统一处理。xml文件的sql中可以不用写公共字段,由拦截器进行赋值处理,或抛出异常。话不多说,直接上代码。


@Intercepts
        ({
                @Signature(type = Executor.class, method = "update", args = {MappedStatement.class, Object.class}),
                @Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}),
        })
@Component
public class MybatisInterceptor implements Interceptor {

    Logger logger = LoggerFactory.getLogger(this.getClass());

    // 这是对应上面的args的序号
    public static final Integer INDEX_ZERO = 0;
    public static final Integer INDEX_ONE = 1;
    public static final String CREATE_TIME_COLUMN = "create_time";
    public static final String CREATE_TIME_FEILD = "createTime";
    public static final String CREATE_BY_FEILD = "createBy";
    public static final String CREATE_BY_COLUMN = "create_by";
    public static final String UPDATE_BY_COLUMN = "update_by";
    public static final String UPDATE_BY_FEILD = "updateBy";
    public static final String UPDATE_TIME_FEILD = "updateTime";
    public static final String UPDATE_TIME_COLUMN = "update_time";
    public static final String DEL_FLAG_COLUMN = "del_flag";
    public static final String DEL_FLAG_FEILD = "delFlag";
    public static final String DEL_FLAG_DEFAULT = "0";


    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        logger.info("进入公共拦截器");
        // 获取参数
        Object[] queryArgs = invocation.getArgs();
        MappedStatement mappedStatement = (MappedStatement) queryArgs[INDEX_ZERO];
        Object object = queryArgs[INDEX_ONE];
        // 获取sql
        BoundSql boundSql = mappedStatement.getBoundSql(object);
        String sql = boundSql.getSql();

        // 获取sql参数列表
        List<ParameterMapping> parameterMappings = boundSql.getParameterMappings();
        List<String> propertyList = parameterMappings.stream().map(parameterMapping -> parameterMapping.getProperty()).collect(Collectors.toList());


        // 根据sql类型对应不同操作
        SqlCommandType sqlCommandType = mappedStatement.getSqlCommandType();
        if (SqlCommandType.INSERT.equals(sqlCommandType)) {
            // 插入操作时,自动插入
            autoFildInsertFeild(object);
            // 拼接预编译sql
            sql = appendInsertSql(sql);
            // 添加参数
            autoFildInsertParams(mappedStatement, parameterMappings, propertyList);
        }
        if (SqlCommandType.UPDATE.equals(sqlCommandType)) {
            // 更新操作时,自动插入
            autoFildUpdateFeild(object);
            // 拼接预编译sql
            sql = appendUpdateSql(sql);
            // 添加参数
            autoFildUpdateParams(mappedStatement, parameterMappings, propertyList);
        }
        if (SqlCommandType.SELECT.equals(sqlCommandType)) {
            sql = appendQuerySql(sql);
        }


        // 重新new一个查询语句对像
        BoundSql newBoundSql = new BoundSql(mappedStatement.getConfiguration(), sql, parameterMappings, boundSql.getParameterObject());
        // 把新的查询放到statement里
        MappedStatement newMs = copyFromMappedStatement(mappedStatement, new BoundSqlSqlSource(newBoundSql));
        queryArgs[INDEX_ZERO] = newMs;
        return invocation.proceed();
    }

    private String appendQuerySql(String sql) throws JSQLParserException, ParseException {
        Statement statement = CCJSqlParserUtil.parse(sql);
        TablesNamesFinder tablesNamesFinder = new TablesNamesFinder();
        List<String> tableList = tablesNamesFinder.getTableList(statement);
        if (!sql.substring(sql.indexOf("from")).contains("del_flag")) {
            if (sql.contains("where")) {
                sql = new StringBuilder().append(sql, 0, sql.indexOf("where") + 5).append("del_flag = 0 and").append(sql.substring(sql.indexOf("where") + 5)).toString();
            } else {
//                new StringBuffer().append(sql).append()
            }
        }


        return sql;
    }

    private void autoFildUpdateParams(MappedStatement mappedStatement, List<ParameterMapping> parameterMappings, List<String> propertyList) {
        if (!propertyList.contains(UPDATE_TIME_FEILD)) {
            parameterMappings.add(0, newParameterMapping(mappedStatement.getConfiguration(), UPDATE_TIME_FEILD, Date.class));
        }
        if (!propertyList.contains(UPDATE_BY_FEILD)) {
            parameterMappings.add(0, newParameterMapping(mappedStatement.getConfiguration(), UPDATE_BY_FEILD, String.class));

        }
    }

    private String appendUpdateSql(String sql) {
        if (!sql.contains(UPDATE_TIME_COLUMN)) {
            sql = orgUpdateSql(sql, UPDATE_TIME_COLUMN);
        }
        if (!sql.contains(UPDATE_BY_COLUMN)) {
            sql = orgSql(sql, UPDATE_BY_COLUMN);
        }
        return sql;
    }

    private void autoFildUpdateFeild(Object object) throws NoSuchFieldException, IllegalAccessException {
        setFeildValue(object, UPDATE_BY_FEILD);
        setFeildValue(object, UPDATE_TIME_FEILD);
    }

    private void autoFildInsertParams(MappedStatement mappedStatement, List<ParameterMapping> parameterMappings, List<String> propertyList) {
        if (!propertyList.contains(CREATE_TIME_FEILD)) {
            parameterMappings.add(newParameterMapping(mappedStatement.getConfiguration(), CREATE_TIME_FEILD, Date.class));
        }
        if (!propertyList.contains(CREATE_BY_FEILD)) {
            parameterMappings.add(newParameterMapping(mappedStatement.getConfiguration(), CREATE_BY_FEILD, String.class));
        }
        if (!propertyList.contains(DEL_FLAG_FEILD)) {
            parameterMappings.add(newParameterMapping(mappedStatement.getConfiguration(), DEL_FLAG_FEILD, Integer.class));
        }
    }

    private String appendInsertSql(String sql) {
        if (!sql.contains(CREATE_TIME_COLUMN)) {
            sql = orgSql(sql, CREATE_TIME_COLUMN);
        }
        if (!sql.contains(CREATE_BY_COLUMN)) {
            sql = orgSql(sql, CREATE_BY_COLUMN);
        }
        if (!sql.contains(DEL_FLAG_COLUMN)) {
            sql = orgSql(sql, DEL_FLAG_COLUMN);
        }
        return sql;
    }

    private void autoFildInsertFeild(Object object) throws IllegalAccessException, NoSuchFieldException {
        setFeildValue(object, CREATE_TIME_FEILD);
        setFeildValue(object, CREATE_BY_FEILD);
        setFeildValue(object, DEL_FLAG_FEILD);
    }

    private ParameterMapping newParameterMapping(Configuration configuration, String property, Class<?> javaType) {
        return new ParameterMapping.Builder(configuration, property, javaType).build();
    }

    private void setFeildValue(Object object, String filedName) throws IllegalAccessException, NoSuchFieldException {
        Field field;
        Class<?> superclass = object.getClass().getSuperclass();
        object.getClass().getSuperclass();
        Field[] declaredFields = object.getClass().getSuperclass().getDeclaredFields();
        for (Field field1 : declaredFields) {
            System.out.println(field1.getName());
        }
        try {
            field = object.getClass().getSuperclass().getDeclaredField(filedName);
        } catch (NoSuchFieldException e) {
            throw new NoSuchFieldException("没有字段" + filedName);
        }
        field.setAccessible(true);
        Object createValue = field.get(object);
        if (createValue != null) {
            return;
        }
        //  获取字段类型,根据字段类型赋值
        String typeName = field.getType().getName();
        if (("java.util.Date").equals(typeName)) {
            field.set(object, new Date());
        }
        if ("java.lang.String".equals(typeName)) {
            field.set(object, "system");
        }
        if ("java.lang.Integer".equals(typeName)) {
            field.set(object, 0);
        }
    }

    private String orgUpdateSql(String sql, String param) {
        String newSql = new StringBuffer().append(sql, 0, sql.indexOf("set") + 3)
                .append(" ")
                .append(param).append("=?, ")
                .append(sql.substring(sql.indexOf("set") + 3)).toString();

        return newSql;
    }

    private String orgSql(String sql, String param) {
        StringBuffer sb = new StringBuffer();
        sb.append(sql, 0, sql.indexOf(")"))
                .append(",").append(param)
                .append(sql, sql.indexOf(")"), sql.lastIndexOf(")")).append(",?)");
        return sb.toString();
    }

    @Override
    public Object plugin(Object target) {
        return Plugin.wrap(target, this);
    }

    @Override
    public void setProperties(Properties properties) {

    }

    private MappedStatement copyFromMappedStatement(MappedStatement ms, SqlSource newSqlSource) {
        MappedStatement.Builder builder = new MappedStatement.Builder(ms.getConfiguration(), ms.getId(), newSqlSource, ms.getSqlCommandType());
        builder.resource(ms.getResource());
        builder.fetchSize(ms.getFetchSize());
        builder.statementType(ms.getStatementType());
        builder.keyGenerator(ms.getKeyGenerator());
        if (ms.getKeyProperties() != null && ms.getKeyProperties().length > 0) {
            builder.keyProperty(ms.getKeyProperties()[0]);
        }
        builder.timeout(ms.getTimeout());
        builder.parameterMap(ms.getParameterMap());
        builder.resultMaps(ms.getResultMaps());
        builder.resultSetType(ms.getResultSetType());
        builder.cache(ms.getCache());
        builder.flushCacheRequired(ms.isFlushCacheRequired());
        builder.useCache(ms.isUseCache());
        return builder.build();
    }

    public class BoundSqlSqlSource implements SqlSource {
        private BoundSql boundSql;

        public BoundSqlSqlSource(BoundSql boundSql) {
            this.boundSql = boundSql;
        }

        @Override
        public BoundSql getBoundSql(Object parameterObject) {
            return boundSql;
        }
    }

你可能感兴趣的:(mybatis)