模拟Mybatis框架---简单版

思路:

            1 采用xml解析技术对Mybatis主配置文件进行解析,获取到数据库连接参数信息,并提取出来,用jdbc获取连接,把连接通过MySqlSessionfactory传递给MySqlSession中,也就是一个MySqlSessionfactory关联一个连接

            2 读取主配置文件中引入的mappper文件的信息,获取到mapper文件中每个子标签(select,insert,update,delete)中的id属性,resultType属性,parameterType属性以及标签中的sql语句

             3 把获取到的mapper文件中子标签的各种属性以及值封装成一个类MySqlElementEntity,也就是一个select标签对应一个类,然后在自定义的SqlSessionfactory中引入此类,定义一个map集合,key为标签的id,value为一个MySqlElementEntity的对象

             4 在自定义的MySqlSession类中自定义五个方法,和mybatis中一样,然后通过增删改查的标签的id和标签名来判断执行查询还是更新操作等,并获取到标签中的sql语句进行处理,替换占位符为“?”

             5 对执行方法参数的判断,可以是一个简单数据类型的,也可以是多个简单数据类型的,还可以是一个实体类对象,要对其进行判断,以便给占位符赋值

下面请看具体代码

首先创建一个工具类 MyUtil,用来获取document对象和数据类型的判断

public class MyUtil {
	// 由xml获取document对象
	public static Document xml2Doc(File file) {
		try {
			// 获取文档解析器工厂对象
			DocumentBuilderFactory factory = DocumentBuilderFactory.newInstance();
			// 获取文档解析器对象
			DocumentBuilder builder = factory.newDocumentBuilder();// 实例工厂模式:
			// 通过解析器对象的parse 由xml获document对象
			return builder.parse(file);
		} catch (Exception e) {
			throw new RuntimeException(e);
		}
	}
    //判断是否数据类型,基本数据类型,包装类,String
	public static boolean pdBasicTypeOrString(Object obj){
		Class cla=obj.getClass();
		if(cla.isPrimitive()){
			return true;
		}
		return cla==String.class||
			   cla==Byte.class||
			   cla==Short.class||
			   cla==Integer.class||
			   cla==Long.class||
			   cla==double.class||
			   cla==Float.class||
			   cla==Boolean.class||
		       cla==Character.class;
	}
	public static void main(String[] args) {
		System.out.println(Integer.class.isPrimitive());
	}
	
}

创建一个MySqlElementEntity类,用来存储mapper文件下面的标签的各种信息

public class MySqlElementEntity   {
	  private String nodeName;
	  private String nodeId;
	  private String resultType;
	  private String parameterType;
	  private String sqlMybatis;//update student set sname=#{sname},sage=#{sage},sex=#{sex} where sid=#{sid}
	  //private String sqlJdbc;
	  //update student set sname=?,sage=?,sex=? where sid=?
	 // private List parameterName;
	public String getNodeName() {
		return nodeName;
	}
	public void setNodeName(String nodeName) {
		this.nodeName = nodeName;
	}
	public String getNodeId() {
		return nodeId;
	}
	public void setNodeId(String nodeId) {
		this.nodeId = nodeId;
	}
	public String getResultType() {
		return resultType;
	}
	public void setResultType(String resultType) {
		this.resultType = resultType;
	}
	public String getParameterType() {
		return parameterType;
	}
	public void setParameterType(String parameterType) {
		this.parameterType = parameterType;
	}
	public String getSqlMybatis() {
		return sqlMybatis;
	}
	public void setSqlMybatis(String sqlMybatis) {
		this.sqlMybatis = sqlMybatis;
	}
	public String getSqlJdbc() {
		return sqlMybatis.replaceAll("\\#\\{[a-zA-Z0-9]+\\}", "?");
	}
	//获取占位符里面的对象属性名
	public List getParameterName() {
		List parameterName=new ArrayList<>();
		int startIndex=0;
		while(true){
			int index1=sqlMybatis.indexOf('{',startIndex);
			if(index1==-1){
				break;
			}
			int index2=sqlMybatis.indexOf("}",startIndex);
			parameterName.add(sqlMybatis.substring(index1+1, index2));
			if(index2==sqlMybatis.length()-1){
				break;
			}
			startIndex=index2+1;
		}
		return parameterName;
	}
//	public void setParameterName(List parameterName) {
//		this.parameterName = parameterName;
//	}
	@Override
	public String toString() {
		return "MySqlElementEntity [nodeName=" + nodeName + ", nodeId=" + nodeId + ", resultType=" + resultType
				+ ", parameterType=" + parameterType + ", sqlMybatis=" + sqlMybatis + ", sqlJdbc=" + getSqlJdbc()
				+ ", parameterName=" + getParameterName() + "]";
	}
	  
	public static void main(String[] args) {
		String sql1="update student set sname=#{sname},sage=#{sage},sex=#{sex} where sid=#{sid}";
		//System.out.println(sql1.replaceAll("\\#\\{[a-zA-Z]+\\}", "?"));
		List parameterName=new ArrayList<>();
		int startIndex=0;
		while(true){
			int index1=sql1.indexOf('{',startIndex);
			if(index1==-1){
				break;
			}
			int index2=sql1.indexOf("}",startIndex);
			parameterName.add(sql1.substring(index1+1, index2));
			if(index2==sql1.length()-1){
				break;
			}
			startIndex=index2+1;
		}
		System.out.println(parameterName);
	}
}

创建一个MySqlSessionFactroy类,用来读取主配置文件和mapper文件,并获取数据库的连接

public class MySqlSessionFactroy {
	private static String driverClass,userName,userPwd,url;
	public static Map  sqlEntityMap;
	//private static Document[] mapperDocs;
    //1读取核心配置文件和使用的sql映射文件
	//2 获取sqlsession:每个sqlsession关联一个connection
	public MySqlSessionFactroy(File file){
		init(file);
	}
	public MySqlSession  openSession(){
		//获取一个MySqlSession  并关联一个connection对象
	
	    //获取连接
		Connection con=null;
		try {
			con=DriverManager.getConnection(url, userName,userPwd);
		} catch (Exception e) {
			throw new RuntimeException(e);//异常转换
		}
		MySqlSession session=new MySqlSession();
		session.setCon(con);
		return session;
	}
	private void init(File file){
		  //读取连接数据库的四大参数 获取连接
//			
//				
//					
//					
//					
//					
//				
//			
		    Map jdbcMap=new HashMap<>();
		   Document docCore=MyUtil.xml2Doc(file);
		   Element dataSourceElement=(Element)docCore.getElementsByTagName("dataSource").item(0);
		   //获取器所有的property子标签
		   NodeList proList=dataSourceElement.getElementsByTagName("property");
		   for (int i = 0; i < proList.getLength(); i++) {
			      Element ele=(Element)proList.item(i);
			      jdbcMap.put(ele.getAttribute("name"), ele.getAttribute("value"));
	       }
		   driverClass=jdbcMap.get("driver");
		   userName=jdbcMap.get("username");
		   userPwd=jdbcMap.get("password");
		   url=jdbcMap.get("url");
		   //注册驱动:
		    try {
				Class.forName(driverClass);
			} catch (ClassNotFoundException e) {
				throw new RuntimeException(e);//异常转换
			}
		   //读取配置文件:
		    Element mapperEle=(Element)docCore.getElementsByTagName("mapper").item(0);
		    String resource=mapperEle.getAttribute("resource");
		    resource=resource.replace(".", "/");//把所有的.转换为/
		    resource=resource.substring(0, resource.lastIndexOf('/'))+".xml";
		    resource="src/"+resource;
			Document mapperDoc=MyUtil.xml2Doc(new File(resource));
		    sqlEntityMap=new HashMap();
		     //读取sql映射文件中sql标签
		     Element root=(Element)mapperDoc.getElementsByTagName("mapper").item(0);
		     NodeList childNodes=root.getChildNodes();
		     for (int i = 0; i < childNodes.getLength(); i++) {
				  if(childNodes.item(i) instanceof Element){
					    Element eleSql=(Element)childNodes.item(i);
					    //把每个sql标签封装为MySqlElementEntity.java对象
					    MySqlElementEntity sqlEntity=new MySqlElementEntity();
					    sqlEntity.setNodeId(eleSql.getAttribute("id").trim());
					    sqlEntity.setNodeName(eleSql.getNodeName());
					    sqlEntity.setParameterType(eleSql.getAttribute("parameterType").trim());
					    sqlEntity.setResultType(eleSql.getAttribute("resultType").trim());
					    sqlEntity.setSqlMybatis(eleSql.getTextContent().trim());
					    sqlEntityMap.put(eleSql.getAttribute("id").trim(), sqlEntity);
				  }
			}
	}
	public static void main(String[] args) {
		MySqlSessionFactroy factroy=new MySqlSessionFactroy(new File("src/mybatis_conf.xml"));
		System.out.println(factroy.driverClass+":"+factroy.url+":"+factroy.userName+":"+factroy.userPwd);
		System.out.println(factroy.sqlEntityMap);
	}
}

创建一个MySqlSession对象,主要提供通用的增删改查的几个方法

public class MySqlSession {
	 private Connection con;
	 
     public void setCon(Connection con) {
		this.con = con;
	 }
	//提供五个方法
	 public Object selectOne(String id,Object...args){
		  List list=(List)execute("select", id, args);
		  return list.get(0);
	 }
	 public List selectList(String id,Object...args){
		  return (List)execute("select", id, args);
	 }
	 public int delete(String id,Object...args){
		 return (Integer)execute("delete", id, args);
	}
	 public int insert(String id,Object...args){
		 return (Integer)execute("insert", id, args);
     }
	 public int update(String id,Object...args){
		 return (Integer)execute("update", id, args);
	 }
	 public void close(){
		 try {
			con.close();
		} catch (SQLException e) {
			throw new RuntimeException(e);
		}
	 }
	 private Object execute(String nodeName,String id,Object...params){
		//获取参数id对应的MySqlElementEntity对象
		 MySqlElementEntity entity=MySqlSessionFactroy.sqlEntityMap.get(id);
		 //判断是不是指定类型
		 if(!entity.getNodeName().equals(nodeName)){
			 throw new RuntimeException("必须是"+nodeName+"对应的sql标签!");
		 }
		
		
		 int hang=0;
		 List  list=null;
		 try {
			 PreparedStatement pre=null;
			 ResultSet set=null;
			
			 pre=con.prepareStatement(entity.getSqlJdbc());
			 //给占位符赋值:一个对象类型的参数:拿对象的属性给占位符赋值
			 //          一个基本数据类型+String类型的参数:直接给唯一的占位符赋值
			 //          多个类型(基本数据类型+String类型):按顺序给占位符赋值
			 if(params!=null&¶ms.length!=0){
				  //System.out.println(params[0]+":::MyUtil.pdBasicTypeOrString(params[0])="+MyUtil.pdBasicTypeOrString(params[0]));
				  if(params.length!=1){
					   for (int i = 0; i < params.length; i++) {
						   pre.setObject(i+1, params[i]);
					   }
				  }else if(MyUtil.pdBasicTypeOrString(params[0])){
					   pre.setObject(1, params[0]);
				  }else{
					  //说明参数是对象:拿对象的属性的值赋值给同名的占位符
					  Class claObj=params[0].getClass();
					  //获取占位符的所有名字
					  List pnList=entity.getParameterName();
					  for (int i = 0; i < pnList.size(); i++) {
						    String fieldName=pnList.get(i);//拿到属性名
						    //拿到属性对象
						    Field field=claObj.getDeclaredField(fieldName);
						    field.setAccessible(true);
						    Object fieldValue=field.get(params[0]);
						    //给占位符赋值
						    pre.setObject(i+1, fieldValue);
					  }
				  }
			 }
			 //执行execute方法
			 if(nodeName.equals("select")){
				 set=pre.executeQuery();
				 //处理结果集
				 list=new ArrayList();
				 while(set.next()){
					 //把没一行封装为对象:
					 Class claResult=Class.forName(entity.getResultType());
					 Object obj=claResult.newInstance();
					 //给对象的属性封装
					 Field[] fields=claResult.getDeclaredFields();
					 for (Field field : fields) {
						  field.setAccessible(true);
						  field.set(obj, set.getObject(field.getName()));
					 }
					 list.add(obj);
				 }
			 }else{
				 hang=pre.executeUpdate();
			 }
			 if(set!=null){set.close();}
			 if(pre!=null){pre.close();}
		} catch (Exception e) {
			throw new RuntimeException(e);
		}
		return hang==0?list:hang;
	 }
}

至此就大功告成了,简单的模拟mybatis框架就完成了!

你可能感兴趣的:(java,java)