MyBatis 二级缓存,关联查询缓存清除拦截器

依赖jsqlparser解析


import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.statement.Statement;
import net.sf.jsqlparser.statement.delete.Delete;
import net.sf.jsqlparser.statement.insert.Insert;
import net.sf.jsqlparser.statement.select.Select;
import net.sf.jsqlparser.statement.update.Update;
import net.sf.jsqlparser.util.TablesNamesFinder;
import org.apache.ibatis.cache.Cache;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.plugin.*;
import org.apache.ibatis.session.Configuration;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;

import java.io.Serializable;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;

@Slf4j
@Intercepts({
        @Signature(method = "query", type = Executor.class, args = {MappedStatement.class, Object.class,
                RowBounds.class, ResultHandler.class}),
        @Signature(method = "update", type = Executor.class, args = {MappedStatement.class, Object.class})})
public class CachingInterceptor implements Interceptor {
    //table-[namespace...]
    private final static ConcurrentMap namespaceConcurrentMap = new ConcurrentHashMap<>();
    //id-[table...]
    private final static ConcurrentMap tableConcurrentMap = new ConcurrentHashMap<>();

    public Object intercept(Invocation invocation) throws Throwable {
        Object[] args = invocation.getArgs();
        MappedStatement ms = (MappedStatement) args[0];

        if (!ms.getConfiguration()
                .isCacheEnabled() || ms.getCache() == null) {
            return invocation.proceed();
        }

        BoundSql boundSql = ms.getBoundSql(args[1]);
        final String operate = invocation.getMethod()
                .getName();
        //todo 缓存id->table,table->namespace
        cacheIfRequired(ms.getCache()
                .getId(), ms.getId(), boundSql.getSql());
        if ("update".equals(operate)) {
            //todo 根据id->table,清除table->namespace对应缓存
            flushCacheIfRequired(ms.getId(), ms.getConfiguration());
        }
        return invocation.proceed();
    }


    public Object plugin(Object target) {
        return Plugin.wrap(target, this);
    }

    public void setProperties(Properties properties) {
    }


    public void cacheIfRequired(String namespace, String id, String bindSql) {
        StatementTable statementTable = tableConcurrentMap.get(id);
        if (statementTable != null && statementTable.getTables() != null && statementTable.getBindSql()
                .equals(bindSql)) {
            if (log.isDebugEnabled()) {
                for (String table : statementTable.getTables()) {
                    StatementNamespace statementNamespace = namespaceConcurrentMap.get(table);
                    if (statementNamespace == null) {
                        continue;
                    }
                    log.debug("already mapped [ id -> table - namespace ]:{} - {} - {}", id, table, statementNamespace.getNamespaces());
                }
            }
            return;
        }

        final List tableNames = getTableList(bindSql);
        if (tableNames == null || tableNames.size() == 0) {
            return;
        }
        tableConcurrentMap.put(id, new StatementTable(bindSql, new HashSet<>(tableNames)));
        for (String table : tableNames) {
            StatementNamespace statementNamespace = namespaceConcurrentMap.get(table);
            if (statementNamespace != null) {
                statementNamespace.addNamespace(namespace);
            } else {
                namespaceConcurrentMap.put(table, new StatementNamespace(namespace));
            }

            log.debug("resolved and cached [ id -> table - namespace ]:{} - {} - {}", id, table, namespaceConcurrentMap.get(table)
                    .getNamespaces());
        }
    }

    private void flushCacheIfRequired(String id, Configuration configuration) {
        StatementTable statementTable = tableConcurrentMap.get(id);
        if (statementTable == null || statementTable.getTables() == null || statementTable.getTables()
                .size() == 0) {
            return;
        }


        for (String table : statementTable.getTables()) {
            StatementNamespace statementNamespace = namespaceConcurrentMap.get(table);
            if (statementNamespace == null || statementNamespace.getNamespaces()
                    .isEmpty()) {
                continue;
            }

            for (String namespace : statementNamespace.getNamespaces()) {
                Cache cache = configuration.getCache(namespace);
                if (cache != null) {
                    log.debug("clearing mapped [ id -> table - namespace -({})]:{} - {} - {}", cache.getSize(), id, table, namespace);
                    cache.clear();
                    log.debug("clear mapped [ id -> table - namespace -({})]:{} - {} - {}", cache.getSize(), id, table, namespace);
                }
            }
        }

    }

    /**
     * 解析SQL中包含的表名
     *
     * @param sql
     * @return
     */
    public List getTableList(final String sql) {

        try {
            String[] sqls = sql.split(";");
            List tableList = new ArrayList<>();
            for (String _sql : sqls) {
                if (_sql.trim()
                        .length() == 0) {
                    continue;
                }
                Statement statement = CCJSqlParserUtil.parse(_sql);
                TablesNamesFinder tablesNamesFinder = new TablesNamesFinder();
                if (statement instanceof Select) {
                    Select selectStatement = (Select) statement;
                    tableList.addAll(tablesNamesFinder.getTableList(selectStatement));
                } else if (statement instanceof Update) {
                    Update updateStatement = (Update) statement;
                    tableList.addAll(tablesNamesFinder.getTableList(updateStatement));
                } else if (statement instanceof Insert) {
                    Insert insertStatement = (Insert) statement;
                    tableList.addAll(tablesNamesFinder.getTableList(insertStatement));
                } else if (statement instanceof Delete) {
                    Delete deleteStatement = (Delete) statement;
                    tableList.addAll(tablesNamesFinder.getTableList(deleteStatement));
                }
            }
            log.debug("SQL - [{}] - contains:" + tableList, sql);
            return tableList;
        } catch (JSQLParserException e) {
            log.error(" resolve sql error:" + sql, e);
        }
        return null;
    }


    @Data
    @NoArgsConstructor
    class StatementTable implements Serializable {
        private Set tables = new HashSet<>();
        private String bindSql;

        public StatementTable(String bindSql, Set tables) {
            this.bindSql = bindSql;
            this.tables.addAll(tables);
        }
    }

    @Data
    @NoArgsConstructor
    class StatementNamespace implements Serializable {
        private Set namespaces = new HashSet<>();

        public StatementNamespace(String namespace) {
            this.namespaces.add(namespace);
        }

        public void addNamespace(String namespace) {
            this.namespaces.add(namespace);
        }

    }
}

 

你可能感兴趣的:(Java,ee,MyBatis)