五、持久层框架设计实现及MyBatis源码分析-自定义持久层框架(五)

在前面章节,我们在自定义持久层框架当中,实现了查询方法,本章节对增、删、改方法进行实现,首先先来说说实现增、删、改方法的基本思路

1、在Excutor接口中增加update方法,因为之前写的都是查询方法,有针对结果集的封装,而增删改的操作,没有前面繁琐的结果集封装,只需要返回一个简单的执行操作影响的行数常量就行,所以单独提取一个update方法进行处理

2、在SqlSession接口中增加delete方法和update两个方法,因为其实不论是添加、删除或是更新,其实底层执行的都是update方法

3、对SqlSession中getMapper方法的改造,之前实现查询方法的时候,只是针对返回结果集是否进行了泛型参数化进行了判断,如果是则说明当前返回对象是一个集合,则调用selectList方法,否则就调用selectOne方法,但是在返回的结果集当中如果返回的是一个未进行泛型参数化的时候还存在是insert、update、delete的情况,所以还要在否则条件中进行更进一步的判断,具体的判断方法就是通过statementId在MappedStatement中获取到对应执行的SQL语句,并对SQL语句进行一个toLowerCase转换(避免SQL语句因为大小写差异,因为有时候我们编写的习惯不同,可能存在SQL语句编写大小写不统一的情况),然后再通过SQL语句是以delete、select或都其它开头,调用对应的delete、selectOne或update方法,以delete方法开头就调用delete方法,以select开关就调用selectOne方法,否则就调用update方法

4、对XMLMapperBuilder解析类的改造,因为之前编写查询方法的时候,只对select标签进行了解析,后面需要添加对insert、update、delete标签的解析

思路分析完成,接下来我们看具体的实现

1、在Excutor接口增加update方法,并在SimpleExecutor实现类中进行具体的实现
Excutor接口代码

package study.lagou.com.sqlSession;

import study.lagou.com.pojo.Configuration;
import study.lagou.com.pojo.MappedStatement;

import java.util.List;

public interface Executor {

    /**
     * 查询方法
     * @param configuration
     * @param mappedStatement
     * @param params
     * @param 
     * @return
     * @throws Exception
     */
     List query(Configuration configuration, MappedStatement mappedStatement, Object... params) throws Exception;

    /**
     * 更新方法,主要用于实现统一更新、删除和添加方法
     * @param configuration
     * @param mappedStatement
     * @param params
     */
    Integer update(Configuration configuration,MappedStatement mappedStatement,Object... params) throws Exception;
}

SimpleExecutor实现类代码,因为获取数据库连接、转换SQL、获取PreparedStatement和设置参数的步骤查询和修改方法都是一致的,所以我们对这些步骤进行了一个简单的提取,提到到getPreparedStatement方法当中

package study.lagou.com.sqlSession;

import study.lagou.com.config.BoundSql;
import study.lagou.com.pojo.Configuration;
import study.lagou.com.pojo.MappedStatement;
import study.lagou.com.utils.GenericTokenParser;
import study.lagou.com.utils.ParameterMapping;
import study.lagou.com.utils.ParameterMappingTokenHandler;
import study.lagou.com.utils.SelefJavaClass;

import java.beans.PropertyDescriptor;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.sql.*;
import java.util.ArrayList;
import java.util.List;

/**
 * @Description: 功能描述
 * @Author houjh
 * @Email: [email protected]
 * @Date: 2021-2-22 22:10
 */
public class SimpleExecutor implements Executor {

    @Override
    public  List query(Configuration configuration, MappedStatement mappedStatement, Object... params) throws Exception {
        //5、执行SQL
        ResultSet resultSet = getPreparedStatement(configuration,mappedStatement,params).executeQuery();
        //6、封装返回结果集
        String resultType = mappedStatement.getResultType();
        Class resultTypeClass = getClassType(resultType);
        ArrayList objectArrayList = new ArrayList<>();
        while (resultSet.next()){
            Object resultObj = resultTypeClass.newInstance();
            //获取到resultSet的元数据
            ResultSetMetaData metaData = resultSet.getMetaData();
            for (int i = 1; i <= metaData.getColumnCount(); i++) {
                //获取元数据中的字段名,注意获取字段名的时候下标需要从1开始,所以我们这里的for循环遍历下标是从1开始的
                String columnName = metaData.getColumnName(i);
                //获取字段值
                Object value = resultSet.getObject(columnName);
                //使用反射或内省,根据数据库表和实体的对应关系,完成封装
                PropertyDescriptor propertyDescriptor = new PropertyDescriptor(columnName, resultTypeClass);
                //获取到PropertyDescriptor对象的写方法
                Method writeMethod = propertyDescriptor.getWriteMethod();
                //通过写方法将值写入到对象中
                writeMethod.invoke(resultObj,value);
            }
            objectArrayList.add(resultObj);
        }
        return (List) objectArrayList;
    }

    @Override
    public Integer update(Configuration configuration, MappedStatement mappedStatement, Object... params) throws Exception {
        //5、执行SQL
        int i = getPreparedStatement(configuration,mappedStatement,params).executeUpdate();
        return i;
    }

    /**
     * 将获取PreparedStatement方法和设置参数方法提取到公共方法中
     * @param configuration
     * @param mappedStatement
     * @param params
     * @return
     * @throws Exception
     */
    private PreparedStatement getPreparedStatement(Configuration configuration,MappedStatement mappedStatement,Object... params) throws Exception {
        //1、注册驱动、获取连接
        Connection connection = configuration.getDataSource().getConnection();
        //2、获取SQL信息,并对SQL语句进行转换,在转换过程中需要对#{}里面的值进行解析和存储
        String sql = mappedStatement.getSql();
        BoundSql boundSql = getBoundSql(sql);
        //3、获取预处理对象preparedStatement
        PreparedStatement preparedStatement = connection.prepareStatement(boundSql.getSqlText());
        //4、设置参数信息
        String paramterType = mappedStatement.getParamterType();
        //4.1、通过参数全路径获取到参数对象
        Class paramterTypeClass = getClassType(paramterType);
        //4.2、通过反射设置参数信息,先获取到标记解析器解析出来的#{}中的属性信息,然后循环遍历进行参数设置
        List parameterMappingList = boundSql.getParameterMappingList();

        //这里需要判断一下参数类型,如果参数类型是java常用的数据类型,那么直接传入参数,否则才通过反射处理设置参数信息
        if(SelefJavaClass.isJavaClass(paramterTypeClass)){
            for(int i= 0;i getClassType(String paramterType) throws ClassNotFoundException {
        if(paramterType != null){
            Class aClass = Class.forName(paramterType);
            return aClass;
        }
        return null;
    }

    /**
     * 完成对#{}的解析工作:1、将#{}使用?进行代替 2、解析出#{}里面的值进行存储
     * @param sql
     * @return
     */
    private BoundSql getBoundSql(String sql) {
        //标记处理类:配合标记处理器完成对占位符的处理工作
        ParameterMappingTokenHandler parameterMappingTokenHandler = new ParameterMappingTokenHandler();
        //标记处理器中的三个参数分别是开始标记、结束标记和标记处理类信息
        GenericTokenParser tokenParser = new GenericTokenParser("#{", "}", parameterMappingTokenHandler);
        //通过标记处理器的parse方法解析出来的SQL
        String parseSql = tokenParser.parse(sql);
        //由标记处理类从#{}里面解析出来的参数名称
        List parameterMappingList = parameterMappingTokenHandler.getParameterMappings();
        //将解析出来的SQL和#{}中的参数封装到BoundSql对象中
        BoundSql boundSql = new BoundSql(parseSql, parameterMappingList);
        return boundSql;
    }

}


2、在SqlSession接口中增加delete方法和update两个方法,并对getMapper方法进行改造,判断当前操作是要执行查询、新增、修改还是删除分别调用不同的方法

SqlSession接口类代码如下

package study.lagou.com.sqlSession;

import java.util.List;

public interface SqlSession {

    /**
     * 查询列表信息
     * @param statementId
     * @param params
     * @param 
     * @return
     */
     List selectList(String statementId, Object... params) throws Exception;

    /**
     * 查询单个数据信息
     * @param statementId
     * @param params
     * @param 
     * @return
     */
     T selectOne(String statementId,Object... params) throws Exception;

    /**
     * 删除信息
     * @param statementId
     * @param params
     */
    Integer delete(String statementId,Object... params) throws Exception;

    /**
     * 更新数据信息
     * @param statementId
     * @param params
     */
    Integer update(String statementId,Object... params) throws Exception;

    /**
     * 通过接口
     * @param mapperInterfaceClass
     * @param 
     * @return
     */
     T getMapper(Class mapperInterfaceClass);
}

DefaultSqlSession实现类具体代码,注意此处的代码包含对getMapper方法的改造

package study.lagou.com.sqlSession;

import study.lagou.com.pojo.Configuration;
import study.lagou.com.pojo.MappedStatement;

import java.lang.reflect.*;
import java.util.List;

/**
 * @Description: 功能描述
 * @Author houjh
 * @Email: [email protected]
 * @Date: 2021-2-22 21:24
 */
public class DefaultSqlSession implements SqlSession {

    private Configuration configuration;

    private Executor simpleExecutor = new SimpleExecutor();

    public DefaultSqlSession(Configuration configuration) {
        this.configuration = configuration;
    }

    @Override
    public  List selectList(String statementId, Object... params) throws Exception {
        List query = simpleExecutor.query(configuration, getMappedStatement(statementId), params);
        return (List) query;
    }

    @Override
    public  T selectOne(String statementId, Object... params) throws Exception {
        //selectOne方法和selectList方法操作数据库的方式一致,所以我们此处直接调用selectList方法即可
        List objects = selectList(statementId, params);
        if(objects.size() == 1){
            return (T) objects.get(0);
        } else {
            throw new RuntimeException("查询到的结果集为空或返回的结果集较多!");
        }
    }

    @Override
    public Integer delete(String statementId, Object... params) throws Exception {
        return this.update(statementId,params);
    }

    @Override
    public Integer update(String statementId, Object... params) throws Exception {
        Integer updateNum = simpleExecutor.update(configuration, getMappedStatement(statementId), params);
        return updateNum;
    }

    /**
     * 通过statementId获取MappedStatement
     * @param statementId
     * @return
     */
    private MappedStatement getMappedStatement(String statementId){
        MappedStatement mappedStatement = configuration.getMappedStatementMap().get(statementId);
        return mappedStatement;
    }

    @Override
    public  T getMapper(Class mapperInterfaceClass) {
        Object proxyInstance = Proxy.newProxyInstance(DefaultSqlSession.class.getClassLoader(), new Class[]{mapperInterfaceClass}, new InvocationHandler() {
            //proxy 为当前代理对象
            //method 当前被调用方法的引用
            //args 传递的参数信息
            @Override
            public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
                //底层还是去执行JDBC代码,只是我们需要根据不同情况,来选择调用selectList或是selectOne方法,如果是查询
                //我们可以根据返回值类型是否进行了 泛型类型参数化(通俗讲就是判断返回值是否是一个泛型)进行判断,如果是就
                //说明当前返回对象是一个集合,则调用selectList方法,否则就调用selectOne

                //1、准备参数statementId,在xml文件中statementId是由namespace.id来组成,但是我们此处只能获取到对应的
                //接口信息,并不能获取到xml文件,所以此处约定让namespace.id = 接口全限定名.方法名的方式来进行处理,这
                //也是Mybatis的Mapper.xml文件中namespace为什么要和接口路径保持一致的原因

                //获取方法名称
                String methodName = method.getName();
                //通过方法获取到接口类的全限定名
                String className = method.getDeclaringClass().getName();
                //拼接生成statementId
                String statementId = className+"."+methodName;


                //2、根据被调用方法的返回值类型,判断应该调用selectOne还是selectList方法
                Type genericReturnType = method.getGenericReturnType();
                if(genericReturnType instanceof ParameterizedType){
                    return selectList(statementId,args);
                } else {
                    //如果返回的结果集不是一个泛型,我们就通过MappedStatement对象中的SQL语句进行以什么开头来进行判断
                    MappedStatement mappedStatement = getMappedStatement(statementId);
                    //为避免SQL大小写差异统一将SQL语句转换成小写
                    String sql = mappedStatement.getSql().toLowerCase();
                    if(sql.startsWith("delete")){
                        return delete(statementId,args);
                    } else if(sql.startsWith("select")){
                        return selectOne(statementId,args);
                    } else {
                        return update(statementId,args);
                    }
                }
            }
        });
        return (T) proxyInstance;
    }
}


3、对XMLMapperBuilder解析类的改造,因为之前编写查询方法的时候,只对select标签进行了解析,后面需要添加对insert、、update、delete标签的解析,此处只需要在rootElement.selectNodes获取节点的时候添加//insert|//update|//delete标签就行

具体实现代码如下

package study.lagou.com.config;

import org.dom4j.Document;
import org.dom4j.DocumentException;
import org.dom4j.Element;
import org.dom4j.io.SAXReader;
import study.lagou.com.pojo.Configuration;
import study.lagou.com.pojo.MappedStatement;

import java.io.InputStream;
import java.util.List;

/**
 * @Description: 功能描述
 * @Author houjh
 * @Email: [email protected]
 * @Date: 2021-1-28 23:32
 */
public class XMLMapperBuilder {

    private Configuration configuration;

    public XMLMapperBuilder(Configuration configuration) {
        this.configuration = configuration;
    }

    public void parse(InputStream inputStream) throws DocumentException {
        Document document = new SAXReader().read(inputStream);
        Element rootElement = document.getRootElement();
        String namespace = rootElement.attributeValue("namespace");

        List list = rootElement.selectNodes("//select|//insert|//update|//delete");
        for (Element element : list) {
            String id = element.attributeValue("id");
            String resultType = element.attributeValue("resultType");
            String paramterType = element.attributeValue("paramterType");
            String sqlText = element.getTextTrim();
            MappedStatement mappedStatement = new MappedStatement();
            mappedStatement.setId(id);
            mappedStatement.setResultType(resultType);
            mappedStatement.setParamterType(paramterType);
            mappedStatement.setSql(sqlText);
            String statementId = namespace+"."+id;
            configuration.getMappedStatementMap().put(statementId,mappedStatement);
        }
    }
}

最后编写测试类对方法进行测试,注意测试方法中需要注意数据库中是否在对应的数据,此处将源代码贴出来,大家注意根据自己数据库中的数据进行相应的调整

package study.lagou.com.test;

import org.dom4j.DocumentException;
import org.junit.Test;
import study.lagou.com.io.Resources;
import study.lagou.com.persistence.test.dao.IUserDao;
import study.lagou.com.persistence.test.pojo.User;
import study.lagou.com.sqlSession.SqlSession;
import study.lagou.com.sqlSession.SqlSessionFactory;
import study.lagou.com.sqlSession.SqlSessionFactoryBuilder;

import java.beans.PropertyVetoException;
import java.io.InputStream;
import java.util.List;

/**
 * @Description: 功能描述
 * @Author houjh
 * @Email: [email protected]
 * @Date: 2021-1-28 21:19
 */
public class PersistenceMapperTest {

    public IUserDao getUserDao() throws PropertyVetoException, DocumentException {
        InputStream resourceAsStream = Resources.getResourceAsStream("sqlMapConfig.xml");
        SqlSessionFactory build = new SqlSessionFactoryBuilder().build(resourceAsStream);
        SqlSession sqlSession = build.openSession();
        IUserDao userDao = sqlSession.getMapper(IUserDao.class);
        return userDao;
    }

    @Test
    public void testOne() throws Exception {
        User user = new User();
        user.setId(1);
        user.setUsername("xiaozhangsan");
        User resultUser = getUserDao().findByCondition(user);
        System.out.println(resultUser);
    }

    @Test
    public void testList() throws Exception {
        List userList = getUserDao().findAll();
        for (User user : userList) {
            System.out.println(user);
        }
    }

    @Test
    public void testUpdate() throws Exception {
        User user = new User();
        user.setId(1);
        user.setUsername("zhangsan");
        user.setPassword("777");
        user.setNickname("张三");

        Integer update = getUserDao().updateById(user);
        System.out.println(update);
    }

    @Test
    public void testDelete() throws Exception {
        Integer delete = getUserDao().deleteById(3);
        System.out.println(delete);
    }

    @Test
    public void testInsert() throws Exception {
        User user  = new User();
        user.setId(5);
        user.setUsername("zhaoliu");
        user.setPassword("888");
        user.setNickname("赵六");
        getUserDao().insert(user);
    }
}

具体代码对应下载地址:https://gitee.com/happymima/mybatis.git

你可能感兴趣的:(五、持久层框架设计实现及MyBatis源码分析-自定义持久层框架(五))