1.定义Repository接口
<strong>package com.bjhy.platform.persist.dao; import java.io.Serializable; import java.util.List; import java.util.Map; import org.springframework.data.jpa.repository.JpaRepository; import org.springframework.data.repository.NoRepositoryBean; import com.bjhy.platform.commons.pager.Condition; import com.bjhy.platform.commons.pager.Order; import com.bjhy.platform.commons.pager.PageBean; /** * @author wbw * 自定义repository的方法接口 */ @NoRepositoryBean public interface CommonRepository<T, ID extends Serializable> extends JpaRepository<T, ID>{ /** * 保存对象<br/> * 注意:如果对象id是字符串,并且没有赋值,该方法将自动设置为uuid值 * @param item * 持久对象,或者对象集合 * @throws Exception */ public void store(Object... item); /** * 更新对象数据 * * @param item * 持久对象,或者对象集合 * @throws Exception */ public void update(Object... item); /** * 执行ql语句 * @param qlString 基于jpa标准的ql语句 * @param values ql中的?参数值,单个参数值或者多个参数值 * @return 返回执行后受影响的数据个数 */ public int executeUpdate(String qlString, Object... values); /** * 执行ql语句 * @param qlString 基于jpa标准的ql语句 * @param params key表示ql中参数变量名,value表示该参数变量值 * @return 返回执行后受影响的数据个数 */ public int executeUpdate(String qlString, Map<String, Object> params); /** * 执行ql语句,可以是更新或者删除操作 * @param qlString 基于jpa标准的ql语句 * @param values ql中的?参数值 * @return 返回执行后受影响的数据个数 * @throws Exception */ public int executeUpdate(String qlString, List<Object> values); /** * 结合提供的分页信息,获取指定条件下的数据对象 * @param pageBean 分页信息 * @param qlString 基于jpa标准的ql语句 */ public void doPager(PageBean pageBean, String qlString); /** * 结合提供的分页信息,获取指定条件下的数据对象 * @param pageBean 分页信息 * @param qlString 基于jpa标准的ql语句 * @param cacheable 是否启用缓存查询 */ public void doPager(PageBean pageBean,String qlString,boolean cacheable); /** * 结合提供的分页信息,获取指定条件下的数据对象 * @param pageBean 分页信息 * @param qlString 基于jpa标准的ql语句 * @param params key表示ql中参数变量名,value表示该参数变量值 */ public void doPager(PageBean pageBean, String qlString, Map<String, Object> params); /** * 结合提供的分页信息,获取指定条件下的数据对象 * @param pageBean 分页信息 * @param qlString 基于jpa标准的ql语句 * @param values ql中的?参数值 */ public void doPager(PageBean pageBean, String qlString, List<Object> values); /** * 结合提供的分页信息,获取指定条件下的数据对象 * @param pageBean 分页信息 * @param qlString 基于jpa标准的ql语句 * @param values ql中的?参数值 */ public void doPager(PageBean pageBean, String qlString, Object... values); /** * 批量删除数据对象 * @param entityClass * @param primaryKeyValues * @return */ public int batchDeleteByQl(Class<?> entityClass,Object...primaryKeyValues); /** * 批量删除数据对象 * @param entityClass * @param pKeyVals 主键是字符串形式的 * @return * @throws Exception */ public int batchDeleteByQl(Class<?> entityClass,String...pKeyVals); /** * @param qlString 查询hql语句 * @param values hql参数值 * @param conditions 查询条件 * @param orders 排序条件 * @return */ public List<?> doList(String qlString, List<Object> values, List<Condition> conditions, List<Order> orders, boolean sqlable); public List<?> doList(String qlString, List<Condition> conditions, List<Order> orders, boolean sqlable);</strong> }
<strong>2.自定义repository的方法接口实现类</strong>
<strong></strong><pre name="code" class="java">package com.bjhy.platform.persist.dao; import java.io.Serializable; import java.lang.reflect.Field; import java.lang.reflect.Method; import java.text.ParseException; import java.text.SimpleDateFormat; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.Date; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.regex.Matcher; import java.util.regex.Pattern; import javax.persistence.EntityManager; import javax.persistence.Id; import javax.persistence.Query; import org.apache.log4j.Logger; import org.hibernate.annotations.QueryHints; import org.springframework.data.jpa.repository.support.SimpleJpaRepository; import org.springframework.util.StringUtils; import com.bjhy.platform.commons.pager.Condition; import com.bjhy.platform.commons.pager.Operation; import com.bjhy.platform.commons.pager.Order; import com.bjhy.platform.commons.pager.OrderType; import com.bjhy.platform.commons.pager.PageBean; import com.bjhy.platform.commons.pager.RelateType; import com.bjhy.platform.util.CglibBeanUtil; import com.bjhy.platform.util.UUIDUtil; /** * @author wbw * 自定义repository的方法接口实现类 */ public class CommonRepositoryImpl<T, ID extends Serializable> extends SimpleJpaRepository<T, Serializable> implements CommonRepository<T, Serializable>{ Logger logger = Logger.getLogger(CommonRepositoryImpl.class); private final EntityManager entityManager; public CommonRepositoryImpl(Class<T> domainClass, EntityManager em) { super(domainClass, em); this.entityManager = em; } @Override public void store(Object... item) { if (null != item) { for (Object entity : item) { innerSave(entity); } } } @Override public void update(Object... item) { if (null != item) { for (Object entity : item) { entityManager.merge(entity); } } } @Override public int executeUpdate(String qlString, Object... values) { Query query = entityManager.createQuery(qlString); if (values != null) { for (int i = 0; i < values.length; i++) { query.setParameter(i + 1, values[i]); } } return query.executeUpdate(); } @Override public int executeUpdate(String qlString, Map<String, Object> params) { Query query = entityManager.createQuery(qlString); for (String name : params.keySet()) { query.setParameter(name, params.get(name)); } return query.executeUpdate(); } @Override public int executeUpdate(String qlString, List<Object> values) { Query query = entityManager.createQuery(qlString); for (int i = 0; i < values.size(); i++) { query.setParameter(i + 1, values.get(i)); } return query.executeUpdate(); } @Override public void doPager(PageBean pageBean, String qlString) { doPager(pageBean, qlString, new ArrayList<Object>()); } @Override public void doPager(PageBean pageBean, String qlString, boolean cacheable) { doPager(pageBean, qlString, new ArrayList<Object>(),cacheable); } @Override public void doPager(PageBean pageBean, String qlString, Map<String, Object> params) { List<Object> values = new ArrayList<Object>();// 来自params的values,值的顺序按照qlString的参数顺序存放 qlString = preQLAndParam(qlString, params, values);// 2.将name形式的参数解析成?号形式的参数,并返回参数值集合 doPager(pageBean, qlString, values); } @Override public void doPager(PageBean pageBean, String qlString, List<Object> values){ doPager(pageBean, qlString, values, false); } @Override public void doPager(PageBean pageBean, String qlString, Object... values){ List<Object> list = new ArrayList<Object>(); if(values!=null){ for (Object value : values) { list.add(value); } } doPager(pageBean, qlString, list); } private void doPager(PageBean pageBean, String qlString, List<Object> values,boolean cacheable){ if (values == null) { values = new ArrayList<Object>(); } qlString = convertQL(qlString);// 1.转换ql为指定格式的语句 List<Object> conValues = new ArrayList<Object>();// 条件参数值集合,来自pageBean.conditions的values List<Object> list = preConditionJPQL(qlString, pageBean.getConditions(), conValues);// 解析条件语句,获取条件参数集合 int conBeginIndex = (Integer) list.get(0);// 返回条件集合在hql起始位置 String condition_jpql = (String) list.get(1);// 获取条件ql语句 String order_jpql = preOrderJPQL(pageBean.getOrders());// 获取排序ql语句 String list_ql = preQL(qlString, condition_jpql, order_jpql);// 获取完整的list ql语句 String count_ql = preCountJPQL(qlString, condition_jpql);// 获取完整的count ql语句 logger.info("count_ql = {" + count_ql + "}"); // 合并数据 for (int i = conValues.size() - 1; i >= 0; i--) { values.add(conBeginIndex, conValues.get(i)); } executeCount(pageBean, count_ql, values, conBeginIndex,cacheable);// 执行count语句,将填充pageBean中的totalRows executeList(pageBean, list_ql, values, cacheable, false, true);// 执行list语句,将填充pageBean中的items } public List<?> doList(String qlString, List<Condition> conditions, List<Order> orders, boolean sqlable){ return doList(qlString, new ArrayList<Object>(), conditions, orders, sqlable); } public List<?> doList(String qlString, List<Object> values, List<Condition> conditions, List<Order> orders, boolean sqlable){ if (values == null) { values = new ArrayList<Object>(); } qlString = convertQL(qlString);// 1.转换ql为指定格式的语句 List<Object> conValues = new ArrayList<Object>();// 条件参数值集合,来自pageBean.conditions的values List<Object> list = preConditionJPQL(qlString, conditions, conValues);// 解析条件语句,获取条件参数集合 int conBeginIndex = (Integer) list.get(0);// 返回条件集合在hql起始位置 String condition_jpql = (String) list.get(1);// 获取条件ql语句 String order_jpql = preOrderJPQL(orders);// 获取排序ql语句 String list_ql = preQL(qlString, condition_jpql, order_jpql);// 获取完整的list ql语句 // 合并数据 for (int i = conValues.size() - 1; i >= 0; i--) { values.add(conBeginIndex, conValues.get(i)); } PageBean pageBean = new PageBean(); executeList(pageBean, list_ql, values, false, sqlable, false); return pageBean.getItems();// 执行list语句,将填充pageBean中的items } @Override public int batchDeleteByQl(Class<?> entityClass, Object... primaryKeyValues) { if(primaryKeyValues==null || primaryKeyValues.length<1)return 0; StringBuilder sb = new StringBuilder("delete from " + entityClass.getSimpleName()+ " where " + getIdName(entityClass) + " in ("); for (int i = 0; i < primaryKeyValues.length; i++) { sb.append("?,"); } sb.replace(sb.length()-1, sb.length(), ""); sb.append(")"); Query query = entityManager.createQuery(sb.toString()); for (int i = 1; i <= primaryKeyValues.length; i++) { query.setParameter(i, primaryKeyValues[i-1]); } int result = query.executeUpdate(); entityManager.clear(); return result; } @Override public int batchDeleteByQl(Class<?> entityClass, String... pKeyVals) { Object[]pvals = pKeyVals; return batchDeleteByQl(entityClass, pvals); } private Serializable innerSave(Object item) { try { if(item==null)return null; Class<?> clazz = item.getClass(); Field idField = getIdField(clazz); Method getMethod = null; if(idField!=null){ Class<?> type = idField.getType(); Object val = idField.get(item); if(type == String.class && (val==null || "".equals(val))){ idField.set(item, UUIDUtil.uuid()); } }else{ Method[] methods = clazz.getDeclaredMethods(); for (Method method : methods) { Id id = method.getAnnotation(Id.class); if (id != null) { Object val = method.invoke(item); if(val==null || "".equals(val)){ String methodName = "s" + method.getName().substring(1); Method setMethod = clazz.getDeclaredMethod(methodName, method.getReturnType()); if(setMethod!=null){ setMethod.invoke(item, UUIDUtil.uuid()); } } getMethod = method; break; } } } entityManager.persist(item); entityManager.flush(); if(idField!=null){ return (Serializable) idField.get(item); } if(getMethod!=null){ return (Serializable)getMethod.invoke(item); } return null; } catch (Exception e) { e.printStackTrace(); throw new RuntimeException(e); } } private String getIdName(Class<?> clazz) { Field idField = getIdField(clazz); if (idField != null) { return idField.getName(); } return null; } private Field getIdField(Class<?> clazz) { Field[] fields = clazz.getDeclaredFields(); Field item = null; for (Field field : fields) { Id id = field.getAnnotation(Id.class); if (id != null) { field.setAccessible(true); item = field; break; } } if(item==null){ Class<?> superclass = clazz.getSuperclass(); if(superclass!=null){ item = getIdField(superclass); } } return item; } private void setParameter(Query query, int position, Object value){ try { query.setParameter(position, value); } catch (IllegalArgumentException e) { logger.info("WARN : " + e.getMessage()); Pattern p = Pattern.compile("(\\w+\\.\\w+\\.\\w+)"); Matcher matcher = p.matcher(e.getMessage()); while(matcher.find()){ String clazz = matcher.group(1); if (Integer.class.getName().equals(clazz)) { value = Integer.parseInt(value.toString()); }else if(Long.class.getName().equals(clazz)){ value = Long.parseLong(value.toString()); }else if(Double.class.getName().equals(clazz)){ value = Double.parseDouble(value.toString()); }else if(Boolean.class.getName().equals(clazz)){ value = Boolean.parseBoolean(value.toString()); }else if(Date.class.getName().equals(clazz)){ String temp = value.toString().replaceAll("%", ""); String pattern = null; if(temp.matches("^\\d{4}-\\d{2}-\\d{2}$")){ pattern = "yyyy-MM-dd"; }else if(temp.matches("^\\d{8}$")){ pattern = "yyyyMMdd"; }else if(temp.matches("^\\d{4}-\\d{2}-\\d{2} \\d{2}:\\d{2}:\\d{2}$")){ pattern = "yyyy-MM-dd HH:mm:ss"; } if(pattern!=null){ SimpleDateFormat sdf = new SimpleDateFormat(pattern); try { value = sdf.parse(temp); } catch (ParseException e1) { throw new RuntimeException(e1); } }else{ throw new RuntimeException("传递的日期值"+temp+"不能够被识别"); } } query.setParameter(position, value); } } } /** * 执行集合语句 */ protected void executeList(PageBean pageBean, String list_ql, List<Object> values, boolean cacheable, boolean sqlable, boolean isPage){ Query query = null; if(sqlable){ query = entityManager.createNativeQuery(list_ql); query.setHint(QueryHints.CACHEABLE, cacheable); }else{ query = entityManager.createQuery(list_ql); } for (int i = 0; i < values.size(); i++) { setParameter(query, i + 1, values.get(i)); } if(!sqlable && isPage){ int firstResult = (pageBean.getCurrentPage() - 1) * pageBean.getRowsPerPage(); int maxResults = pageBean.getRowsPerPage(); query.setFirstResult(firstResult) .setMaxResults(maxResults); } List<?> list = query.getResultList(); if(list.size()>0){ Object item = list.get(0); if(item.getClass().isArray()){ String[] fieldArray = preFieldInfo(list_ql); Map<String, Class<?>> propertyMap = preProp(fieldArray,(Object[]) list.get(0)); List<Object> items = new ArrayList<Object>(); for (Object object : list) { Object[] entity = (Object[]) object; CglibBeanUtil bean = new CglibBeanUtil(propertyMap); for (int i = 0; i < fieldArray.length; i++) { bean.setValue(fieldArray[i], entity[i]); } items.add(bean.getObject()); } list = items; } } pageBean.setItems(list); } public Map<String, Class<?>> preProp(String[] fieldArray, Object[] fieldValArray) { Map<String, Class<?>> propertyMap = new HashMap<String, Class<?>>(); Class<?> clazz = null; for (int i = 0; i < fieldArray.length; i++) { clazz = fieldValArray[i]!=null?fieldValArray[i].getClass():Object.class; propertyMap.put(fieldArray[i],clazz); } return propertyMap; } public String[] preFieldInfo(String list_ql) { String[] fieldArray; int firstFormIndex = list_ql.indexOf("FROM"); logger.info("firstFromIndex={" + firstFormIndex + "}"); String prefixFrom = list_ql.substring(0, firstFormIndex); logger.info("prefixFrom={" + prefixFrom + "}"); fieldArray = prefixFrom.replace("SELECT", "").trim().split(","); for (int i = 0; i < fieldArray.length; i++) { String field = fieldArray[i]; String[] s = field.split(" as | AS | "); if(s.length==2){ fieldArray[i] = s[1]; } String[]tempArray = fieldArray[i].split("\\."); fieldArray[i] = tempArray[tempArray.length-1]; } return fieldArray; } /** * 解析ql语句 */ protected String preQL(String qlString, String condition_jpql, String order_jpql) { if (qlString.endsWith("ASC") || qlString.endsWith("DESC")) { order_jpql = order_jpql.replace("ORDER BY", ","); } qlString = qlString.replaceAll("WHERE 1=1", " WHERE 1=1 " + condition_jpql) + order_jpql;// 排序位置有待修改 logger.info("list_ql = {" + qlString + "}"); return qlString; } /** * 解析统计ql语句 * * @param qlString * @return * @throws Exception */ protected String preCountJPQL(String qlString, String condition_jpql){ String countField = "*"; String distinctField = findDistinctField(qlString); if(distinctField!=null){ countField = "DISTINCT " + distinctField; } if (qlString.matches("^FROM.+")) { qlString = "SELECT count("+countField+") " + qlString; } else { int beginIndex = qlString.indexOf("FROM"); qlString = "SELECT count("+countField+") " + qlString.substring(beginIndex); } qlString = qlString.replaceAll("WHERE 1=1", " WHERE 1=1 " + condition_jpql); return qlString; } private String findDistinctField(String qlString) { int index = qlString.indexOf("DISTINCT"); if(index!=-1){ String subql = qlString.substring(index+9); int end = -1; for (int i = 0; i < subql.length(); i++) { if(subql.substring(i,i+1).matches(" |,")){ end = i; break; } } return subql.substring(0, end); } return null; } /** * 解析排序语句 * @param orders * @return */ protected String preOrderJPQL(List<Order> orders) { if (orders.size() == 0) { return ""; } StringBuffer c = new StringBuffer(" ORDER BY "); for (Order order : orders) { String propertyName = order.getPropertyName(); OrderType orderType = order.getOrderType(); c.append(propertyName + " " + orderType + ","); } if (orders.size() > 0) { c.replace(c.length() - 1, c.length(), ""); } logger.info("order = {" + c.toString() + "}"); return c.toString(); } /** * 格式化ql * @param qlString * @return */ protected String convertQL(String qlString) { String result = qlString.replaceAll("from", "FROM") .replaceAll("distinct", "DISTINCT") .replaceAll("select", "SELECT").replaceAll("where", "WHERE") .replaceAll("order by", "ORDER BY").replaceAll("asc", "ASC") .replaceAll("desc", "DESC").trim(); logger.info("qlString = {" + result + "}"); return result; } /** * 解析ql语句和参数 * * @param qlString * @param params * @param values * @return */ protected String preQLAndParam(String qlString, Map<String, Object> params, List<Object> values) { logger.info("开始解析qlString:{" + qlString + "}"); Map<Integer, Object> map_values = new HashMap<Integer, Object>(); Map<Integer, String> map_names = new HashMap<Integer, String>(); List<Integer> list = new ArrayList<Integer>(); String preQL = qlString; if (null != params) { for (String key : params.keySet()) { int index = qlString.indexOf(":" + key); Object value = params.get(key); preQL = preQL.replaceAll(":" + key + " ", "? "); list.add(index); map_values.put(index, value); map_names.put(index, key); } logger.info("解析完成qlString:{" + preQL + "}"); Collections.sort(list); logger.info("最终参数值顺序(参数名->参数位置->参数值):"); for (Integer position : list) { if (logger.isDebugEnabled()) { System.out.println(map_names.get(position) + "->" + position + "->" + map_values.get(position)); } values.add(map_values.get(position)); } } return preQL; } private String preConditionJPQL(List<Condition> conditions, List<Object> values){ StringBuffer c = new StringBuffer(); if (conditions != null && conditions.size() > 0) { c.append(RelateType.AND.toString() + " ( "); for (int i = 0; i < conditions.size(); i++) { Condition condition = conditions.get(i); String groupPrefixBrackets = condition.getGroupPrefixBrackets(); String propertyName = condition.getPropertyName(); Object value = condition.getPropertyValue(); boolean isPrefixBrackets = condition.isPrefixBrackets(); boolean isSuffixBrackets = condition.isSuffixBrackets(); Operation operation = condition.getOperation(); RelateType relateType = condition.getRelateType(); String related = ""; if(i!=0){ if(relateType==null){ relateType = RelateType.AND; } related = relateType + (isPrefixBrackets?" ( ": " "); }else{ related = "" + (isPrefixBrackets?" ( ": " "); } c.append(groupPrefixBrackets); switch (operation) { case NC: case CN: String[] list = value.toString().split("[, ]"); if(list.length>1){ c.append(related + " ( " + propertyName + operation + "?"); values.add("%" + list[0] + "%"); for (int j = 1; j < list.length; j++) { c.append(RelateType.OR + propertyName + operation + "?"); values.add("%" + list[j] + "%"); } c.append(" ) "); }else{ c.append(related + propertyName + operation + "?"); values.add("%" + value + "%"); } break; case BN: case BW: c.append(related + propertyName + operation + "?"); values.add(value + "%"); break; case EN: case EW: c.append(related + propertyName + operation + "?"); values.add("%" + value); break; case BETWEEN: c.append(related + propertyName + operation + "?" + " AND " + "?"); Object[] params = new Object[2]; if (value instanceof String) { String[] array = value.toString().split("#|,"); params[0] = array[0]; params[1] = array[1]; } else { params = (Object[]) value; } values.add(params[0]); values.add(params[1]); break; case NI: case IN: c.append(related + propertyName + operation + "("); if(value!=null){ Class<?> clazz = value.getClass(); if (clazz.isArray()) { Object[] array = (Object[])value; for (Object object : array) { c.append("?,"); values.add(object); } if(array.length>0){ c.replace(c.length() - 1, c.length(), ""); } } else if (value instanceof Collection<?>) { Collection<?> coll = (Collection<?>) value; for (Object object : coll) { c.append("?,"); values.add(object); } if(coll.size()>0){ c.replace(c.length() - 1, c.length(), ""); } }else if(value instanceof String){ if(StringUtils.isEmpty((String)value)){ c.append("NULL"); }else{ String[]array = ((String) value).split(","); for (String val : array) { c.append("?,"); values.add(val); } if(array.length>0){ c.replace(c.length() - 1, c.length(), ""); } } } }else{ c.append("NULL"); } c.append(")"); break; case EQ: case GE: case GT: case LE: case LT: case NE: c.append(related + propertyName + operation + "?"); values.add(value); break; case NN: case NU: c.append(related + propertyName + operation); break; default: break; } c.append(isSuffixBrackets?" ) ": " "); } c.append(" ) "); } logger.info("condition = {" + c.toString() + "}"); return c.toString(); } /** * 解析条件语句 * * @param qlString * @param conditions * @param conValues * @throws Exception */ protected List<Object> preConditionJPQL(String qlString, List<Condition> conditions, List<Object> conValues) { List<Object> list = new ArrayList<Object>(); int conBeginIndex = getConBeginIndex(qlString); list.add(conBeginIndex); list.add(preConditionJPQL(conditions, conValues).toString()); return list; } protected int getConBeginIndex(String qlString) { int conIndex = qlString.indexOf("WHERE 1=1"); if (conIndex == -1) { throw new RuntimeException("ql中没有WHERE 1=1"); } logger.info("conIndex = {" + conIndex + "}"); String conBefore = qlString.substring(0, conIndex); String conAfter = qlString.substring(conIndex); int conBeforeCount = counter(conBefore, '?'); int conAfterCount = counter(conAfter, '?'); logger.info("条件前的?个数:{" + conBeforeCount + "}"); logger.info("条件后的?个数:{" + conAfterCount + "}"); logger.info("条件的起始位置:{" + conBeforeCount + "}"); return conBeforeCount; } protected int counter(String s, char c) { int count = 0; for (int i = 0; i < s.length(); i++) { if (s.charAt(i) == c) { count++; } } return count; } /** * 执行统计语句 */ protected void executeCount(PageBean pageBean, String count_ql, List<Object> values, int conBeginIndex,boolean cacheable){ Query query = entityManager.createQuery(count_ql); query.setHint(QueryHints.CACHEABLE, cacheable); for (int i = 0; i + conBeginIndex < values.size(); i++) { setParameter(query, i + 1, values.get(i + conBeginIndex)); } List<?> list = null; list = query.getResultList(); if (list.size() == 1) { int totalRows = Integer.parseInt(list.get(0).toString()); logger.info("executeCount totalRows = {" + totalRows + "}"); pageBean.setTotalRows(totalRows); } else { pageBean.setTotalRows(list.size()); } } }3.自定义CommonRepositoryFactoryBean
<strong></strong><pre name="code" class="java">package com.bjhy.platform.persist.config.jpadata; import java.io.Serializable; import javax.persistence.EntityManager; import org.springframework.data.jpa.repository.JpaRepository; import org.springframework.data.jpa.repository.support.JpaRepositoryFactory; import org.springframework.data.jpa.repository.support.JpaRepositoryFactoryBean; import org.springframework.data.repository.core.RepositoryMetadata; import org.springframework.data.repository.core.support.RepositoryFactorySupport; import com.bjhy.platform.persist.dao.CommonRepositoryImpl; /** * @author wbw * * 扩展jpaRepository,让所有的repository共享起自定义的方法 */ public class CommonRepositoryFactoryBean<R extends JpaRepository<T, I>, T, I extends Serializable> extends JpaRepositoryFactoryBean<R, T, I> { protected RepositoryFactorySupport createRepositoryFactory(EntityManager em) { return new CommonRepositoryFactory(em); } private static class CommonRepositoryFactory<T, I extends Serializable> extends JpaRepositoryFactory { private final EntityManager em; public CommonRepositoryFactory(EntityManager em) { super(em); this.em = em; } protected Object getTargetRepository(RepositoryMetadata metadata) { return new CommonRepositoryImpl<T, I>( (Class<T>) metadata.getDomainType(), em); } protected Class<?> getRepositoryBaseClass(RepositoryMetadata metadata) { return CommonRepositoryImpl.class; } } }
<strong>4.定义SpringJpaDataConfig</strong>
<strong>package com.bjhy.platform.persist.config.jpadata; import org.springframework.context.annotation.Configuration; import org.springframework.data.jpa.repository.config.EnableJpaRepositories; import org.springframework.data.web.config.EnableSpringDataWebSupport; @Configuration @EnableJpaRepositories(basePackages = "com.bjhy.**.dao", repositoryFactoryBeanClass = CommonRepositoryFactoryBean.class) @EnableSpringDataWebSupport public class SpringJpaDataConfig { } </strong>
<strong></strong>