最近在公司实习过程中,TL提出一个需求,要求在不使用Hibernate的情况下实现一个比较通用的DAO框架,使用JDBCTemplate作为数据库sql语句的执行工具。在参考了CloudStack 3.0.2的相关源代码后,我自己实现了一个简化版的DAO框架。结果后来,TL又说改用Python开发,遗憾地把这些东西留作纪念吧。
简单的类图参见连接http://pan.baidu.com/share/link?shareid=115118&uk=3592520259
环境为MyEclipse8.5+Spring2.5,使用jar为asm-3.3.1、cglib-2.2、mysql-connector
1,编程思想
本质上是将某些通用的API,如最基础的CRUD直接通过泛型类来实现。唯一的比较难处理的就是Update时,哪些属性需要更新,这可以拦截通过CGLIB类库实现对setter方法的拦截并记录被改变的属性。
2,代码
核心类BaseDaoImpl
package DataBaseDemo.daoimpl; import java.lang.reflect.Field; import java.lang.reflect.ParameterizedType; import java.lang.reflect.Type; import java.sql.Connection; import java.sql.PreparedStatement; import java.sql.SQLException; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Map; import javax.sql.DataSource; import net.sf.cglib.proxy.Enhancer; import org.apache.log4j.Logger; import org.springframework.jdbc.core.JdbcTemplate; import org.springframework.jdbc.core.PreparedStatementCreator; import org.springframework.jdbc.core.PreparedStatementSetter; import org.springframework.jdbc.datasource.DataSourceTransactionManager; import org.springframework.jdbc.support.GeneratedKeyHolder; import org.springframework.jdbc.support.KeyHolder; import org.springframework.transaction.PlatformTransactionManager; import org.springframework.transaction.TransactionDefinition; import org.springframework.transaction.support.DefaultTransactionDefinition; import DataBaseDemo.interceptor.UpdateFactory; import DataBaseDemo.util.DBUtils; import DataBaseDemo.util.ModelRowMapper; import com.mysql.jdbc.Statement; public class BaseDaoImpl<T> { Logger logger=Logger.getLogger(BaseDaoImpl.class); //POJO类的实际类型 Class<T> entityType; //简单地将POJO类名映射成数据库表名 String table; public static JdbcTemplate jdbcTemplate; public static PlatformTransactionManager transactionManager; public static DefaultTransactionDefinition transactionDef; @SuppressWarnings("unchecked") BaseDaoImpl() { DataSource datasource=DBUtils.configureDatasource(); jdbcTemplate = new JdbcTemplate(datasource); transactionManager=new DataSourceTransactionManager(datasource); transactionDef=new DefaultTransactionDefinition(TransactionDefinition.PROPAGATION_REQUIRED); Type t = getClass().getGenericSuperclass(); // 使用该语法后,BaseDaoImpl无法正常使用,只能通过子类调用它 // 用于获取实际输入的Model类型 if (t instanceof ParameterizedType) { entityType = (Class<T>) ((ParameterizedType) t) .getActualTypeArguments()[0]; } else if (((Class<?>) t).getGenericSuperclass() instanceof ParameterizedType) { entityType = (Class<T>) ((ParameterizedType) ((Class<?>) t) .getGenericSuperclass()).getActualTypeArguments()[0]; } else { entityType = (Class<T>) ((ParameterizedType) ((Class<?>) ((Class<?>) t) .getGenericSuperclass()).getGenericSuperclass()) .getActualTypeArguments()[0]; } this.table = DBUtils.getTable(entityType); } @SuppressWarnings("unchecked") public List<T> queryAll() { String sql = "select * from " + table; List<T> list = jdbcTemplate.query(sql, new ModelRowMapper(entityType)); return list; } /** * 根据ID,查询一条记录并用实体包装 * * @param id * @return */ @SuppressWarnings("unchecked") public T load(int id) { String sql = "select * from " + table + " where id=" + id; List<T> list = jdbcTemplate.query(sql, new ModelRowMapper(entityType)); return list.size() == 0 ? null : list.get(0); } /** * 根据ID,删除指定表里的记录 * * @param id */ public void delete(int id) { String sql = "delete from " + table + " where id=" + id; jdbcTemplate.execute(sql); } /** * 根据ID更新实体数据到数据库 * * @param entity * @param id */ @SuppressWarnings("unchecked") public void update(T entity, int id) { assert Enhancer.isEnhanced(entity.getClass()) : "没有被拦截器监控到更新数据"; StringBuilder sql = new StringBuilder(); sql.append("update " + table + " set "); System.out.println(entity.hashCode()); HashMap<String, Object> map = UpdateFactory.getChanges(entity.hashCode()); List<String> keys = new ArrayList<String>(); List<Object> values = new ArrayList<Object>(); Iterator iter = map.entrySet().iterator(); while (iter.hasNext()) { Map.Entry entry = (Map.Entry) iter.next(); String key = (String) entry.getKey(); Object val = entry.getValue(); keys.add(key); values.add(val); } for (int i = 0; i < keys.size(); i++) { if (i == keys.size() - 1) { sql.append(keys.get(i) + "=? "); } else { sql.append(keys.get(i) + "=?,"); } } sql.append("where id=?"); logger.info("更新语句:"+sql.toString()); values.add(id); jdbcTemplate.update(sql.toString(), setParams(values.toArray())); } /** * 插入实体,并返回数据库自增生成的ID * * @param entity * @return */ @SuppressWarnings("unchecked") public int insert(T entity) { final StringBuilder sql = new StringBuilder(); sql.append("insert into " + table + "("); HashMap<String, Object> map = getChangesForInsert(entity); List<String> columns = new ArrayList<String>(); final List<Object> values = new ArrayList<Object>(); Iterator iter = map.entrySet().iterator(); while (iter.hasNext()) { Map.Entry entry = (Map.Entry) iter.next(); String key = (String) entry.getKey(); Object val = entry.getValue(); columns.add(key); values.add(val); } for (int i = 0; i < columns.size(); i++) { if (i == columns.size() - 1) { sql.append(columns.get(i) + ") values("); } else { sql.append(columns.get(i) + ","); } } for (int i = 0; i < values.size(); i++) { if (i == values.size() - 1) { sql.append("?)"); } else { sql.append("?,"); } } logger.info("插入语句:"+sql.toString()); KeyHolder key = new GeneratedKeyHolder(); final String insertSql=sql.toString(); jdbcTemplate.update(new PreparedStatementCreator() { @Override public PreparedStatement createPreparedStatement(Connection con) throws SQLException { // TODO Auto-generated method stub //必须设置Statement.RETURN_GENERATED_KEYS才能进行返回ID PreparedStatement ps = jdbcTemplate.getDataSource() .getConnection().prepareStatement(insertSql,Statement.RETURN_GENERATED_KEYS); for (int i = 0; i < values.size(); i++) { ps.setObject(i + 1, values.get(i)); } return ps; } }, key); return key.getKey().intValue(); } /** * 插入实体并返回被插入的实体 * * @param entity * @return */ public T persist(final T entity) { int id=insert(entity); //直接通过connection进行提交同样无法成功 // transaction.commit(); logger.info("数据库返回的自增ID为:"+id); T persisted=load(id); return persisted; } /** * * @param params * @return */ @SuppressWarnings("unchecked") public List<T> query(SearchCriteria sc){ String where = sc.generateWhereClause(); StringBuilder sb = new StringBuilder("select * from "+sc.getTable()); sb.append(where); logger.info("查询语句"+sb.toString()); logger.info("查询参数"+Arrays.toString(sc.generateParams())); List<T> list = jdbcTemplate.query(sb.toString(), setParams(sc.generateParams()),new ModelRowMapper(entityType)); return list; } /** * 返回DAO对应的数据库表的总记录数 * @return */ public int getTotalCount(){ String sql="select count(*) from "+table; return jdbcTemplate.queryForInt(sql); } /** * 以pagesize大小的页,返回第page页的数据 * @param page * @param pagesize * @return */ @SuppressWarnings("unchecked") public List<T> getPage(int page,int pagesize){ if(page<0||pagesize<0){ throw new IllegalArgumentException("页码或页大小参数不合法"); } String sql="select * from "+table+" limit "+page*pagesize+","+(page+1)*pagesize; return jdbcTemplate.query(sql, new ModelRowMapper<T>(entityType)); } /** * 直接执行sql查询语句,param作为参数数组 * @param sql * @param params * 返回查询到的结果列表 * @return */ @SuppressWarnings("unchecked") public List<T> executeRawSql(String sql,Object[] params){ return jdbcTemplate.query(sql, setParams(params) , new ModelRowMapper<T>(entityType)); } /** * 设置查询用的参数列表 * @param params * @return */ protected PreparedStatementSetter setParams(final Object[] params) { return new PreparedStatementSetter() { @Override public void setValues(PreparedStatement ps) throws SQLException { // TODO Auto-generated method stub for (int i = 0; i < params.length; i++) { ps.setObject(i + 1, params[i]); } } }; } /** * 返回待插入实体上的所有非空属性值及属性名的Map * @param entity * @return */ protected HashMap<String, Object> getChangesForInsert(T entity){ Field[] fields = entityType.getDeclaredFields(); HashMap<String, Object> insertValues = new HashMap<String, Object>(); try { for (Field field : fields) { field.setAccessible(true); //跳过id字段 if("id".equalsIgnoreCase(field.getName())) continue; Object value = field.get(entity); if (value == null) continue; insertValues.put(field.getName(), value); } return insertValues; } catch (IllegalArgumentException e) { // TODO Auto-generated catch block e.printStackTrace(); } catch (IllegalAccessException e) { // TODO Auto-generated catch block e.printStackTrace(); } return null; } }
子类Dao实例:UserDaoImpl
package DataBaseDemo.daoimpl; import java.util.List; import DataBaseDemo.dao.UserDao; import DataBaseDemo.model.UserVO; public class UserDaoImpl extends BaseDaoImpl<UserVO> implements UserDao { @Override public UserVO queryUser() { // TODO Auto-generated method stub UserVO user=new UserVO(); return user; } //自定义的高级查询包装 public List<UserVO> listUsers(){ return queryAll(); } }
核心拦截工厂对Model被修改的属性进行记录并通过CGLIB的接口进行拦截
package DataBaseDemo.interceptor; import java.util.HashMap; import net.sf.cglib.proxy.Callback; import net.sf.cglib.proxy.Enhancer; import net.sf.cglib.proxy.NoOp; /** * 使用UpdateFactory存放对象被改变的属性及其值 * 以对象的hashCode为key,值为被改变的HashMap * @author Administrator * */ public class UpdateFactory { public static HashMap<Integer,HashMap<String,Object>> changes; private static Enhancer enhancer; //以字典的方式记录每个对象的改变属性值 static{ changes=new HashMap<Integer, HashMap<String,Object>>(); } /** * 根据对象的hashCode存储对象被改变的属性值 * @param hash * @param key * @param value */ public static void addChange(Integer hash,String key,Object value){ HashMap<String, Object> orginal=changes.get(hash); if(orginal==null){ orginal=new HashMap<String, Object>(); orginal.put(key, value); }else{ orginal.put(key, value); } changes.put(hash, orginal); } /** * 以对象的hashCode取出对象的所有变更 * @param hash * @return */ public static HashMap<String, Object> getChanges(Integer hash){ return changes.get(hash); } // 通过工厂生成对象,并产生拦截器,拦截set方法生成被改变的值Map /** * 根据对象class生成对应实例,并使它的修改能够被CGLIB拦截 */ public static Object createVO(Class<?> clazz) { enhancer = new Enhancer(); enhancer.setSuperclass(clazz); Callback[] callbacks; callbacks = new Callback[] { NoOp.INSTANCE, new UpdateInterceptor() }; enhancer.setCallbacks(callbacks); enhancer.setCallbackFilter(new SetFilter()); return enhancer.create(); } }package DataBaseDemo.util;
import java.lang.reflect.Field; import java.sql.ResultSet; import java.sql.SQLException; import java.util.HashMap; import org.springframework.jdbc.core.RowMapper; import DataBaseDemo.interceptor.UpdateFactory; /** * 使用包cglib和asm来创建对某一对象setters方法的拦截器 * * @author Administrator * */ public class ModelRowMapper<T> implements RowMapper { /** * @param args */ Class<?> clazz; public ModelRowMapper(Class<?> clazz) { this.clazz = clazz; } // RowMapper中直接通过field给字段设值,避免干扰set拦截器的使用 public static Object setValues(HashMap<String, Object> map, Object entity) { Field[] fields = entity.getClass().getDeclaredFields(); try { for (Field field : fields) { Object value = map.get(field.getName()); if (value != null) { field.setAccessible(true); field.set(entity, value); } } } catch (IllegalArgumentException e) { // TODO Auto-generated catch block e.printStackTrace(); } catch (IllegalAccessException e) { // TODO Auto-generated catch block e.printStackTrace(); } return entity; } public void setValues(ResultSet rs, Object entity) { Field[] fields = clazz.getDeclaredFields(); try { for (Field field : fields) { Object value = rs.getObject(field.getName()); field.setAccessible(true); field.set(entity, value); } } catch (SQLException e) { // TODO Auto-generated catch block e.printStackTrace(); } catch (IllegalArgumentException e) { // TODO Auto-generated catch block e.printStackTrace(); } catch (IllegalAccessException e) { // TODO Auto-generated catch block e.printStackTrace(); } } @SuppressWarnings("unchecked") @Override public T mapRow(ResultSet rs, int rowNum) throws SQLException { //通过更新工厂的静态方法创建类实例,使它被CGLIB监控 T entity = (T) UpdateFactory.createVO(clazz); setValues(rs, entity); return entity; } }
3,缺点
--目前的查询非常简单,需要进行优化
--无法支持事务管理,原因不明,进一步研究中
所有源码参照CloudStack3.0.2的相关代码编写