【MyBatis】plugin原理及分页插件实现

MyBatis拦截器

我们可以选择在这些被拦截的方法执行前后加上某些逻辑,也可以在执行这些被拦截的方法时执行自己的逻辑而不再执行被拦截的方法。
Mybatis拦截器设计的一个初衷就是为了供用户在某些时候可以实现自己的逻辑而不必去动Mybatis固有的逻辑。打个比方,对于Executor,Mybatis中有几种实现:BatchExecutor、ReuseExecutor、SimpleExecutorCachingExecutor。这个时候如果你觉得这几种实现对于Executor接口的query方法都不能满足你的要求,那怎么办呢?是要去改源码吗?当然不。我们可以建立一个Mybatis拦截器用于拦截Executor接口的query方法,在拦截之后实现自己的query方法逻辑,之后可以选择是否继续执行原来的query方法。

Interceptor

对于拦截器Mybatis为我们提供了一个Interceptor接口,通过实现该接口就可以定义我们自己的拦截器。

package org.apache.ibatis.plugin; 
import java.util.Properties; 
public interface Interceptor {
 Object intercept(Invocation invocation) throws Throwable; 
Object plugin(Object target); 
void setProperties(Properties properties);
 } 
  • intercept
    它将直接覆盖所拦截的对象的原有方法,它是插件的核心方法。intercept里面有一个参数Invocation对象,通过它可以反射调度原来对象的方法
  • plugin
    target是被拦截的对象,他的作用是给被拦截对象生成一个代理对象,并返回她。为了方便MyBatis使用Plugin.wrap()提供生成代理对象,我们往往使用plugin方法便可以生成一个代理对象了
  • setProperties
    允许在plugin元素中配置所需参数,方法在插件初始化的时候就被调用一次,然后把插件对象存入到配置中,以便后面取出。
    Mybatis拦截器只能拦截四种类型的接口:ExecutorStatementHandlerParameterHandlerResultSetHandler
定义插件签名
@Intercepts({@Signature(type = Executor.class, //确定要拦截的对象
        method = "update", //确定要拦截的方法
        args = {MappedStatement.class,Object.class} //拦截方法的参数
    )})
public class MyPlugin implements Interceptor {
......
}

@Intercepts说明他是一个拦截器。@Signature是注册拦截器签名的地方,只有签名满足条件才能拦截,type可以是四大对象中的一个。method代表要拦截的四大对象中的某一种接口的方法,args是该方法的参数,需要根据拦截对象方法的参数进行设置。

实现一个简单的插件

package com.excelib.plugin;

import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.plugin.*;

import java.util.Properties;

@Intercepts({@Signature(type = Executor.class, //确定要拦截的对象
        method = "update", //确定要拦截的方法
        args = {MappedStatement.class,Object.class} //拦截方法的参数
    )})
public class MyPlugin implements Interceptor {
    Properties properties=null;
    /**
     * 拦截方法的处理
     * @param invocation 责任链对象
     */
    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        System.err.println("before ....");
        //如果当前代理是一个非代理对象,那么它就会调用真实拦截对象的方法,如果不是他会回调下一个代理对象的代理接口的方法
        Object result = invocation.proceed();
        System.err.println("after .....");
        return result;
    }
    /**
     * 生成对象的代理,这里常用MyBatis提供的Plugin类的wrap方法
     * @param target 被代理的对象
     */
    @Override
    public Object plugin(Object target) {
        if (target instanceof Executor){
            //只是Executor才生成代理
            System.err.println("调用生成代理对象"+target.getClass());
            return Plugin.wrap(target,this);
        }
       return target;
    }

    /**
     * 获取配置文件的属性,我们在MyBatis的配置文件里面去配置
     * @param properties 是MyBatis配置的参数
     */
    @Override
    public void setProperties(Properties properties) {
        System.err.println(properties.get("dbType"));
        this.properties = properties;
    }
}

SqlSessionFactoryBean中配置plugins属性

  
            
            
            
            

            
            
                
                    
                
            
        

输出结果:

19:02:12,974 DEBUG SqlSessionUtils:54 - Creating a new SqlSession
调用生成代理对象class org.apache.ibatis.executor.ReuseExecutor
19:02:12,991 DEBUG SqlSessionUtils:54 - SqlSession [org.apache.ibatis.session.defaults.DefaultSqlSession@3a94964] was not registered for synchronization because synchronization is not active
19:02:13,008 DEBUG DataSourceUtils:110 - Fetching JDBC Connection from DataSource
19:02:13,260 DEBUG SpringManagedTransaction:54 - JDBC Connection [jdbc:mysql://47.94.102.25:3306/test?characterEncoding=UTF-8, [email protected], MySQL-AB JDBC Driver] will not be managed by Spring
19:02:13,263 DEBUG getUser:54 - ==>  Preparing: select * from user where id = ? 
19:02:13,321 DEBUG getUser:54 - ==> Parameters: 17(Integer)
19:02:13,359 DEBUG getUser:54 - <==      Total: 1
19:02:13,364 DEBUG SqlSessionUtils:54 - Closing non transactional SqlSession [org.apache.ibatis.session.defaults.DefaultSqlSession@3a94964]
19:02:13,364 DEBUG DataSourceUtils:327 - Returning JDBC Connection to DataSource
User{id=17, name='Jack', age=8}

分页插件

步骤
  1. 拦截StatementHandler
  2. 获取原查询sql,构建查询总条数的sql,获取connection,然后新建 BoundSql、ParameterHandler,处理之后执行查询操作
  3. 校验分页参数
  4. 改写原sql为分页sql,设置分页参数到statement中
  5. 执行查询操作,回填分页数据
import org.apache.ibatis.executor.parameter.ParameterHandler;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.plugin.*;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.SystemMetaObject;
import org.apache.ibatis.scripting.defaults.DefaultParameterHandler;
import org.apache.ibatis.session.Configuration;

import java.lang.reflect.InvocationTargetException;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.*;

@Intercepts({@Signature(type = StatementHandler.class, //确定要拦截的对象
        method = "prepare", //确定要拦截的方法
        args = {Connection.class,Integer.class} //拦截方法的参数
)})
public class PagingPlugin implements Interceptor {
    private Integer defaultPage;//默认页码
    private Integer defaultPageSize;//默认每页条数
    private Boolean defaultUseFlag;//默认是否启用插件
    private Boolean defaultCheckFlag;//默认是否检查当前页码的正确性



    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        StatementHandler stmtHandler = getUnProxyObject(invocation);
        MetaObject metaStatementHandler = SystemMetaObject.forObject(stmtHandler);
        String sql = (String) metaStatementHandler.getValue("delegate.boundSql.sql");
        if (!checkSelect(sql)){
            return invocation.proceed();
        }
        BoundSql boundSql = (BoundSql) metaStatementHandler.getValue("delegate.boundSql");
        Object parameterObject = boundSql.getParameterObject();
        PageParams pageParams = getPageParams(parameterObject);
        if (pageParams == null)
            return invocation.proceed();
        //获取分页参数,获取不到的时候使用默认值
        Integer page = pageParams.getPage()==null?this.defaultPage:pageParams.getPage();
        Integer pageSize = pageParams.getPageSize()==null?this.defaultPageSize:pageParams.getPageSize();
        Boolean useFlag = pageParams.getUseFlag() == null?this.defaultUseFlag:pageParams.getUseFlag();
        Boolean checkFlag = pageParams.getCheckFlag()==null?this.defaultCheckFlag:pageParams.getCheckFlag();
        if (!useFlag){
            return invocation.proceed();
        }
        int total = getTotal(invocation, metaStatementHandler, boundSql);
        //回填总数到分页参数里
        setTotalToPageParams(pageParams,total,pageSize);
        //检查当前页码的有效性
        checkPage(checkFlag,page,pageParams.getTotalPage());
        //修改SQL
        return changeSQL(invocation,metaStatementHandler,boundSql,page,pageSize);
    }

    /**
     * 从代理对象中分离出真实对象
     */
    private StatementHandler getUnProxyObject(Invocation invocation) {
        StatementHandler statementHandler = (StatementHandler) invocation.getTarget();
        MetaObject metaStatementHandler = SystemMetaObject.forObject(statementHandler);
        //分离代理对象链(由于目标类可能被多个拦截器拦截,从而形成多次代理,通过循环可以分离出最原始的目标类)
        Object object = null;
        while (metaStatementHandler.hasGetter("h")){
            object = metaStatementHandler.getValue("h");
        }
        if (object==null){
            return statementHandler;
        }
        return (StatementHandler) object;
    }
    /**
     * 判断是否是selec语句
     */
    private boolean checkSelect(String sql){
        String trimSql = sql.trim();
        int idx = trimSql.toLowerCase().indexOf("select");
        return idx == 0;
    }

    /**
     * 获取分页参数
     */
    private PageParams getPageParams(Object parameterObject){
        if (parameterObject==null)
            return null;
        PageParams pageParams = null;
        if (parameterObject instanceof Map){
            Map paramMap = (Map)parameterObject;
            Set keySet = paramMap.keySet();
            Iterator iterator = keySet.iterator();
            while (iterator.hasNext()){
                String key = iterator.next();
                Object value = paramMap.get(key);
                if (value instanceof PageParams){
                    return (PageParams) value;
                }
            }
        }else if (parameterObject instanceof PageParams){
            pageParams = (PageParams) parameterObject;
        }
        return pageParams;
    }

    /**
     * 获取总数
     * @param ivt
     * @param metaStatementHandler
     * @param boundSql
     * @return
     * @throws SQLException
     */
    private int getTotal(Invocation ivt, MetaObject metaStatementHandler, BoundSql boundSql) throws SQLException {
        MappedStatement mappedStatement = (MappedStatement) metaStatementHandler.getValue("delegate.mappedStatement");
        Configuration configuration = mappedStatement.getConfiguration();
        String sql = (String) metaStatementHandler.getValue("delegate.boundSql.sql");
        String countSql = "select count(*) as total from ( " +sql+" )$_paging";
        Connection connection = (Connection) ivt.getArgs()[0];
        PreparedStatement ps = null;
        int total = 0;
        try {
            ps = connection.prepareStatement(countSql);
            BoundSql countBoundSql = new BoundSql(configuration, countSql, boundSql.getParameterMappings(), boundSql.getParameterObject());
            ParameterHandler handler = new DefaultParameterHandler(mappedStatement, boundSql.getParameterObject(), countBoundSql);
            handler.setParameters(ps);
            //执行查询
            ResultSet resultSet = ps.executeQuery();
            while (resultSet.next()){
                total = resultSet.getInt("total");
            }
            return total;
        } finally {
            //这里不能关闭Connection,否则后续的SQL就没法继续了
            if (ps != null){
                ps.close();
            }
        }
    }

    /**
     * 回填总页数和总条数到分页参数
     * @param pageParams
     * @param total
     * @param pageSize
     */
    private void setTotalToPageParams(PageParams pageParams,int total,int pageSize){
        pageParams.setTotal(total);
        int totalPage = total%pageSize==0?total/pageSize:total/pageSize+1;
        pageParams.setTotalPage(totalPage);
    }

    /**
     * 检查当前页码的有效性
     * @param checkFlag
     * @param pageNum
     * @param pageTotal
     */
    private void checkPage(Boolean checkFlag,Integer pageNum,Integer pageTotal){
        if (checkFlag){
            //检查页码page是否合法
            if (pageNum>pageTotal){
                throw new IllegalArgumentException("查询失败,查询页码【"+pageNum+"】大于总页数【"+pageTotal+"】!!");
            }
        }
    }

    /**
     * 修改当前查询的SQL
     * @param invocation
     * @param metaStatementHandler
     * @param boundSql
     * @param page
     * @param pageSize
     * @return
     * @throws InvocationTargetException
     * @throws IllegalAccessException
     * @throws SQLException
     */
    private Object changeSQL(Invocation invocation,MetaObject metaStatementHandler,BoundSql boundSql,int page,int pageSize) throws InvocationTargetException, IllegalAccessException, SQLException {
        String sql = (String) metaStatementHandler.getValue("delegate.boundSql.sql");
        //修改SQL,这里使用的是MySQL,如果是其他数据库则需要修改
        String newSql = "select * from ( " +sql+" ) $_paging_table limit ?,?";
        //修改当前需要执行的SQL
        metaStatementHandler.setValue("delegate.boundSql.sql",newSql);
        //相当于套用了StatementHandler的prepare方法,预编译了当前SQL并设置原有的参数,但是少了两个分页参数,它返回的是一个PreParedStatement对象
        PreparedStatement ps = (PreparedStatement) invocation.proceed();
        //计算SQL总参数个数
        int count = ps.getParameterMetaData().getParameterCount();
        ps.setInt(count-1,Math.max(page-1,0)*pageSize);
        ps.setInt(count,pageSize);
        return ps;
    }

    @Override
    public Object plugin(Object target) {
        if (target instanceof StatementHandler){
            return Plugin.wrap(target,this);
        }
        return target;
    }

    @Override
    public void setProperties(Properties props) {
        defaultPage = Integer.parseInt(props.getProperty("default.page", "1"));
        defaultPageSize = Integer.parseInt(props.getProperty("default.pageSize", "20"));
        defaultUseFlag = Boolean.parseBoolean(props.getProperty("default.useFlag", "false"));
        defaultCheckFlag = Boolean.parseBoolean(props.getProperty("default.checkFlag", "false"));
    }

    public Integer getDefaultPage() {
        return defaultPage;
    }

    public void setDefaultPage(Integer defaultPage) {
        this.defaultPage = defaultPage;
    }

    public Integer getDefaultPageSize() {
        return defaultPageSize;
    }

    public void setDefaultPageSize(Integer defaultPageSize) {
        this.defaultPageSize = defaultPageSize;
    }

    public Boolean getDefaultUseFlag() {
        return defaultUseFlag;
    }

    public void setDefaultUseFlag(Boolean defaultUseFlag) {
        this.defaultUseFlag = defaultUseFlag;
    }

    public Boolean getDefaultCheckFlag() {
        return defaultCheckFlag;
    }

    public void setDefaultCheckFlag(Boolean defaultCheckFlag) {
        this.defaultCheckFlag = defaultCheckFlag;
    }
}

    
            
            
            
            

            
            
                
                    
                    
                        
                        
                        
                        
                        
                        
                        
                        
                    
                
            
        
    @Test
    public void testPagingPlugin(){
        ClassPathXmlApplicationContext context = new ClassPathXmlApplicationContext("classpath:applicationContext.xml");
        UserMapper userMapper = (UserMapper) context.getBean("userMapper");
        PageParams pageParams = new PageParams();
        pageParams.setPage(0);
        pageParams.setPageSize(2);
        User user = new User();
        user.setName("Jack");
        User result = userMapper.getUserByCondition(user,pageParams);
        System.out.println(result.toString());
    }

源码理解

我们先以Executor的拦截器为例看一下拦截器的调用过程,创建sqlSession的时候会首先创建一个Executor

public class DefaultSqlSessionFactory implements SqlSessionFactory {
    @Override
  public SqlSession openSession() {
    return openSessionFromDataSource(configuration.getDefaultExecutorType(), null, false);
  }
  private SqlSession openSessionFromDataSource(ExecutorType execType, TransactionIsolationLevel level, boolean autoCommit) {
    Transaction tx = null;
    try {
      final Environment environment = configuration.getEnvironment();
      final TransactionFactory transactionFactory = getTransactionFactoryFromEnvironment(environment);
      tx = transactionFactory.newTransaction(environment.getDataSource(), level, autoCommit);
      final Executor executor = configuration.newExecutor(tx, execType);
      return new DefaultSqlSession(configuration, executor, autoCommit);
    } catch (Exception e) {
      closeTransaction(tx); // may have fetched a connection so lets call close()
      throw ExceptionFactory.wrapException("Error opening session.  Cause: " + e, e);
    } finally {
      ErrorContext.instance().reset();
    }
  }
}

//org.apache.ibatis.session.Configuration#newExecutor(org.apache.ibatis.transaction.Transaction, org.apache.ibatis.session.ExecutorType)
  public Executor newExecutor(Transaction transaction, ExecutorType executorType) {
    executorType = executorType == null ? defaultExecutorType : executorType;
    executorType = executorType == null ? ExecutorType.SIMPLE : executorType;
    Executor executor;
    if (ExecutorType.BATCH == executorType) {
      executor = new BatchExecutor(this, transaction);
    } else if (ExecutorType.REUSE == executorType) {
      executor = new ReuseExecutor(this, transaction);
    } else {
      executor = new SimpleExecutor(this, transaction);
    }
    if (cacheEnabled) {
      executor = new CachingExecutor(executor);
    }
    //使用拦截器链创建代理
    executor = (Executor) interceptorChain.pluginAll(executor);
    return executor;
  }

创建Executor的时候调用了interceptorChain.pluginAll(executor);并将其返回值传给了sqlSession。进入pluginAll()可以知道InterceptorChain中维护了一个interceptors的list,调用pluginAll();时会遍历所有的interceptors调用interceptor.plugin();并将前一个interceptor.plugin()的返回结果作为下一个interceptor.plugin()的参数,其实每次调用interceptor.plugin()我们一般都会生成一个动态代理,这样就形成了一个动态代理链。

public class InterceptorChain {
  private final List interceptors = new ArrayList();
  public Object pluginAll(Object target) {
    for (Interceptor interceptor : interceptors) {
      target = interceptor.plugin(target);
    }
    return target;
  }
  public void addInterceptor(Interceptor interceptor) {
    interceptors.add(interceptor);
  }
}

而在plugin中我们判断如果是我们需要拦截的类的实例才生成代理,需要这样判断是因为InterceptorChain中会不加区别的调用interceptor.plugin()然后传给下一层,并没有处理我们的注解。生成动态代理的时候一般式调用Plugin.wrap(),这个方法生成的动态代理在调用方法的时候会判断是否是我们需要拦截的方法,如果是的话就会回调interceptor.intercept()。

@Override
    public Object plugin(Object target) {
        if (target instanceof StatementHandler){
            return Plugin.wrap(target,this);
        }
        return target;
    }

public class Plugin implements InvocationHandler {

  private final Object target;
  private final Interceptor interceptor;
  private final Map, Set> signatureMap;

  private Plugin(Object target, Interceptor interceptor, Map, Set> signatureMap) {
    this.target = target;
    this.interceptor = interceptor;
    this.signatureMap = signatureMap;
  }

  public static Object wrap(Object target, Interceptor interceptor) {
    //获取方法签名,key是拦截的类value是锁兰姐的方法 
    Map, Set> signatureMap = getSignatureMap(interceptor);
    Class type = target.getClass();
    //获取拦截类实现的所有接口
    Class[] interfaces = getAllInterfaces(type, signatureMap);
    if (interfaces.length > 0) {
      //创建代理对象
      return Proxy.newProxyInstance(
          type.getClassLoader(),
          interfaces,
          new Plugin(target, interceptor, signatureMap));
    }
    return target;
  }

  @Override
  public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
    try {
      Set methods = signatureMap.get(method.getDeclaringClass());
      if (methods != null && methods.contains(method)) {
       //如果调用的代理类的方法是我需要拦截的方法则回调interceptor.intercept()
        return interceptor.intercept(new Invocation(target, method, args));
      }
      //如果不是我们需要拦截的方法则直接调用原方法
      return method.invoke(target, args);
    } catch (Exception e) {
      throw ExceptionUtil.unwrapThrowable(e);
    }
  }
  .....
}

如果还需要调用下一层代理或者原对象的方法则直接调用Invocation.proceed()

public class Invocation {
......
  private final Object target;
  private final Method method;
  private final Object[] args;
  public Object proceed() throws InvocationTargetException, IllegalAccessException {
    return method.invoke(target, args);
  }
}

在interceptor.intercept中我们可以做自己的逻辑处理,比如分页、缓存(拦截Excutor.doQuery)等插件实现起来都比较方便。

你可能感兴趣的:(【MyBatis】plugin原理及分页插件实现)