分析Mybatis的分页插件PageHelper的源码

本次我们分析PageHelper的源码,查看它的执行过程;

1、PageHelper的版本


                com.github.pagehelper
            pagehelper
            3.7.5
        

2、首先我们先看在代码中怎么使用的

PageHelper.startPage(1,2);
List users = sqlSession.selectList("UserMapper.queryAllUser");
users.forEach(u -> System.out.println(u));

3、分析PageHelper.startPage(1,2),这个是在设置当前页和每页查询的数据量

PageHelper里面有一个静态方法:
public static Page startPage(int pageNum, int pageSize) {
        return startPage(pageNum, pageSize, true);
    }
最后调用的方法:
    public static Page startPage(int pageNum, int pageSize, boolean count, Boolean reasonable, Boolean pageSizeZero) {
        Page page = new Page(pageNum, pageSize, count);
        page.setReasonable(reasonable);
        page.setPageSizeZero(pageSizeZero);
        SqlUtil.setLocalPage(page);
        return page;
    }
其实就是创建了一个Page对象,设置一些参数,然后把Page对象放入了ThreadLocal对象里面。
创建对象的时候还计算了开始和结束的行
private Page(int pageNum, int pageSize, int total, Boolean reasonable) {
        super(0);
        if (pageNum == 1 && pageSize == Integer.MAX_VALUE) {
            pageSizeZero = true;
            pageSize = 0;
        }
        this.pageNum = pageNum;
        this.pageSize = pageSize;
        this.total = total;
        calculateStartAndEndRow();// 计算起止行号
        setReasonable(reasonable);
    }
private void calculateStartAndEndRow() {
        this.startRow = this.pageNum > 0 ? (this.pageNum - 1) * this.pageSize : 0;
        this.endRow = this.startRow + this.pageSize * (this.pageNum > 0 ? 1 : 0);
    }
// 将创建的page对象放入了ThreadLocal对象中
SqlUtil.setLocalPage(page); 
public static void setLocalPage(Page page) {
        LOCAL_PAGE.set(page);
    }
private static final ThreadLocal LOCAL_PAGE = new ThreadLocal();

4、PageHelper实现Interceptor接口,在接口上面有一个注解:

@Intercepts(@Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}));意思也就是说拦截了Executor的query方法,方法的参数是:MappedStatement,Object,RowBounds,ResultHandler。

实现Interceptor接口会重写3个方法:

/**
     * Mybatis拦截器方法
     *
     * @param invocation 拦截器入参
     * @return 返回执行结果
     * @throws Throwable 抛出异常
     */
    public Object intercept(Invocation invocation) throws Throwable {
        return sqlUtil.processPage(invocation);
    }

    /**
     * 只拦截Executor
     *
     * @param target
     * @return
     */
    public Object plugin(Object target) {
        if (target instanceof Executor) {
            return Plugin.wrap(target, this);
        } else {
            return target;
        }
    }

    /**
     * 设置属性值
     *
     * @param p 属性值
     */
    public void setProperties(Properties p) {
        //MyBatis3.2.0版本校验
        try {
            Class.forName("org.apache.ibatis.scripting.xmltags.SqlNode");//SqlNode是3.2.0之后新增的类
        } catch (ClassNotFoundException e) {
            throw new RuntimeException("您使用的MyBatis版本太低,MyBatis分页插件PageHelper支持MyBatis3.2.0及以上版本!");
        }
        //数据库方言
        String dialect = p.getProperty("dialect");
        sqlUtil = new SqlUtil(dialect);
        sqlUtil.setProperties(p);
    }

5、我们的分页逻辑在intercept(Invocation invocation)方法里面,所以主要来分析这个方法

public Object intercept(Invocation invocation) throws Throwable {
        return sqlUtil.processPage(invocation);
    }
// 调用的是这个方法
public Object processPage(Invocation invocation) throws Throwable {
        try {
            Object result = _processPage(invocation);
            return result;// 这里返回的其实是一个Page对象
        } finally {
            clearLocalPage();
        }
    }
// 接下来
private Object _processPage(Invocation invocation) throws Throwable {
        final Object[] args = invocation.getArgs();
        RowBounds rowBounds = (RowBounds) args[2];
        if (SqlUtil.getLocalPage() == null && rowBounds == RowBounds.DEFAULT) {
            return invocation.proceed();
        } else {
            //忽略RowBounds-否则会进行Mybatis自带的内存分页
            args[2] = RowBounds.DEFAULT;
            //分页信息
            Page page = getPage(rowBounds);
            //pageSizeZero的判断
            if ((page.getPageSizeZero() != null && page.getPageSizeZero()) && page.getPageSize() == 0) {
                //执行正常(不分页)查询
                Object result = invocation.proceed();
                //得到处理结果
                page.addAll((List) result);
                //相当于查询第一页
                page.setPageNum(1);
                //这种情况相当于pageSize=total
                page.setPageSize(page.size());
                //仍然要设置total
                page.setTotal(page.size());
                //返回结果仍然为Page类型 - 便于后面对接收类型的统一处理
                return page;
            }
            //获取原始的ms
            MappedStatement ms = (MappedStatement) args[0];
            SqlSource sqlSource = ms.getSqlSource();
            //简单的通过total的值来判断是否进行count查询
            if (page.isCount()) {
                //将参数中的MappedStatement替换为新的qs
                msUtils.processCountMappedStatement(ms, sqlSource, args);
                //查询总数
                Object result = invocation.proceed();
                //设置总数
                page.setTotal((Integer) ((List) result).get(0));
                if (page.getTotal() == 0) {
                    return page;
                }
            }
            //pageSize>0的时候执行分页查询,pageSize<=0的时候不执行相当于可能只返回了一个count
            if (page.getPageSize() > 0 &&
                    ((rowBounds == RowBounds.DEFAULT && page.getPageNum() > 0)
                            || rowBounds != RowBounds.DEFAULT)) {
                //将参数中的MappedStatement替换为新的qs
                msUtils.processPageMappedStatement(ms, sqlSource, page, args);
                //执行分页查询
                Object result = invocation.proceed();
                //得到处理结果
                page.addAll((List) result);
            }
            //返回结果
            return page;
        }
    }

6、分页其实分成两步:第一步,查询一共有多少条数据(Count);第二步,查询当前页的数据(LIMIT ?,?);我们先分析统计总数的代码

# 分析count的源码
主要的代码如下:
MappedStatement ms = (MappedStatement) args[0];
SqlSource sqlSource = ms.getSqlSource();
 //简单的通过total的值来判断是否进行count查询
 if (page.isCount()) {
       //将参数中的MappedStatement替换为新的qs
       msUtils.processCountMappedStatement(ms, sqlSource, args);
        //查询总数
       Object result = invocation.proceed();
       //设置总数
      page.setTotal((Integer) ((List) result).get(0));
      if (page.getTotal() == 0) {
               return page;
             }
  }
# 关键代码:msUtils.processCountMappedStatement(ms, sqlSource, args);//将参数中的MappedStatement替换为新的qs
public void processCountMappedStatement(MappedStatement ms, SqlSource sqlSource, Object[] args) {
        args[0] = getMappedStatement(ms, sqlSource, args[1], SUFFIX_COUNT);
    }

public MappedStatement getMappedStatement(MappedStatement ms, SqlSource sqlSource, Object parameterObject, String suffix) {
        MappedStatement qs = null;
        if (ms.getId().endsWith(SUFFIX_PAGE) || ms.getId().endsWith(SUFFIX_COUNT)) {
            throw new RuntimeException("分页插件配置错误:请不要在系统中配置多个分页插件(使用Spring时,mybatis-config.xml和Spring配置方式,请选择其中一种,不要同时配置多个分页插件)!");
        }
        if (parser.isSupportedMappedStatementCache()) {
            try {
                qs = ms.getConfiguration().getMappedStatement(ms.getId() + suffix);
            } catch (Exception e) {
                //ignore
            }
        }
        if (qs == null) {
            //创建一个新的MappedStatement
            qs = newMappedStatement(ms, getsqlSource(ms, sqlSource, parameterObject, suffix == SUFFIX_COUNT), suffix);
            if (parser.isSupportedMappedStatementCache()) {
                try {
                    ms.getConfiguration().addMappedStatement(qs);
                } catch (Exception e) {
                    //ignore
                }
            }
        }
        return qs;
    }
# 关键代码:qs = newMappedStatement(ms, getsqlSource(ms, sqlSource, parameterObject, suffix == SUFFIX_COUNT), suffix);中的getsqlSource(ms, sqlSource, parameterObject, suffix == SUFFIX_COUNT)
public SqlSource getsqlSource(MappedStatement ms, SqlSource sqlSource, Object parameterObject, boolean count) {
        if (sqlSource instanceof DynamicSqlSource) {//动态sql
            MetaObject msObject = SystemMetaObject.forObject(ms);
            SqlNode sqlNode = (SqlNode) msObject.getValue("sqlSource.rootSqlNode");
            MixedSqlNode mixedSqlNode;
            if (sqlNode instanceof MixedSqlNode) {
                mixedSqlNode = (MixedSqlNode) sqlNode;
            } else {
                List contents = new ArrayList(1);
                contents.add(sqlNode);
                mixedSqlNode = new MixedSqlNode(contents);
            }
            return new PageDynamicSqlSource(this, ms.getConfiguration(), mixedSqlNode, count);
        } else if (sqlSource instanceof ProviderSqlSource) {//注解式sql
            return new PageProviderSqlSource(parser, ms.getConfiguration(), (ProviderSqlSource) sqlSource, count);
        } else if (count) {//RawSqlSource和StaticSqlSource
            return getStaticCountSqlSource(ms.getConfiguration(), sqlSource, parameterObject);
        } else {
            return getStaticPageSqlSource(ms.getConfiguration(), sqlSource, parameterObject);
        }
    }
# 关键代码
else if (count) {//RawSqlSource和StaticSqlSource
            return getStaticCountSqlSource(ms.getConfiguration(), sqlSource, parameterObject);
        }
public SqlSource getStaticCountSqlSource(Configuration configuration, SqlSource sqlSource, Object parameterObject) {
        BoundSql boundSql = sqlSource.getBoundSql(parameterObject);
        return new StaticSqlSource(configuration, parser.getCountSql(boundSql.getSql()), boundSql.getParameterMappings());
    }
# 关键代码:parser.getCountSql(boundSql.getSql())
public String getCountSql(final String sql) {
        return sqlParser.getSmartCountSql(sql);
    }
# 接下来这个方法就是将sql处理为count
public String getSmartCountSql(String sql) {
        //校验是否支持该sql
        isSupportedSql(sql);
        if (CACHE.get(sql) != null) {
            return CACHE.get(sql);
        }
        //解析SQL
        Statement stmt = null;
        try {
            stmt = CCJSqlParserUtil.parse(sql);
        } catch (Throwable e) {
            //无法解析的用一般方法返回count语句
            String countSql = getSimpleCountSql(sql);
            CACHE.put(sql, countSql);
            return countSql;
        }
        Select select = (Select) stmt;
        SelectBody selectBody = select.getSelectBody();
        //处理body-去order by
        processSelectBody(selectBody);
        //处理with-去order by
        processWithItemsList(select.getWithItemsList());
        //处理为count查询
        sqlToCount(select);
        String result = select.toString();
        CACHE.put(sql, result);
        return result;
    }
# 关键代码:sqlToCount(select);
public void sqlToCount(Select select) {
        SelectBody selectBody = select.getSelectBody();
        // 是否能简化count查询
        if (selectBody instanceof PlainSelect && isSimpleCount((PlainSelect) selectBody)) {
            ((PlainSelect) selectBody).setSelectItems(COUNT_ITEM);
        } else {
            PlainSelect plainSelect = new PlainSelect();
            SubSelect subSelect = new SubSelect();
            subSelect.setSelectBody(selectBody);
            subSelect.setAlias(TABLE_ALIAS);
            plainSelect.setFromItem(subSelect);
            plainSelect.setSelectItems(COUNT_ITEM);
            select.setSelectBody(plainSelect);
        }
    }
# 最终将原来的sql改成countsql之后返回
Object result = invocation.proceed();# 执行sql完成查询
page.setTotal((Integer) ((List) result).get(0));// 设置总数
if (page.getTotal() == 0) {
    return page;// 如果count的结果为0,就不进行limit查询了,直接返回page对象。
}
# 所以count这一步分析完毕了。

7、分析limit查询的这一步骤

//将参数中的MappedStatement替换为新的qs
msUtils.processPageMappedStatement(ms, sqlSource, page, args);
Object result = invocation.proceed();
//得到处理结果
page.addAll((List) result);

# 关键代码:msUtils.processPageMappedStatement(ms, sqlSource, page, args);
public void processPageMappedStatement(MappedStatement ms, SqlSource sqlSource, Page page, Object[] args) {
        args[0] = getMappedStatement(ms, sqlSource, args[1], SUFFIX_PAGE);
        //处理入参
        args[1] = setPageParameter((MappedStatement) args[0], args[1], page);
    }
# 先看处理sql,其实就是加上limit ?,?(args[0] = getMappedStatement(ms, sqlSource, args[1], SUFFIX_PAGE);)
public MappedStatement getMappedStatement(MappedStatement ms, SqlSource sqlSource, Object parameterObject, String suffix) {
        MappedStatement qs = null;
        if (ms.getId().endsWith(SUFFIX_PAGE) || ms.getId().endsWith(SUFFIX_COUNT)) {
            throw new RuntimeException("分页插件配置错误:请不要在系统中配置多个分页插件(使用Spring时,mybatis-config.xml和Spring配置方式,请选择其中一种,不要同时配置多个分页插件)!");
        }
        if (parser.isSupportedMappedStatementCache()) {
            try {
                qs = ms.getConfiguration().getMappedStatement(ms.getId() + suffix);
            } catch (Exception e) {
                //ignore
            }
        }
        if (qs == null) {
            //创建一个新的MappedStatement
            qs = newMappedStatement(ms, getsqlSource(ms, sqlSource, parameterObject, suffix == SUFFIX_COUNT), suffix);
            if (parser.isSupportedMappedStatementCache()) {
                try {
                    ms.getConfiguration().addMappedStatement(qs);
                } catch (Exception e) {
                    //ignore
                }
            }
        }
        return qs;
    }
# 关键代码:getsqlSource(ms, sqlSource, parameterObject, suffix == SUFFIX_COUNT)
public SqlSource getsqlSource(MappedStatement ms, SqlSource sqlSource, Object parameterObject, boolean count) {
        if (sqlSource instanceof DynamicSqlSource) {//动态sql
            MetaObject msObject = SystemMetaObject.forObject(ms);
            SqlNode sqlNode = (SqlNode) msObject.getValue("sqlSource.rootSqlNode");
            MixedSqlNode mixedSqlNode;
            if (sqlNode instanceof MixedSqlNode) {
                mixedSqlNode = (MixedSqlNode) sqlNode;
            } else {
                List contents = new ArrayList(1);
                contents.add(sqlNode);
                mixedSqlNode = new MixedSqlNode(contents);
            }
            return new PageDynamicSqlSource(this, ms.getConfiguration(), mixedSqlNode, count);
        } else if (sqlSource instanceof ProviderSqlSource) {//注解式sql
            return new PageProviderSqlSource(parser, ms.getConfiguration(), (ProviderSqlSource) sqlSource, count);
        } else if (count) {//RawSqlSource和StaticSqlSource
            return getStaticCountSqlSource(ms.getConfiguration(), sqlSource, parameterObject);
        } else {
            return getStaticPageSqlSource(ms.getConfiguration(), sqlSource, parameterObject);// 这次是走这里
        }
    }
# 关键代码:return getStaticPageSqlSource(ms.getConfiguration(), sqlSource, parameterObject);
public SqlSource getStaticPageSqlSource(Configuration configuration, SqlSource sqlSource, Object parameterObject) {
        BoundSql boundSql = sqlSource.getBoundSql(parameterObject);
        return new StaticSqlSource(configuration, parser.getPageSql(boundSql.getSql()), parser.getPageParameterMapping(configuration, boundSql));
    }
# 关键代码:parser.getPageSql(boundSql.getSql())
public String getPageSql(String sql) {
        StringBuilder sqlBuilder = new StringBuilder(sql.length() + 14);
        sqlBuilder.append(sql);
        sqlBuilder.append(" limit ?,?");
        return sqlBuilder.toString();
    }
# 就是在后面加上了" limit ?,?"。最终返回了
# 再看怎么处理参数
args[1] = setPageParameter((MappedStatement) args[0], args[1], page);
public Map setPageParameter(MappedStatement ms, Object parameterObject, Page page) {
        BoundSql boundSql = ms.getBoundSql(parameterObject);
        return parser.setPageParameter(ms, parameterObject, boundSql, page);
    }
public Map setPageParameter(MappedStatement ms, Object parameterObject, BoundSql boundSql, Page page) {
        Map paramMap = super.setPageParameter(ms, parameterObject, boundSql, page);
        paramMap.put(PAGEPARAMETER_FIRST, page.getStartRow());// 设置第一个参数
        paramMap.put(PAGEPARAMETER_SECOND, page.getPageSize());// 设置第二个参数
        return paramMap;
    }
# 其实就是设置limit后面的参数,两个参数的分别为:First_PageHelper(值为page对象的startRow),Second_PageHelper(值为page对象的pageSize)
# 返回替换的sql,和添加参数,直接分页查询
Object result = invocation.proceed();
# 将结果添加到page中,page继承了ArrayList
page.addAll((List) result);
最终返回了page对象

8、至此,源码分析完毕。注意Page继承了ArrayList

你可能感兴趣的:(源码分析)