JDBCTemplate+JavaPOJO实现通用DAO

最近在公司实习过程中,TL提出一个需求,要求在不使用Hibernate的情况下实现一个比较通用的DAO框架,使用JDBCTemplate作为数据库sql语句的执行工具。在参考了CloudStack 3.0.2的相关源代码后,我自己实现了一个简化版的DAO框架。结果后来,TL又说改用Python开发,遗憾地把这些东西留作纪念吧。

 

简单的类图参见连接http://pan.baidu.com/share/link?shareid=115118&uk=3592520259

 

环境为MyEclipse8.5+Spring2.5,使用jar为asm-3.3.1、cglib-2.2、mysql-connector

1,编程思想

本质上是将某些通用的API,如最基础的CRUD直接通过泛型类来实现。唯一的比较难处理的就是Update时,哪些属性需要更新,这可以拦截通过CGLIB类库实现对setter方法的拦截并记录被改变的属性。

 

2,代码

核心类BaseDaoImpl

package DataBaseDemo.daoimpl;

import java.lang.reflect.Field;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;

import javax.sql.DataSource;

import net.sf.cglib.proxy.Enhancer;

import org.apache.log4j.Logger;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.jdbc.core.PreparedStatementCreator;
import org.springframework.jdbc.core.PreparedStatementSetter;
import org.springframework.jdbc.datasource.DataSourceTransactionManager;
import org.springframework.jdbc.support.GeneratedKeyHolder;
import org.springframework.jdbc.support.KeyHolder;
import org.springframework.transaction.PlatformTransactionManager;
import org.springframework.transaction.TransactionDefinition;
import org.springframework.transaction.support.DefaultTransactionDefinition;

import DataBaseDemo.interceptor.UpdateFactory;
import DataBaseDemo.util.DBUtils;
import DataBaseDemo.util.ModelRowMapper;

import com.mysql.jdbc.Statement;

public class BaseDaoImpl<T> {
    Logger logger=Logger.getLogger(BaseDaoImpl.class);
    //POJO类的实际类型
	Class<T> entityType;
	//简单地将POJO类名映射成数据库表名
	String table;
	public static JdbcTemplate jdbcTemplate;
    public static PlatformTransactionManager transactionManager;
    public static DefaultTransactionDefinition  transactionDef;
	@SuppressWarnings("unchecked")
	BaseDaoImpl() {
		DataSource datasource=DBUtils.configureDatasource();
		jdbcTemplate = new JdbcTemplate(datasource);
		transactionManager=new DataSourceTransactionManager(datasource);
		transactionDef=new DefaultTransactionDefinition(TransactionDefinition.PROPAGATION_REQUIRED);
		Type t = getClass().getGenericSuperclass();
		// 使用该语法后,BaseDaoImpl无法正常使用,只能通过子类调用它
		// 用于获取实际输入的Model类型
		if (t instanceof ParameterizedType) {
			entityType = (Class<T>) ((ParameterizedType) t)
					.getActualTypeArguments()[0];
		} else if (((Class<?>) t).getGenericSuperclass() instanceof ParameterizedType) {
			entityType = (Class<T>) ((ParameterizedType) ((Class<?>) t)
					.getGenericSuperclass()).getActualTypeArguments()[0];
		} else {
			entityType = (Class<T>) ((ParameterizedType) ((Class<?>) ((Class<?>) t)
					.getGenericSuperclass()).getGenericSuperclass())
					.getActualTypeArguments()[0];
		}
		this.table = DBUtils.getTable(entityType);
	}

	@SuppressWarnings("unchecked")
	public List<T> queryAll() {
		String sql = "select * from " + table;
		List<T> list = jdbcTemplate.query(sql, new ModelRowMapper(entityType));
		return list;
	}

	/**
	 * 根据ID,查询一条记录并用实体包装
	 * 
	 * @param id
	 * @return
	 */
	@SuppressWarnings("unchecked")
	public T load(int id) {
		String sql = "select * from " + table + " where id=" + id;
		List<T> list = jdbcTemplate.query(sql, new ModelRowMapper(entityType));
		return list.size() == 0 ? null : list.get(0);
	}

	/**
	 * 根据ID,删除指定表里的记录
	 * 
	 * @param id
	 */
	public void delete(int id) {
		String sql = "delete from " + table + " where id=" + id;
		jdbcTemplate.execute(sql);
	}

	/**
	 * 根据ID更新实体数据到数据库
	 * 
	 * @param entity
	 * @param id
	 */
	@SuppressWarnings("unchecked")
	public void update(T entity, int id) {
		assert Enhancer.isEnhanced(entity.getClass()) : "没有被拦截器监控到更新数据";
		StringBuilder sql = new StringBuilder();
		sql.append("update " + table + " set ");
		System.out.println(entity.hashCode());
		HashMap<String, Object> map = UpdateFactory.getChanges(entity.hashCode());
		List<String> keys = new ArrayList<String>();
		List<Object> values = new ArrayList<Object>();
		Iterator iter = map.entrySet().iterator();
		while (iter.hasNext()) {
			Map.Entry entry = (Map.Entry) iter.next();
			String key = (String) entry.getKey();
			Object val = entry.getValue();
			keys.add(key);
			values.add(val);
		}
		for (int i = 0; i < keys.size(); i++) {
			if (i == keys.size() - 1) {
				sql.append(keys.get(i) + "=? ");
			} else {
				sql.append(keys.get(i) + "=?,");
			}
		}
		sql.append("where id=?");
		logger.info("更新语句:"+sql.toString());
		values.add(id);
		jdbcTemplate.update(sql.toString(), setParams(values.toArray()));
	}

	/**
	 * 插入实体,并返回数据库自增生成的ID
	 * 
	 * @param entity
	 * @return
	 */
	@SuppressWarnings("unchecked")
	public int insert(T entity) {
		final StringBuilder sql = new StringBuilder();
		
		sql.append("insert into " + table + "(");
		HashMap<String, Object> map = getChangesForInsert(entity);
		List<String> columns = new ArrayList<String>();
		final List<Object> values = new ArrayList<Object>();
		Iterator iter = map.entrySet().iterator();
		while (iter.hasNext()) {
			Map.Entry entry = (Map.Entry) iter.next();
			String key = (String) entry.getKey();
			Object val = entry.getValue();
			columns.add(key);
			values.add(val);
		}
		for (int i = 0; i < columns.size(); i++) {
			if (i == columns.size() - 1) {
				sql.append(columns.get(i) + ") values(");
			} else {
				sql.append(columns.get(i) + ",");
			}
		}

		for (int i = 0; i < values.size(); i++) {
			if (i == values.size() - 1) {
				sql.append("?)");
			} else {
				sql.append("?,");
			}
		}
		logger.info("插入语句:"+sql.toString());
		KeyHolder key = new GeneratedKeyHolder();
		final String insertSql=sql.toString();
		jdbcTemplate.update(new PreparedStatementCreator() {

			@Override
			public PreparedStatement createPreparedStatement(Connection con)
					throws SQLException {
				// TODO Auto-generated method stub
				//必须设置Statement.RETURN_GENERATED_KEYS才能进行返回ID
				PreparedStatement ps = jdbcTemplate.getDataSource()
						.getConnection().prepareStatement(insertSql,Statement.RETURN_GENERATED_KEYS);
				for (int i = 0; i < values.size(); i++) {
					ps.setObject(i + 1, values.get(i));
				}
				return ps;
			}
		}, key);
		return key.getKey().intValue();
	}

	/**
	 * 插入实体并返回被插入的实体
	 * 
	 * @param entity
	 * @return
	 */
	public T persist(final T entity) {
		int id=insert(entity);
		//直接通过connection进行提交同样无法成功
//		transaction.commit();
		logger.info("数据库返回的自增ID为:"+id);
		T persisted=load(id);
		return persisted;
	}
	/**
	 * 
	 * @param params
	 * @return
	 */
	@SuppressWarnings("unchecked")
	public List<T> query(SearchCriteria sc){
		String where = sc.generateWhereClause();
		StringBuilder sb = new StringBuilder("select * from "+sc.getTable());
		sb.append(where);
		logger.info("查询语句"+sb.toString());
		logger.info("查询参数"+Arrays.toString(sc.generateParams()));
		List<T> list = jdbcTemplate.query(sb.toString(), setParams(sc.generateParams()),new ModelRowMapper(entityType));
		return list;
	}
	/**
	 * 返回DAO对应的数据库表的总记录数
	 * @return
	 */
	public int getTotalCount(){
		String sql="select count(*) from "+table;
		return jdbcTemplate.queryForInt(sql);
	}
	/**
	 * 以pagesize大小的页,返回第page页的数据
	 * @param page
	 * @param pagesize
	 * @return
	 */
	@SuppressWarnings("unchecked")
	public List<T> getPage(int page,int pagesize){
		if(page<0||pagesize<0){
			throw new IllegalArgumentException("页码或页大小参数不合法");
		}
		String sql="select * from "+table+" limit "+page*pagesize+","+(page+1)*pagesize;
		return jdbcTemplate.query(sql, new ModelRowMapper<T>(entityType));
	}
	/**
	 * 直接执行sql查询语句,param作为参数数组
	 * @param sql
	 * @param params
	 * 返回查询到的结果列表
	 * @return
	 */
	@SuppressWarnings("unchecked")
	public List<T> executeRawSql(String sql,Object[] params){
		return jdbcTemplate.query(sql, setParams(params)
				, new ModelRowMapper<T>(entityType));
	}

	/**
	 * 设置查询用的参数列表
	 * @param params
	 * @return
	 */
	protected PreparedStatementSetter setParams(final Object[] params) {
		return new PreparedStatementSetter() {

			@Override
			public void setValues(PreparedStatement ps) throws SQLException {
				// TODO Auto-generated method stub
				for (int i = 0; i < params.length; i++) {
					ps.setObject(i + 1, params[i]);
				}
			}
		};
	}
	/**
	 * 返回待插入实体上的所有非空属性值及属性名的Map
	 * @param entity
	 * @return
	 */
	protected HashMap<String, Object> getChangesForInsert(T entity){
		Field[] fields = entityType.getDeclaredFields();
		
		HashMap<String, Object> insertValues = new HashMap<String, Object>();
		try {
			for (Field field : fields) {
				field.setAccessible(true);
				//跳过id字段
				if("id".equalsIgnoreCase(field.getName()))
					continue;
				Object value = field.get(entity);
				if (value == null)
					continue;
				insertValues.put(field.getName(), value);
				
			}
			return insertValues;
		} catch (IllegalArgumentException e) {
			// TODO Auto-generated catch block
			e.printStackTrace();
		} catch (IllegalAccessException e) {
			// TODO Auto-generated catch block
			e.printStackTrace();
		}
		return null;
	}

}
 

子类Dao实例:UserDaoImpl

package DataBaseDemo.daoimpl;

import java.util.List;

import DataBaseDemo.dao.UserDao;
import DataBaseDemo.model.UserVO;

public class UserDaoImpl extends BaseDaoImpl<UserVO> implements UserDao {

	@Override
	public UserVO queryUser() {
		// TODO Auto-generated method stub
		UserVO user=new UserVO();
		return user;
	}
	//自定义的高级查询包装
	public List<UserVO> listUsers(){
		return queryAll();
	}

}
 

核心拦截工厂对Model被修改的属性进行记录并通过CGLIB的接口进行拦截

package DataBaseDemo.interceptor;

import java.util.HashMap;

import net.sf.cglib.proxy.Callback;
import net.sf.cglib.proxy.Enhancer;
import net.sf.cglib.proxy.NoOp;
/**
 * 使用UpdateFactory存放对象被改变的属性及其值
 * 以对象的hashCode为key,值为被改变的HashMap
 * @author Administrator
 *
 */
public class UpdateFactory {


	public static HashMap<Integer,HashMap<String,Object>> changes;
	private static Enhancer enhancer;
	//以字典的方式记录每个对象的改变属性值
	static{
		changes=new HashMap<Integer, HashMap<String,Object>>();
	}
	
	/**
	 * 根据对象的hashCode存储对象被改变的属性值
	 * @param hash
	 * @param key
	 * @param value
	 */
	public static void addChange(Integer hash,String key,Object value){
		HashMap<String, Object> orginal=changes.get(hash);
		if(orginal==null){
			orginal=new HashMap<String, Object>();
			orginal.put(key, value);
		}else{
			orginal.put(key, value);
		}
		
		changes.put(hash, orginal);
	}
	
	/**
	 * 以对象的hashCode取出对象的所有变更
	 * @param hash
	 * @return
	 */
	public static HashMap<String, Object> getChanges(Integer hash){
		return changes.get(hash);
	}
	
	
	// 通过工厂生成对象,并产生拦截器,拦截set方法生成被改变的值Map
	/**
	 * 根据对象class生成对应实例,并使它的修改能够被CGLIB拦截
	 */
	public static Object createVO(Class<?> clazz) {
		enhancer = new Enhancer();
		enhancer.setSuperclass(clazz);
		Callback[] callbacks;
		callbacks = new Callback[] { NoOp.INSTANCE, new UpdateInterceptor() };
		enhancer.setCallbacks(callbacks);
		enhancer.setCallbackFilter(new SetFilter());
		return enhancer.create();
	}
}
 package DataBaseDemo.util;
import java.lang.reflect.Field;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.HashMap;

import org.springframework.jdbc.core.RowMapper;

import DataBaseDemo.interceptor.UpdateFactory;

/**
 * 使用包cglib和asm来创建对某一对象setters方法的拦截器
 * 
 * @author Administrator
 * 
 */
public class ModelRowMapper<T> implements RowMapper {

	/**
	 * @param args
	 */
	Class<?> clazz;

	public ModelRowMapper(Class<?> clazz) {
		this.clazz = clazz;
	}

	// RowMapper中直接通过field给字段设值,避免干扰set拦截器的使用
	public static Object setValues(HashMap<String, Object> map, Object entity) {
		Field[] fields = entity.getClass().getDeclaredFields();
		try {
			for (Field field : fields) {
				Object value = map.get(field.getName());
				if (value != null) {
					field.setAccessible(true);
					field.set(entity, value);
				}
			}
		} catch (IllegalArgumentException e) {
			// TODO Auto-generated catch block
			e.printStackTrace();
		} catch (IllegalAccessException e) {
			// TODO Auto-generated catch block
			e.printStackTrace();
		}
		return entity;
	}

	public void setValues(ResultSet rs, Object entity) {
		Field[] fields = clazz.getDeclaredFields();
		try {
			for (Field field : fields) {
				Object value = rs.getObject(field.getName());
				field.setAccessible(true);
				field.set(entity, value);
			}
		} catch (SQLException e) {
			// TODO Auto-generated catch block
			e.printStackTrace();
		} catch (IllegalArgumentException e) {
			// TODO Auto-generated catch block
			e.printStackTrace();
		} catch (IllegalAccessException e) {
			// TODO Auto-generated catch block
			e.printStackTrace();
		}
	}
	
	@SuppressWarnings("unchecked")
	@Override
	public T mapRow(ResultSet rs, int rowNum) throws SQLException {
		//通过更新工厂的静态方法创建类实例,使它被CGLIB监控
		T entity = (T) UpdateFactory.createVO(clazz);
		setValues(rs, entity);
		return entity;
	}

}

 3,缺点

--目前的查询非常简单,需要进行优化

--无法支持事务管理,原因不明,进一步研究中

所有源码参照CloudStack3.0.2的相关代码编写

你可能感兴趣的:(反射,cglib,Java框架实现,通用Dao)