MyBatis 之自己实现一个MyBatis框架

整体流程图

MyBatis 之自己实现一个MyBatis框架_第1张图片

  • 首先创建 SqlSessionFactory 实例,SqlSessionFactory 就是创建 SqlSession 的工厂类。
  • 加载配置文件创建 Configuration 对象,配置文件包括数据库相关配置文件以及我们在 XML 文件中写的 SQL。
  • 通过 SqlSessionFactory 创建 SqlSession
  • 通过 SqlSession 获取 mapper 接口动态代理。
  • 动态代理回调 SqlSession 中某查询方法。
  • SqlSession 将查询方法转发给 Executor
  • Executor 基于 JDBC 访问数据库获取数据,最后还是通过 JDBC 操作数据库。
  • Executor 通过反射将数据转换成 POJO 并返回给 SqlSession
  • 将数据返回给调用者。
    MyBatis 之自己实现一个MyBatis框架_第2张图片

SQL脚本和Maven依赖

DROP TABLE IF EXISTS `t_user`;
CREATE TABLE `t_user` (
  `userId` bigint(20) NOT NULL,
  `userName` varchar(255) DEFAULT NULL,
  `sex` int(11) DEFAULT NULL,
  `role` varchar(255) DEFAULT NULL,
  PRIMARY KEY (`userId`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8;
<dependency>
    <groupId>dom4jgroupId>
    <artifactId>dom4jartifactId>
    <version>1.6.1version>
dependency>
<dependency>
    <groupId>mysqlgroupId>
    <artifactId>mysql-connector-javaartifactId>
    <version>5.1.38version>
dependency>

首先按照我们以前的使用 MyBatis 代码时的流程,创建 Mapper 接口、XML 文件,和 POJO 以及集一些配置文件,这几个文件我们和 mybatis-demo 创建一样的即可,方便我们比较结果。

Mapper 接口,这里面定义两个抽象方法,根据主键查找用户和查找所有用户:

import com.demo.entity.User;
import java.util.List;

public interface UserMapper {
    User selectByPrimaryKey(long userId);
    List<User> selectAll();
}

XML 文件里面是上面两个抽象方法的具体 SQL 实现,完全模仿官方 XML 文件的写法,需要注意 namespace、id、resultType、SQL 语句这几个点,都是我们后面代码需要处理的。



<mapper namespace="com.demo.mapper.UserMapper">
    <select id="selectByPrimaryKey" resultType="com.demo.entity.User">
            select *
            from t_user
            where userId = #{userId}
        select>
    <select id="selectAll" resultType="com.demo.entity.User">
            select *
            from t_user
        select>
mapper>

实体类,它的属性与数据库的表相对应:

import lombok.Data;

@Data
public class User {
    private Long userId;
    private String userName;
    private Integer sex;
    private String role;
}

最后一个配置文件,数据库连接配置文件 db.properties:

  jdbc.driver=com.mysql.jdbc.Driver
  jdbc.url=jdbc:mysql://localhost:3370/test?useSSL=false
  jdbc.username=root
  jdbc.password=123456

配置文件和一些测试的必须类已经写完了,首先我们需要把这些配置信息加载到 Configuration 配置类中。
先定义一个类来加载写 SQL 语句的 XML 文件,上面我们说过要注意四个点,namespace、id、resultType、SQL 语句,我们写对应的属性来保存它。

import lombok.Data;

/**
 * XML 中的 sql 配置信息加载到这个类中
 */
@Data
public class MappedStatement {
    private String namespace;
    private String id;
    private String resultType;
    private String sql;
}

接下来我们定义一个 Configuration 总配置类,来保存 db.propeties 里面的属性和 XML 文件的 SQL 信息,Configuration 类里面的文件对应我们配置文件中的属性。

import lombok.Getter;
import lombok.Setter;

import java.util.HashMap;
import java.util.Map;

/**
 * 所有的配置信息
 */
@Setter
@Getter
public class Configuration {
    private String jdbcDriver;
    private String jdbcUrl;
    private String jdbcPassword;
    private String jdbcUsername;
    private Map<String, MappedStatement> mappedStatement = new HashMap<>();
}

按照上面的流程图,我们来创建一个 SqlSessionFactory 工厂类,这个类有两个功能,一个是加载配置文件信息到 Configuration 类中,另一个是创建 SqlSession。

SqlSessionFactory 抽象模版:

import com.demo.sqlsession.SqlSession;

public interface SqlSessionFactory {
    SqlSession openSession();
}

创建 SqlSessionFactory 的 Default 实现类,Default 实现类主要完成了两个功能,加载配置信息到 Configuration 对象里,实现创建 SqlSession 的功能。

import com.demo.configuration.Configuration;
import com.demo.configuration.MappedStatement;
import com.demo.sqlsession.DefaultSqlSession;
import com.demo.sqlsession.SqlSession;
import org.dom4j.Document;
import org.dom4j.DocumentException;
import org.dom4j.Element;
import org.dom4j.io.SAXReader;

import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.net.URL;
import java.util.ArrayList;
import java.util.List;
import java.util.Properties;

/**
 * 1.初始化时就完成了 configuration 的实例化
 * 2.工厂类,生成 sqlSession
 */
public class DefaultSqlSessionFactory implements SqlSessionFactory {

    private final Configuration configuration = new Configuration();
    // xml 文件存放的位置
    private static final String MAPPER_CONFIG_LOCATION = "mappers";
    // 数据库信息存放的位置
    private static final String DB_CONFIG_FILE = "db.properties";

    public DefaultSqlSessionFactory() {
        loadDBInfo();
        loadMapperInfo();
    }

    private void loadDBInfo() {
        InputStream db = this.getClass().getClassLoader().getResourceAsStream(DB_CONFIG_FILE);
        Properties properties = new Properties();
        try {
            properties.load(db);
        } catch (IOException e) {
            e.printStackTrace();
        }
        //将配置信息写入Configuration 对象
        configuration.setJdbcDriver(properties.get("jdbc.driver").toString());
        configuration.setJdbcUrl(properties.get("jdbc.url").toString());
        configuration.setJdbcUsername(properties.get("jdbc.username").toString());
        configuration.setJdbcPassword(properties.get("jdbc.password").toString());
    }

    //解析并加载xml文件
    private void loadMapperInfo() {
        URL resources = this.getClass().getClassLoader().getResource(MAPPER_CONFIG_LOCATION);
        File mappers = new File(resources.getFile());
        //读取文件夹下面的文件信息
        if (mappers.isDirectory()) {
            File[] files = mappers.listFiles();
            for (File file : files) {
                loadMapperInfo(file);
            }
        }
    }

    private void loadMapperInfo(File file) {
        SAXReader reader = new SAXReader();
        //通过read方法读取一个文件转换成 Document 对象
        Document document = null;
        try {
            document = reader.read(file);
        } catch (DocumentException e) {
            e.printStackTrace();
        }
        //获取根结点元素对象
        Element e = document.getRootElement();
        //获取命名空间namespace
        String namespace = e.attribute("namespace").getData().toString();
        //获取select,insert,update,delete子节点列表
        List<Element> selects = e.elements("select");
        List<Element> inserts = e.elements("insert");
        List<Element> updates = e.elements("update");
        List<Element> deletes = e.elements("delete");

        List<Element> allElement = new ArrayList<>();
        allElement.addAll(selects);
        allElement.addAll(inserts);
        allElement.addAll(updates);
        allElement.addAll(deletes);

        //遍历节点,组装成 MappedStatement 然后放入到configuration 对象中
        for (Element element : allElement) {
            MappedStatement mappedStatement = new MappedStatement();
            String id = element.attribute("id").getData().toString();
            String resultType = element.attribute("resultType").getData().toString();
            String sql = element.getData().toString();

            mappedStatement.setId(namespace + "." + id);
            mappedStatement.setResultType(resultType);
            mappedStatement.setNamespace(namespace);
            mappedStatement.setSql(sql);
            // xml 文件中的每个 sql 方法都组装成 mappedStatement 对象,以 namespace+"."+id 为 key, 放入configuration 配置类中
            configuration.getMappedStatement().put(namespace + "." + id, mappedStatement);
        }
    }

    @Override
    public SqlSession openSession() {
        // openSession 方法创建一个 DefaultSqlSession,configuration 配置类作为 构造函数参数传入
        return new DefaultSqlSession(configuration);
    }
}

在 SqlSessionFactory 里创建了 DefaultSqlSession,我们看看它的具体实现。SqlSession 里面应该封装了所有数据库的具体操作和一些获取 mapper 实现类的方法。

SqlSession 接口,定义模版方法

import java.util.List;

/**
 * 封装了所有数据库的操作
 * 所有功能都是基于 Executor 来实现的,Executor 封装了 JDBC 操作
 */
public interface SqlSession {
    /**
     * 根据传入的条件查询单一结果
     *
     * @param statement namespace+id,可以用做 key,去 configuration 里面获取 sql 语句,resultType
     * @param parameter 要传入 sql 语句中的查询参数
     * @param        返回指定的结果对象
     * @return
     */
    <T> T selectOne(String statement, Object parameter);

    <T> List<T> selectList(String statement, Object parameter);

    /**
     * 获取Mapper实现类
     * @param type
     * @param 
     * @return
     */
    <T> T getMapper(Class<T> type);
}

Default 的 SqlSession 实现类。里面需要传入 Executor,这个 Executor 里面封装了 JDBC 操作数据库的流程。我们重点关注 getMapper 方法,使用动态代理生成一个加强类。这里面最终还是把数据库的相关操作转给 SqlSession,使用 Mapper 能使编程更加优雅。

import com.demo.bind.MapperProxy;
import com.demo.configuration.Configuration;
import com.demo.configuration.MappedStatement;
import com.demo.executor.Executor;
import com.demo.executor.SimpleExecutor;

import java.lang.reflect.Proxy;
import java.util.List;

public class DefaultSqlSession implements SqlSession {
    private final Configuration configuration;

    private Executor executor;

    public DefaultSqlSession(Configuration configuration) {
        super();
        this.configuration = configuration;
        executor = new SimpleExecutor(configuration);
    }

    @Override
    public <T> T selectOne(String statement, Object parameter) {
        List<T> selectList = this.selectList(statement, parameter);
        if (selectList == null || selectList.size() == 0) {
            return null;
        }
        if (selectList.size() == 1) {
            return (T) selectList.get(0);
        } else {
            throw new RuntimeException("too many result");
        }
    }

    @Override
    public <T> List<T> selectList(String statement, Object parameter) {
        MappedStatement ms = configuration.getMappedStatement().get(statement);
        // 我们的查询方法最终还是交给了 Executor 去执行,Executor 里面封装了 JDBC 操作。传入参数包含了 sql 语句和 sql 语句需要的参数。
        return executor.query(ms, parameter);
    }

    @Override
    public <T> T getMapper(Class<T> type) {
        //通过动态代理生成了一个实现类,我们重点关注,动态代理的实现,它是一个 InvocationHandler,传入参数是 this,就是 sqlSession 的一个实例。
        MapperProxy mp = new MapperProxy(this);
        //给我一个接口,还你一个实现类
        return (T) Proxy.newProxyInstance(type.getClassLoader(), new Class[]{type}, mp);
    }
}

来看看我们的 InvocationHandler 如何实现 invoke 方法:

import com.demo.sqlsession.SqlSession;

import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.util.Collection;

/**
 * 将请求转发给 sqlSession
 */
public class MapperProxy implements InvocationHandler {
    private SqlSession sqlSession;

    public MapperProxy(SqlSession sqlSession) {
        this.sqlSession = sqlSession;
    }

    @Override
    public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
        System.out.println(method.getDeclaringClass().getName() + "." + method.getName());
        //最终还是将执行方法转给 sqlSession,因为 sqlSession 里面封装了 Executor
        //根据调用方法的类名和方法名以及参数,传给 sqlSession 对应的方法
        if (Collection.class.isAssignableFrom(method.getReturnType())) {
            return sqlSession.selectList(method.getDeclaringClass().getName() + "." + method.getName(), args == null ? null : args[0]);
        } else {
            return sqlSession.selectOne(method.getDeclaringClass().getName() + "." + method.getName(), args == null ? null : args[0]);
        }
    }
}

获取 Mapper 接口的实现类我们已经实现了,通过动态代理调用 sqlSession 的方法。那么就剩最后一个重要的工作了,那就是实现 Exectuor 类去操作数据库,封装 JDBC。

Executor 抽象模版,我们只实现了 query、update 等操作慢慢增加。

import com.demo.configuration.MappedStatement;

import java.util.List;

/**
 * mybatis 核心接口之一,定义了数据库操作的最基本的方法,JDBC,sqlSession的所有功能都是基于它来实现的
 */
public interface Executor {
    /**
     *
     * 查询接口
     * @param ms 封装sql 语句的 mappedStatemnet 对象,里面包含了 sql 语句,resultType 等。
     * @param parameter 传入sql 参数
     * @param  将数据对象转换成指定对象结果集返回
     * @return
     */
    <E> List<E> query(MappedStatement ms, Object parameter);
}
import com.demo.configuration.Configuration;
import com.demo.configuration.MappedStatement;
import com.demo.util.ReflectionUtil;

import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.List;

/**
 * Executor 接口的实现类,主要是对 JDBC 的封装,和利用反射方法将结果映射到 resultType 对应的实体类中
 */
public class SimpleExecutor implements Executor {
    private final Configuration configuration;

    public SimpleExecutor(Configuration configuration) {
        this.configuration = configuration;
    }

    @Override
    public <E> List<E> query(MappedStatement ms, Object parameter) {
        System.out.println(ms.getSql().toString());

        List<E> ret = new ArrayList<>(); //返回结果集
        try {
            Class.forName(configuration.getJdbcDriver());
        } catch (ClassNotFoundException e) {
            e.printStackTrace();
        }

        Connection connection = null;
        PreparedStatement preparedStatement = null;
        ResultSet resultSet = null;
        try {
//            connection = DriverManager.getConnection("jdbc:mysql://localhost:3370/test?useSSL=false","root","123456");
            connection = DriverManager.getConnection(configuration.getJdbcUrl(), configuration.getJdbcUsername(), configuration.getJdbcPassword());
            String regex = "#\\{([^}])*\\}";
            // 将 sql 语句中的 #{userId} 替换为 ?
            String sql = ms.getSql().replaceAll(regex, "?");
            preparedStatement = connection.prepareStatement(sql);
            //处理占位符,把占位符用传入的参数替换
            parametersize(preparedStatement, parameter);
            resultSet = preparedStatement.executeQuery();
            handlerResultSet(resultSet, ret, ms.getResultType());
        } catch (SQLException e) {
            e.printStackTrace();
        } finally {
            try {
                resultSet.close();
                preparedStatement.close();
                connection.close();
            } catch (Exception e) {
                e.printStackTrace();
            }
        }

        return ret;
    }

    private void parametersize(PreparedStatement preparedStatement, Object parameter) throws SQLException {
        if (parameter instanceof Integer) {
            preparedStatement.setInt(1, (int) parameter);
        } else if (parameter instanceof Long) {
            preparedStatement.setLong(1, (Long) parameter);
        } else if (parameter instanceof String) {
            preparedStatement.setString(1, (String) parameter);
        }
    }

    private <E> void handlerResultSet(ResultSet resultSet, List<E> ret, String className) {
        Class<E> clazz = null;
        //通过反射获取类对象
        try {
            clazz = (Class<E>) Class.forName(className);
        } catch (ClassNotFoundException e) {
            e.printStackTrace();
        }


        try {
            while (resultSet.next()) {
                Object entity = clazz.newInstance();
                //通过反射工具 将 resultset 中的数据填充到 entity 中
                ReflectionUtil.setPropToBeanFromResultSet(entity, resultSet);
                ret.add((E) entity);
            }
        } catch (Exception e) {
            e.printStackTrace();
        }

    }
}

到目前未知,我们简单版的 MyBatis 框架已经实现了,我们来写一个测试类测试一下。

public class TestDemo {
    public static void main(String[] args) {
        SqlSessionFactory sqlSessionFactory = new DefaultSqlSessionFactory();
        SqlSession sqlSession = sqlSessionFactory.openSession();
        UserMapper mapper = sqlSession.getMapper(UserMapper.class);
        User user = mapper.selectByPrimaryKey(1001L);
        System.out.println(user.toString());
    }
}

看一下测试的结果,整个 MyBatis 框架已经实现完成了,当然有很多地方需要完善,比如 XML 中的 SQL 语句处处理还缺很多功能,目前只支持 select 等。

mybatis的简单模拟
自己实现一个 MyBatis 框架

你可能感兴趣的:(MyBatis)