mybatis 多数据库通用分页插件PageInterceptor

这个分页插件是基于mybatis-3.25.jar  commons-lang3-3.32.jar 以及mybatis-spring-1.2.0.jar。

 

下面是继承了Interceptor的插件类

package dwz.common.mybatis;

import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.List;
import java.util.Properties;

import javax.xml.bind.PropertyException;
import org.apache.commons.lang3.StringUtils;
import org.apache.ibatis.binding.MapperMethod;
import org.apache.ibatis.executor.parameter.ParameterHandler;
import org.apache.ibatis.executor.statement.RoutingStatementHandler;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.ParameterMapping;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Plugin;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.scripting.defaults.DefaultParameterHandler;

import dwz.common.mybatis.Page;
import dwz.common.util.ReflectUtil;

@Intercepts({ @Signature(method = "prepare", type = StatementHandler.class, args = { Connection.class }) })
@SuppressWarnings("rawtypes")
public class PageInterceptor implements Interceptor {

	private static String databaseType ="";// 数据库类型,不同的数据库有不同的分页方法
	/**
	 * 拦截后要执行的方法
	 */
	public Object intercept(Invocation invocation) throws Throwable {

		RoutingStatementHandler handler = (RoutingStatementHandler) invocation
				.getTarget();
		StatementHandler delegate = (StatementHandler) ReflectUtil.getFieldValue(handler, "delegate");
		BoundSql boundSql = delegate.getBoundSql();
		Object params = boundSql.getParameterObject();
		Page page = null;
		if (params instanceof Page) {
			page = (Page) params;
		} else if (params instanceof MapperMethod.ParamMap) {
			MapperMethod.ParamMap paramMap = (MapperMethod.ParamMap) params;
			for (Object key : paramMap.keySet()) {
				if (paramMap.get(key) instanceof Page) {
					page = (Page) paramMap.get(key);
					break;
				}
			}
		}

		if (page != null) {
			MappedStatement mappedStatement = (MappedStatement) ReflectUtil
					.getFieldValue(delegate, "mappedStatement");
			Connection connection = (Connection) invocation.getArgs()[0];
			String sql = boundSql.getSql();
			this.setTotalRecord(page, (MapperMethod.ParamMap) params,
					mappedStatement, connection);
			String pageSql = this.getPageSql(page, sql);
			ReflectUtil.setFieldValue(boundSql, "sql", pageSql);
		}
		return invocation.proceed();
	}

	/**
	 * 拦截器对应的封装原始对象的方法
	 */
	public Object plugin(Object target) {
		return Plugin.wrap(target, this);
	}

	public void setProperties(Properties p) {
		databaseType = p.getProperty("databaseType");
		if (StringUtils.isEmpty(databaseType)) {
			try {
				throw new PropertyException("databaseType is not found!");
			} catch (PropertyException e) {
				e.printStackTrace();
			}
		} 
	}
	private String getPageSql(Page page, String sql) {
		StringBuffer sqlBuffer = new StringBuffer(sql);
		if ("mysql".equalsIgnoreCase(databaseType)) {
			return getMysqlPageSql(page, sqlBuffer);
		} else if ("oracle".equalsIgnoreCase(databaseType)) {
			return getOraclePageSql(page, sqlBuffer);
		} else if ("sqlserver".equalsIgnoreCase(databaseType)) {
			return getSqlserverPageSql(page, sqlBuffer);
		}
		return sqlBuffer.toString();
	}

	private String getSqlserverPageSql(Page page, StringBuffer sqlBuffer) {
		// 计算第一条记录的位置,Sqlserver中记录的位置是从0开始的。
		int startRowNum = (page.getPageNum() - 1) * page.getPageSize() + 1;
		int endRowNum = startRowNum + page.getPageSize();
		String sql = "select appendRowNum.row,* from (select ROW_NUMBER() OVER (order by (select 0)) AS row,* from ("
				+ sqlBuffer.toString()
				+ ") as innerTable"
				+ ")as appendRowNum where appendRowNum.row >= "
				+ startRowNum
				+ " AND appendRowNum.row <= " + endRowNum;
		return sql;
	}

	private String getMysqlPageSql(Page page, StringBuffer sqlBuffer) {
		// 计算第一条记录的位置,Mysql中记录的位置是从0开始的。
		int offset = (page.getPageNum() - 1) * page.getPageSize();
		sqlBuffer.append(" limit ").append(offset).append(",").append(page.getPageSize());
		return sqlBuffer.toString();
	}

	private String getOraclePageSql(Page page, StringBuffer sqlBuffer) {
		// 计算第一条记录的位置,Oracle分页是通过rownum进行的,而rownum是从1开始的
		int offset = (page.getPageNum() - 1) * page.getPageSize() + 1;
		sqlBuffer.insert(0, "select u.*, rownum r from (").append(") u where rownum < ")
			.append(offset + page.getPageSize());
		sqlBuffer.insert(0, "select * from (").append(") where r >= ").append(offset);
		return sqlBuffer.toString();
	}

	/**
	 * 给当前的参数对象page设置总记录数
	 * 
	 * @param page
	 *            Mapper映射语句对应的参数对象
	 * @param mappedStatement
	 *            Mapper映射语句
	 * @param connection
	 */
	private void setTotalRecord(Page page, MapperMethod.ParamMap params,
			MappedStatement mappedStatement, Connection connection) {
		BoundSql boundSql = mappedStatement.getBoundSql(params);
		String sql = boundSql.getSql();
		String countSql = this.getCountSql(sql);
		List parameterMappings = boundSql.getParameterMappings();
		BoundSql countBoundSql = new BoundSql(mappedStatement.getConfiguration(), countSql,parameterMappings, params);
		ParameterHandler parameterHandler = new DefaultParameterHandler(
				mappedStatement, params, countBoundSql);
		PreparedStatement pstmt = null;
		ResultSet rs = null;
		try {
			pstmt = connection.prepareStatement(countSql);
			parameterHandler.setParameters(pstmt);
			rs = pstmt.executeQuery();
			if (rs.next()) {
				int totalRecord = rs.getInt(1);
				page.setTotalRecord(totalRecord);
			}
		} catch (SQLException e) {
			e.printStackTrace();
		} finally {
			try {
				if (rs != null)rs.close();
				if (pstmt != null)pstmt.close();
			} catch (SQLException e) {
				e.printStackTrace();
			}
		}
	}

	/**
	 * 根据原Sql语句获取对应的查询总记录数的Sql语句
	 * 
	 * @param sql
	 * @return
	 */
	private String getCountSql(String sql) {
		return "select count(*) from (" + sql + ") as countRecord";
	}

}

所需要的工具类ReflectUtil类,以及Page类如下

package dwz.common.util;

import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;

import org.apache.commons.lang3.StringUtils;
import org.apache.commons.beanutils.BeanUtils;
import org.springframework.util.Assert;


/**
 * 利用反射进行操作的一个工具类
 */
public class ReflectUtil {
	/**
	 * 利用反射获取指定对象的指定属性
	 * 
	 * @param obj
	 *            目标对象
	 * @param fieldName
	 *            目标属性
	 * @return 目标属性的值
	 */
	public static Object getFieldValue(Object obj, String fieldName) {
		Object result = null;
		Field field = ReflectUtil.getField(obj, fieldName);
		if (field != null) {
			field.setAccessible(true);
			try {
				result = field.get(obj);
			} catch (IllegalArgumentException e) {
				e.printStackTrace();
			} catch (IllegalAccessException e) {
				e.printStackTrace();
			}
		}
		return result;
	}

	/**
	 * 利用反射获取指定对象里面的指定属性
	 * 
	 * @param obj
	 *            目标对象
	 * @param fieldName
	 *            目标属性
	 * @return 目标字段
	 */
	private static Field getField(Object obj, String fieldName) {
		Field field = null;
		for (Class clazz = obj.getClass(); clazz != Object.class; clazz = clazz
				.getSuperclass()) {
			try {
				field = clazz.getDeclaredField(fieldName);
				break;
			} catch (NoSuchFieldException e) {
				// 这里不用做处理,子类没有该字段可能对应的父类有,都没有就返回null。
			}
		}
		return field;
	}

	/**
	 * 利用反射设置指定对象的指定属性为指定的值
	 * 
	 * @param obj
	 *            目标对象
	 * @param fieldName
	 *            目标属性
	 * @param fieldValue
	 *            目标值
	 */
	public static void setFieldValue(Object obj, String fieldName,
			String fieldValue) {
		Field field = ReflectUtil.getField(obj, fieldName);
		if (field != null) {
			try {
				field.setAccessible(true);
				field.set(obj, fieldValue);
			} catch (IllegalArgumentException e) {
				e.printStackTrace();
			} catch (IllegalAccessException e) {
				e.printStackTrace();
			}
		}
	}

	/**
	 * 两者属性名一致时,拷贝source里的属性到dest里
	 * 
	 * @return void
	 * @throws IllegalArgumentException 
	 * @throws IllegalAccessException
	 * @throws InvocationTargetException
	 */
	@SuppressWarnings("unchecked")
	public static void copyPorperties(Object dest, Object source) throws IllegalAccessException, IllegalArgumentException, InvocationTargetException{
		Class srcCla = source.getClass();
		Field[] fsF = srcCla.getDeclaredFields();

		for (Field s : fsF)
		{
			String name = s.getName();
			Object srcObj = invokeGetterMethod(source, name);
			try
			{
				BeanUtils.setProperty(dest, name, srcObj);
			}
			catch (Exception e){
				e.printStackTrace();
			}

		}
	}

	/**
	 * 调用Getter方法.
	 * @throws InvocationTargetException 
	 * @throws IllegalArgumentException 
	 * @throws IllegalAccessException 
	 */
	public static Object invokeGetterMethod(Object target, String propertyName) throws IllegalAccessException, IllegalArgumentException, InvocationTargetException
	{
		String getterMethodName = "get" + StringUtils.capitalize(propertyName);
		return invokeMethod(target, getterMethodName, new Class[] {},
				new Object[] {});
	}

	/**
	 * 直接调用对象方法, 无视private/protected修饰符.
	 * @throws InvocationTargetException 
	 * @throws IllegalArgumentException 
	 * @throws IllegalAccessException 
	 */
	public static Object invokeMethod(final Object object,
			final String methodName, final Class[] parameterTypes,
			final Object[] parameters) throws IllegalAccessException, IllegalArgumentException, InvocationTargetException{
		Method method = getDeclaredMethod(object, methodName, parameterTypes);
		if (method == null)
		{
			throw new IllegalArgumentException("Could not find method ["
					+ methodName + "] parameterType " + parameterTypes
					+ " on target [" + object + "]");
		}

		method.setAccessible(true);
		return method.invoke(object, parameters);
	}

	/**
	 * 循环向上转型, 获取对象的DeclaredMethod.
	 * 
	 * 如向上转型到Object仍无法找到, 返回null.
	 */
	protected static Method getDeclaredMethod(Object object, String methodName,
			Class[] parameterTypes)
	{
		Assert.notNull(object, "object不能为空");

		for (Class superClass = object.getClass(); superClass != Object.class; superClass = superClass
				.getSuperclass())
		{
			try{
				return superClass.getDeclaredMethod(methodName, parameterTypes);
			}
			catch (NoSuchMethodException e)
			{// NOSONAR
				// Method不在当前类定义,继续向上转型
			}
		}
		return null;
	}
}

page类,可以复写toString方法直接将页面分页效果展示出来

package dwz.common.mybatis;

import java.util.List;

/**
 * 对分页的基本数据进行封装
 */
public class Page{
    private int pageNum = 1;//页码,默认是第一页
    private int pageSize = 5;//每页显示的记录数,默认是5
    private int totalRecord;//总记录数
    private int total;//总记录数
    private int totalPage;//总页数
    private List results;//对应的当前页记录

    
    public int getTotal() {
		return total;
	}

	public void setTotal(int total) {
		this.total = total;
	}

	public int getPageNum() {
        return pageNum;
    }

    public void setPageNum(int pageNum) {
        this.pageNum = pageNum;
    }

    public int getPageSize() {
        return pageSize;
    }

    public void setPageSize(int pageSize) {
        this.pageSize = pageSize;
    }

    public int getTotalRecord() {
        return totalRecord;
    }

    public void setTotalRecord(int totalRecord) {
        this.totalRecord = totalRecord;
        this.total=totalRecord;
        int totalPage = totalRecord % pageSize == 0 ? totalRecord / pageSize : totalRecord / pageSize + 1;
        this.setTotalPage(totalPage);
    }

    public int getTotalPage() {
        return totalPage;
    }

    public void setTotalPage(int totalPage) {
        this.totalPage = totalPage;
    }

    public List getResults() {
    	if(null != results && results.size() == 0){
            return null;
    	}
        return results;
    }

    public void setResults(List results) {
        this.results = results;
    }



}

最后在mybatis-config.xml配置中将该插件给配置进去,并对databaseType进行赋值。


	  
		
			
				mysql
			
		
	

 

你可能感兴趣的:(mybatis 多数据库通用分页插件PageInterceptor)