@Intercepts({@Signature(type = StatementHandler.class, method = "prepare", args = {Connection.class})}) public class PaginationInterceptor implements Interceptor { private static final Log logger = LogFactory .getLog(PaginationInterceptor.class); private static final ObjectFactory DEFAULT_OBJECT_FACTORY = new DefaultObjectFactory(); private static final ObjectWrapperFactory DEFAULT_OBJECT_WRAPPER_FACTORY = new DefaultObjectWrapperFactory(); private static String defaultDialect = "mysql"; // 数据库类型(默认为mysql) private static String defaultPageSqlId = ".*Page$"; // 需要拦截的ID(正则匹配) private static String dialect = ""; // 数据库类型(默认为mysql) private static String pageSqlId = ""; // 需要拦截的ID(正则匹配) @Override public Object intercept(Invocation invocation) throws Throwable { StatementHandler statementHandler = (StatementHandler) invocation .getTarget(); MetaObject metaStatementHandler = MetaObject.forObject( statementHandler, DEFAULT_OBJECT_FACTORY, DEFAULT_OBJECT_WRAPPER_FACTORY); // 分离代理对象链(由于目标类可能被多个拦截器拦截,从而形成多次代理,通过下面的两次循环可以分离出最原始的的目标类) while (metaStatementHandler.hasGetter("h")) { Object object = metaStatementHandler.getValue("h"); metaStatementHandler = MetaObject.forObject(object, DEFAULT_OBJECT_FACTORY, DEFAULT_OBJECT_WRAPPER_FACTORY); } // 分离最后一个代理对象的目标类 while (metaStatementHandler.hasGetter("target")) { Object object = metaStatementHandler.getValue("target"); metaStatementHandler = MetaObject.forObject(object, DEFAULT_OBJECT_FACTORY, DEFAULT_OBJECT_WRAPPER_FACTORY); } Configuration configuration = (Configuration) metaStatementHandler .getValue("delegate.configuration"); dialect = configuration.getVariables()==null?null:configuration.getVariables().getProperty("dialect"); if (null == dialect || "".equals(dialect)) { logger.warn("Property dialect is not setted,use default 'mysql' "); dialect = defaultDialect; } pageSqlId = configuration.getVariables()==null?null:configuration.getVariables().getProperty("pageSqlId"); if (null == pageSqlId || "".equals(pageSqlId)) { logger.warn("Property pageSqlId is not setted,use default '.*Page$' "); pageSqlId = defaultPageSqlId; } MappedStatement mappedStatement = (MappedStatement) metaStatementHandler .getValue("delegate.mappedStatement"); // 只重写需要分页的sql语句。通过MappedStatement的ID匹配,默认重写以Page结尾的MappedStatement的sql if (mappedStatement.getId().matches(pageSqlId)) { BoundSql boundSql = (BoundSql) metaStatementHandler .getValue("delegate.boundSql"); Object parameterObject = boundSql.getParameterObject(); if (parameterObject == null) { throw new NullPointerException("parameterObject is null!"); } else { Pagination page = (Pagination) metaStatementHandler .getValue("delegate.boundSql.parameterObject.page"); String sql = boundSql.getSql(); // 重写sql String pageSql = buildPageSql(sql, page); metaStatementHandler.setValue("delegate.boundSql.sql", pageSql); // 采用物理分页后,就不需要mybatis的内存分页了,所以重置下面的两个参数 metaStatementHandler.setValue("delegate.rowBounds.offset", RowBounds.NO_ROW_OFFSET); metaStatementHandler.setValue("delegate.rowBounds.limit", RowBounds.NO_ROW_LIMIT); Connection connection = (Connection) invocation.getArgs()[0]; // 重设分页参数里的总页数等 setPageParameter(sql, connection, mappedStatement, boundSql, page); } } // 将执行权交给下一个拦截器 return invocation.proceed(); } /** * 从数据库里查询总的记录数并计算总页数,回写进分页参数<code>PageParameter</code>,这样调用者就可用通过 分页参数 * <code>PageParameter</code>获得相关信息。 * * @param sql * @param connection * @param mappedStatement * @param boundSql * @param page */ private void setPageParameter(String sql, Connection connection, MappedStatement mappedStatement, BoundSql boundSql, Pagination page) { // 记录总记录数 String countSql = "select count(0) from (" + sql + ") as total"; PreparedStatement countStmt = null; ResultSet rs = null; try { countStmt = connection.prepareStatement(countSql); BoundSql countBS = new BoundSql(mappedStatement.getConfiguration(), countSql, boundSql.getParameterMappings(), boundSql.getParameterObject()); setParameters(countStmt, mappedStatement, countBS, boundSql.getParameterObject()); rs = countStmt.executeQuery(); int totalCount = 0; if (rs.next()) { totalCount = rs.getInt(1); } page.setTotalCount((long)totalCount); int totalPage = totalCount / page.getPageSize() + ((totalCount % page.getPageSize() == 0) ? 0 : 1); page.setTotalPage(totalPage); } catch (SQLException e) { logger.error("Ignore this exception", e); } finally { try { rs.close(); } catch (SQLException e) { logger.error("Ignore this exception", e); } try { countStmt.close(); } catch (SQLException e) { logger.error("Ignore this exception", e); } } } /** * 对SQL参数(?)设值 * * @param ps * @param mappedStatement * @param boundSql * @param parameterObject * @throws SQLException */ private void setParameters(PreparedStatement ps, MappedStatement mappedStatement, BoundSql boundSql, Object parameterObject) throws SQLException { ParameterHandler parameterHandler = new DefaultParameterHandler( mappedStatement, parameterObject, boundSql); parameterHandler.setParameters(ps); } /** * 根据数据库类型,生成特定的分页sql * * @param sql * @param page * @return */ private String buildPageSql(String sql, Pagination page) { if (page != null) { String pageSql = new String(); if ("mysql".equals(dialect)) { pageSql = MysqlDialect.getInstance().getLimitString(sql, page.getStartIndex(), page.getPageSize()); } else if ("oracle".equals(dialect)) { pageSql = OracleDialect.getInstance().getLimitString(sql, page.getStartIndex(), page.getPageSize()); } else { return sql; } return pageSql; } else { return sql; } } @Override public Object plugin(Object target) { // 当目标类是StatementHandler类型时,才包装目标类,否者直接返回目标本身,减少目标被代理的次数 if (target instanceof StatementHandler) { return Plugin.wrap(target, this); } else { return target; } } @Override public void setProperties(Properties properties) { } }