MyBatis分页插件的实现

分页插件的核心部分由两个类组成,PageInterceptor拦截器类和数据库方言接口Dialect

1、PageInterceptor拦截器类

package com.pbyang.web.pagination;
import org.apache.ibatis.cache.CacheKey;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.ResultMap;
import org.apache.ibatis.mapping.ResultMapping;
import org.apache.ibatis.plugin.*;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;

import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Properties;

/**
 * Mybatis - 通用分页拦截器
 */
@SuppressWarnings({"rawtypes", "unchecked"})
@Intercepts(
	@Signature(
		type = Executor.class, 
		method = "query", 
		args = {MappedStatement.class, Object.class, 
				RowBounds.class, ResultHandler.class}
	)
)
public class PageInterceptor implements Interceptor{
    private static final List EMPTY_RESULTMAPPING
    		= new ArrayList(0);
    private Dialect dialect;
    private Field additionalParametersField;

    public Object intercept(Invocation invocation) throws Throwable {
        //获取拦截方法的参数
        Object[] args = invocation.getArgs();
        MappedStatement ms = (MappedStatement) args[0];
        Object parameterObject = args[1];
        RowBounds rowBounds = (RowBounds) args[2];
        //调用方法判断是否需要进行分页,如果不需要,直接返回结果
        if (!dialect.skip(ms.getId(), parameterObject, rowBounds)) {
        	ResultHandler resultHandler = (ResultHandler) args[3];
            //当前的目标对象
            Executor executor = (Executor) invocation.getTarget();
            BoundSql boundSql = ms.getBoundSql(parameterObject);
            //反射获取动态参数
            Map additionalParameters = 
            		(Map) additionalParametersField.get(boundSql);
            //判断是否需要进行 count 查询
            if (dialect.beforeCount(ms.getId(), parameterObject, rowBounds)){
            	//根据当前的 ms 创建一个返回值为 Long 类型的 ms
                MappedStatement countMs = newMappedStatement(ms, Long.class);
                //创建 count 查询的缓存 key
                CacheKey countKey = executor.createCacheKey(
                		countMs, 
                		parameterObject, 
                		RowBounds.DEFAULT, 
                		boundSql);
                //调用方言获取 count sql
                String countSql = dialect.getCountSql(
                		boundSql, 
                		parameterObject, 
                		rowBounds, 
                		countKey);
                BoundSql countBoundSql = new BoundSql(
                		ms.getConfiguration(), 
                		countSql, 
                		boundSql.getParameterMappings(), 
                		parameterObject);
                //当使用动态 SQL 时,可能会产生临时的参数,这些参数需要手动设置到新的 BoundSql 中
                for (String key : additionalParameters.keySet()) {
                    countBoundSql.setAdditionalParameter(
                    		key, additionalParameters.get(key));
                }
                //执行 count 查询
                Object countResultList = executor.query(
                		countMs, 
                		parameterObject, 
                		RowBounds.DEFAULT, 
                		resultHandler, 
                		countKey, 
                		countBoundSql);
                Long count = (Long) ((List) countResultList).get(0);
                //处理查询总数
                dialect.afterCount(count, parameterObject, rowBounds);
                if(count == 0L){
                	//当查询总数为 0 时,直接返回空的结果
                	return dialect.afterPage(
                			new ArrayList(), 
                			parameterObject, 
                			rowBounds); 
                }
            }
            //判断是否需要进行分页查询
            if (dialect.beforePage(ms.getId(), parameterObject, rowBounds)){
            	//生成分页的缓存 key
                CacheKey pageKey = executor.createCacheKey(
                		ms, 
                		parameterObject, 
                		rowBounds, 
                		boundSql);
                //调用方言获取分页 sql
                String pageSql = dialect.getPageSql(
                		boundSql, 
                		parameterObject, 
                		rowBounds, 
                		pageKey);
                BoundSql pageBoundSql = new BoundSql(
                		ms.getConfiguration(), 
                		pageSql, 
                		boundSql.getParameterMappings(), 
                		parameterObject);
                //设置动态参数
                for (String key : additionalParameters.keySet()) {
                    pageBoundSql.setAdditionalParameter(
                    		key, additionalParameters.get(key));
                }
                //执行分页查询
                List resultList = executor.query(
                		ms, 
                		parameterObject, 
                		RowBounds.DEFAULT, 
                		resultHandler, 
                		pageKey, 
                		pageBoundSql);
                
                return dialect.afterPage(resultList, parameterObject, rowBounds);
            }
        }
        //返回默认查询
        return invocation.proceed();
    }

    /**
     * 根据现有的 ms 创建一个新的,使用新的返回值类型
     *
     * @param ms
     * @param resultType
     * @return
     */
    public MappedStatement newMappedStatement(
    		MappedStatement ms, Class resultType) {
        MappedStatement.Builder builder = new MappedStatement.Builder(
        		ms.getConfiguration(), 
        		ms.getId() + "_Count", 
        		ms.getSqlSource(), 
        		ms.getSqlCommandType()
        );
        builder.resource(ms.getResource());
        builder.fetchSize(ms.getFetchSize());
        builder.statementType(ms.getStatementType());
        builder.keyGenerator(ms.getKeyGenerator());
        if (ms.getKeyProperties() != null 
        		&& ms.getKeyProperties().length != 0) {
            StringBuilder keyProperties = new StringBuilder();
            for (String keyProperty : ms.getKeyProperties()) {
                keyProperties.append(keyProperty).append(",");
            }
            keyProperties.delete(
            		keyProperties.length() - 1, keyProperties.length());
            builder.keyProperty(keyProperties.toString());
        }
        builder.timeout(ms.getTimeout());
        builder.parameterMap(ms.getParameterMap());
        //count查询返回值int
        List resultMaps = new ArrayList();
        ResultMap resultMap = new ResultMap.Builder(
        		ms.getConfiguration(), 
        		ms.getId(), 
        		resultType, 
        		EMPTY_RESULTMAPPING).build();
        resultMaps.add(resultMap);
        builder.resultMaps(resultMaps);
        builder.resultSetType(ms.getResultSetType());
        builder.cache(ms.getCache());
        builder.flushCacheRequired(ms.isFlushCacheRequired());
        builder.useCache(ms.isUseCache());
        return builder.build();
    }

    public Object plugin(Object target) {
        return Plugin.wrap(target, this);
    }


    public void setProperties(Properties properties) {
        String dialectClass = properties.getProperty("dialect");
        try {
            dialect = (Dialect) Class.forName(dialectClass).newInstance();
        } catch (Exception e) {
            throw new RuntimeException(
            		"使用 PageInterceptor 分页插件时,必须设置 dialect 属性");
        }
        dialect.setProperties(properties);
        try {
            //反射获取 BoundSql 中的 additionalParameters 属性
            additionalParametersField = BoundSql.class.getDeclaredField(
            		"additionalParameters");
            additionalParametersField.setAccessible(true);
        } catch (NoSuchFieldException e) {
            throw new RuntimeException(e);
        }
    }

}

2、数据库方言接口Dialect

package com.pbyang.web.pagination;
import java.util.List;
import java.util.Properties;

import org.apache.ibatis.cache.CacheKey;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.session.RowBounds;

/**
 * 数据库方言,针对不同数据库进行实现
 * 
 * @author liuzh
 */
@SuppressWarnings("rawtypes")
public interface Dialect {
	/**
	 * 跳过 count 和 分页查询
	 * 
	 * @param msId 执行的  MyBatis 方法全名
	 * @param parameterObject 方法参数
	 * @param rowBounds 分页参数
	 * @return true 跳过,返回默认查询结果,false 执行分页查询
	 */
	boolean skip(String msId, Object parameterObject, RowBounds rowBounds);
	
	/**
	 * 执行分页前,返回 true 会进行 count 查询,false 会继续下面的 beforePage 判断
	 * 
	 * @param msId 执行的  MyBatis 方法全名
	 * @param parameterObject 方法参数
	 * @param rowBounds 分页参数
	 * @return
	 */
	boolean beforeCount(String msId, Object parameterObject, RowBounds rowBounds);
	
	/**
	 * 生成 count 查询 sql
	 * 
	 * @param boundSql 绑定 SQL 对象
	 * @param parameterObject 方法参数
	 * @param rowBounds 分页参数
	 * @param countKey count 缓存 key
	 * @return
	 */
	String getCountSql(BoundSql boundSql, Object parameterObject, RowBounds rowBounds, CacheKey countKey);
	
	/**
	 * 执行完 count 查询后
	 * 
	 * @param count 查询结果总数
	 * @param parameterObject 接口参数
	 * @param rowBounds 分页参数
	 */
	void afterCount(long count, Object parameterObject, RowBounds rowBounds);
	
	/**
	 * 执行分页前,返回 true 会进行分页查询,false 会返回默认查询结果
	 * 
	 * @param msId 执行的 MyBatis 方法全名
	 * @param parameterObject 方法参数
	 * @param rowBounds 分页参数
	 * @return
	 */
	boolean beforePage(String msId, Object parameterObject, RowBounds rowBounds);
	
	/**
	 * 生成分页查询 sql
	 * 
	 * @param boundSql 绑定 SQL 对象
	 * @param parameterObject 方法参数
	 * @param rowBounds 分页参数
	 * @param pageKey 分页缓存 key
	 * @return
	 */
	String getPageSql(BoundSql boundSql, Object parameterObject, RowBounds rowBounds, CacheKey pageKey);
	
	/**
	 * 分页查询后,处理分页结果,拦截器中直接 return 该方法的返回值
	 * 
	 * @param pageList 分页查询结果
	 * @param parameterObject 方法参数
	 * @param rowBounds 分页参数
	 * @return
	 */
	Object afterPage(List pageList, Object parameterObject, RowBounds rowBounds);
	
	/**
	 * 设置参数
	 * 
	 * @param properties 插件属性
	 */
	void setProperties(Properties properties);
}

3、可以记录 total 的分页参数

package com.pbyang.web.pagination;

import org.apache.ibatis.session.RowBounds;

/**
 * 可以记录 total 的分页参数
 */
public class PageRowBounds extends RowBounds{
	private long total;

	public PageRowBounds() {
		super();
	}

	public PageRowBounds(int offset, int limit) {
		super(offset, limit);
	}

	public long getTotal() {
		return total;
	}

	public void setTotal(long total) {
		this.total = total;
	}
}

4、MySqlDialect实现

package com.pbyang.web.pagination;

import java.util.List;
import java.util.Properties;

import org.apache.ibatis.cache.CacheKey;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.session.RowBounds;

/**
 * MySql 实现
 * 
 */
@SuppressWarnings("rawtypes")
public class MySqlDialect implements Dialect {

	public boolean skip(String msId, Object parameterObject, RowBounds rowBounds) {
		//这里使用 RowBounds 分页,默认没有 RowBounds 参数时,会使用 RowBounds.DEFAULT 作为默认值
		if(rowBounds != RowBounds.DEFAULT){
			return false;
		}
		return true;
	}

	public boolean beforeCount(String msId, Object parameterObject, RowBounds rowBounds) {
		//只有使用 PageRowBounds 才能记录总数,否则查询了总数也没用
		if(rowBounds instanceof PageRowBounds){
    		return true;
    	}
		return false;
	}
	
	public String getCountSql(BoundSql boundSql, Object parameterObject, RowBounds rowBounds, CacheKey countKey) {
		//简单嵌套实现 MySql count 查询
		return "select count(*) from (" + boundSql.getSql() + ") temp";
	}
	
    public void afterCount(long count, Object parameterObject, RowBounds rowBounds) {
    	//记录总数,按照 beforeCount 逻辑,只有 PageRowBounds 时才会查询 count,所以这里直接强制转换
    	((PageRowBounds)rowBounds).setTotal(count);
    }
    
	public boolean beforePage(String msId, Object parameterObject, RowBounds rowBounds) {
		if(rowBounds != RowBounds.DEFAULT){
			return true;
		}
		return false;
	}
	
	public String getPageSql(BoundSql boundSql, Object parameterObject, RowBounds rowBounds, CacheKey pageKey) {
		//pageKey 会影响缓存,通过固定的 RowBounds 可以保证二级缓存有效
		pageKey.update("RowBounds");
		return boundSql.getSql() + " limit " + rowBounds.getOffset() + "," + rowBounds.getLimit();
	}

	public Object afterPage(List pageList, Object parameterObject, RowBounds rowBounds) {
		return pageList;
	}
	
	public void setProperties(Properties properties) {
		
	}


}

5、要实现拦截器,需要在mybatis-config.xml中进行如下配置:



  
     
  

6、SqlConfigOnlyTest.xml配置







  
     
  


	
	
		
			
			
			  
			
			
			
				
				
				
				
			
		
	
	
	
	
	   
	
	

7、ArticleMapper.xml












   
   

8、ArticDao

package com.pbyang.dao;

import java.util.List;

import org.apache.ibatis.session.RowBounds;

import com.pbyang.entity.Article;
import org.springframework.stereotype.Repository;

@Repository
public interface ArticleDao {

	public List
selectAll(RowBounds rowBounds); }

9、测试代码实现

package com.pbyang.dao;

import static org.junit.Assert.*;

import java.io.InputStream;
import java.util.List;

import org.apache.ibatis.io.Resources;
import org.apache.ibatis.session.RowBounds;
import org.apache.ibatis.session.SqlSession;
import org.apache.ibatis.session.SqlSessionFactory;
import org.apache.ibatis.session.SqlSessionFactoryBuilder;
import org.apache.log4j.Logger;
import org.junit.Before;
import org.junit.Test;

import com.pbyang.entity.Article;
import com.pbyang.web.pagination.PageRowBounds;

public class ArticleDaoTest {
	
	
	private SqlSessionFactory sqlSessionFactory;
	
	Logger LOGGER = Logger.getLogger(ArticleDaoTest.class);
	
	@Before
	public void init() throws Exception {
		InputStream inputStream = Resources.getResourceAsStream("SqlConfigOnlyTest.xml");
		this.sqlSessionFactory = new SqlSessionFactoryBuilder().build(inputStream);
	}
	@Test
	public void testSelectAll() {
		SqlSession sqlSession = this.sqlSessionFactory.openSession();
		ArticleDao articleDao = sqlSession.getMapper(ArticleDao.class);
		try {
		RowBounds rowBounds = new RowBounds(1, 3);  //// offset起始行 // limit是当前页显示多少条数据
		List
articleList = articleDao.selectAll(rowBounds); for (Article article : articleList) { System.out.println("标题:" +article.getTitle()); } System.out.println("++++++++++xxx+++++++ " + articleList.size()); PageRowBounds pageRowbounds = new PageRowBounds(1, 3); List
list = articleDao.selectAll(pageRowbounds); System.out.println("查询总数" + pageRowbounds.getTotal()); for(Article article : list) { System.out.println("标题:" +article.getTitle()); } } catch (Exception e) { // TODO: handle exception } finally { sqlSession.close(); } //分页插件可以支持一级缓存和二级缓存 } }

你可能感兴趣的:(Mybatis)