SpringBoot自定义Mybatis数据源,SpringBoot集成Druid数据库连接池,自定义Mybatis拦截器转义特殊字符\%_

SpringBoot自定义Mybatis数据源,SpringBoot集成Druid数据库连接池,自定义Mybatis拦截器转义特殊字符%_

文章目录

  • SpringBoot自定义Mybatis数据源,SpringBoot集成Druid数据库连接池,自定义Mybatis拦截器转义特殊字符\%_
      • 六、Springboot中使用Druid自定义数据库源
      • 七,自定义Mybatis拦截器(解决sql注入)

六、Springboot中使用Druid自定义数据库源

  1. 首先引入相关依赖:

    <dependency>
        <groupId>com.alibabagroupId>
        <artifactId>druid-spring-boot-starterartifactId>
        <version>1.2.15version>
    dependency>
    <dependency>
        <groupId>org.mybatis.spring.bootgroupId>
        <artifactId>mybatis-spring-boot-starterartifactId>
        <version>3.0.0version>
    dependency>
    <dependency>
        <groupId>com.baomidougroupId>
        <artifactId>mybatis-plus-boot-starterartifactId>
        <version>3.5.3version>
    dependency>
    <dependency>
        <groupId>com.github.pagehelpergroupId>
        <artifactId>pagehelper-spring-boot-starterartifactId>
        <version>1.4.6version>
    dependency>
    
  2. 在配置文件中配置数据库连接信息,以及一些需要的druid配置

    spring.datasource.type=com.alibaba.druid.pool.DruidDataSource
    spring.datasource.url=jdbc:p6spy:mysql://127.0.0.1:3306/reggie?characterEncoding=utf-8&serverTimezone=UTC
    spring.datasource.username=root
    spring.datasource.password=mysqlroot
    # druid
    spring.datasource.druid.enable=true
    spring.datasource.druid.name=mysql
    spring.datasource.druid.initial-size=10
    spring.datasource.druid.min-idle=10
    spring.datasource.druid.max-active=20
    spring.datasource.druid.max-wait=60000
    spring.datasource.druid.validation-query=SELECT 1 FROM DUAL
    spring.datasource.druid.pool-prepared-statements=false
    
  3. 配置一个数据源,如果项目中有多个可以在此处配置多个,用@primary注解标注主数据源。在这里可以使用@ConfigurationProperties(prefix = "spring.datasource.druid")将配置文件中 的配置项自动加载进来,也可以先@Value注解将数据先读取出来,再设置进去

    @Primary
    @Bean(name = "masterDataSource")
    @ConfigurationProperties(prefix = "spring.datasource.druid")
    public DruidDataSource masterDataSource() {
        return DruidDataSourceBuilder.create().build();
    }
    
  4. 为Mybatis设置配置数据源、事务、mapper文件路径以及拦截器,这样数据源就设置好了,Springboot项目启动时就会自动加载我们的sqlSessionFactory吗,达到我们自定义数据源的目的了

    @Configuration
    @MapperScan(basePackages = {"com.shadowy.mapper"}, sqlSessionFactoryRef = "sqlSessionFactory") // 首先定义mapperscan扫描包,可指定多个包路径
    public class MybatisConfig {
        @Primary
        @Bean(name = "transactionManager")
        public DataSourceTransactionManager transactionManager(@Qualifier("masterDataSource") DataSource dataSource) {
            return new DataSourceTransactionManager(dataSource);
        }
    
        @Primary
        @Bean
        public SqlSessionFactory sqlSessionFactory(@Qualifier("masterDataSource") DataSource dataSource) throws Exception { // 将上一步的DruidDataSource引入
            // 为MybatisSqlSessionFactoryBean设置数据源,以及xml资源路径,一般mapper.xml都放在Resources目录下的
            final MybatisSqlSessionFactoryBean sessionFactory = new MybatisSqlSessionFactoryBean();
            sessionFactory.setDataSource(dataSource);
    
            String mapperLocation = "mapper/*.xml";
            Resource[] resources = new PathMatchingResourcePatternResolver().getResources(
                ResourceLoader.CLASSPATH_URL_PREFIX + mapperLocation
            );
            sessionFactory.setMapperLocations(resources);
          
            // 为Mybatis注册一些需要用到的插件
            Properties prop = new Properties();
            prop.setProperty("offsetAsPageNum", "true");
            prop.setProperty("rowBoundsWithCount", "true");
            prop.setProperty("reasonable", "true");
            prop.setProperty("helperDialect", "mysql");
            PageInterceptor pageInterceptor = new PageInterceptor();
            pageInterceptor.setProperties(prop);
    
            // ES转义插件
            EscapeInterceptor escapeInterceptor = new EscapeInterceptor();
    
            List<Interceptor> interceptors = new ArrayList<>();
            interceptors.add(pageInterceptor);
            interceptors.add(escapeInterceptor);
    
            sessionFactory.setPlugins(interceptors.toArray(new Interceptor[1]));
            return sessionFactory.getObject();
        }
    }
    
  5. 检查数据源是否连接正常,定时任务健康检查,这里的dataSource就是我们前面定义的masterDataSource,注入进来即可

    @Scheduled(cron = "${db.datasource.health.check}")
    private void healthCheck() {
        try (
            Connection connection = dataSource.getConnection();
            PreparedStatement preparedStatement = connection.prepareStatement(testSql);
            ResultSet resultSet = preparedStatement.executeQuery()
        ) {
            if (resultSet.next()) {
                LogUtil.info("DataSource connected is normal");
            }
        } catch (SQLException e) {
            LogUtil.error("DataSource is disconnected");
            throw new RuntimeException(e);
        }
    }
    

七,自定义Mybatis拦截器(解决sql注入)

  项目上发现一个问题,我们有时候会用到like %?%这种查询语句,有时是mybatisplus中的QueryMapper.like(?)构建的语句,有时候是我们在mapper层自己定义的like concat('%', #{id}, '%')这种语句,当传递进来的参数是若干个下划线_时会查询到所有的数据而不是包含下划线的数据,因为下划线代表任意一个字符,这就会导致sql注入风险,导致原有的语义改变。

  针对这种情况,我们就需要自定义一个拦截器来拦截执行的sql语句,对其做转义处理,参考:https://www.jianshu.com/p/f4d3e6ffeee8

  首先明确需求,目标是让参数进行转义,例如下面这样:

SELECT * FROM DISH WHERE ID LIKE CONCAT('%', '_', '%'); # 查询所有数据
# 转换为↓
SELECT * FROM DISH WHERE ID LIKE CONCAT('%', '\_', '%'); # 查询id中有_的数据
  1. 定义一个mybatis拦截器,拦截select查询

    @Intercepts({
        @Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class,
            RowBounds.class, ResultHandler.class}),
        @Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class,
            RowBounds.class, ResultHandler.class, CacheKey.class, BoundSql.class})
    })
    public class EscapeInterceptor implements Interceptor {
    }
    
  2. 重写接口中的方法,针对intercept方法进行增强即可,另外两个默认就行

    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        final Object[] args = invocation.getArgs();
        // 拦截sql得到参数与MappedStatement对象,里面存放着sql的所有信息
        MappedStatement statement = (MappedStatement) args[0];
        parameterObject = args[1];
    
        // BoundSql就是我们实际进行查询的sql对象,里面放着sql语句
        BoundSql boundSql = statement.getBoundSql(parameterObject);
        String sql = boundSql.getSql().toLowerCase(Locale.ROOT);
    
        // 处理特殊字符,有些sql可能不需要特殊处理,不要一刀切
        if (isLikeSql(sql, boundSql)) {
            // 处理了特殊字符后,需要把增强后的对象重新赋值回去,因为下面转类型时丢失引用了(为了避免强制类型转换告警)
            args[0] = buildMappedStatement(statement, boundSql);
            args[1] = parameterObject;
        }
    
        return invocation.proceed();
    }
    
  3. 拦截器详细代码

    package com.shadowy.datasource.Interceptor;
    
    import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
    import com.shadowy.utils.JsonUtil;
    import com.shadowy.utils.StringUtils;
    import org.apache.commons.lang3.ObjectUtils;
    import org.apache.ibatis.cache.CacheKey;
    import org.apache.ibatis.executor.Executor;
    import org.apache.ibatis.mapping.BoundSql;
    import org.apache.ibatis.mapping.MappedStatement;
    import org.apache.ibatis.mapping.ParameterMapping;
    import org.apache.ibatis.mapping.SqlSource;
    import org.apache.ibatis.plugin.Interceptor;
    import org.apache.ibatis.plugin.Intercepts;
    import org.apache.ibatis.plugin.Invocation;
    import org.apache.ibatis.plugin.Signature;
    import org.apache.ibatis.session.ResultHandler;
    import org.apache.ibatis.session.RowBounds;
    
    import java.util.HashSet;
    import java.util.Locale;
    import java.util.Map;
    import java.util.Properties;
    import java.util.Set;
    
    /**
     * 自定义Mybatis拦截器,处理模糊查询中的特殊字符_\%,参考:https://www.jianshu.com/p/f4d3e6ffeee8
     *
     * @Author shadowy
     * @Since 2023/3/21 23:13
     */
    @Intercepts({
        @Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class,
            RowBounds.class, ResultHandler.class}),
        @Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class,
            RowBounds.class, ResultHandler.class, CacheKey.class, BoundSql.class})
    })
    public class EscapeInterceptor implements Interceptor {
        private Object parameterObject;
    
        /**
         * 拦截sql做增强处理,实测RESTFUL风格的接口可能无法转义
         *
         * @param invocation invocation
         * @return proceed
         * @throws Throwable exception
         */
        @Override
        public Object intercept(Invocation invocation) throws Throwable {
            final Object[] args = invocation.getArgs();
            MappedStatement statement = (MappedStatement) args[0];
            parameterObject = args[1];
    
            BoundSql boundSql = statement.getBoundSql(parameterObject);
            String sql = boundSql.getSql().toLowerCase(Locale.ROOT);
    
            // 处理特殊字符
            if (isLikeSql(sql, boundSql)) {
                // 对sql进行增强处理
                args[0] = buildMappedStatement(statement, boundSql);
                args[1] = parameterObject;
            }
    
            return invocation.proceed();
        }
    
        // 判断是否是模糊查询语句
        private boolean isLikeSql(String sql, BoundSql boundSql) {
            if (!sql.contains(" like ") || !sql.contains("?")) {
                return false;
            }
    
            // 获取关键字的个数(去重)
            String[] strList = sql.split("\\?");
            Set<String> keyNames = new HashSet<>();
            for (int i = 0; i < strList.length; i++) {
                if (strList[i].toLowerCase().contains(" like ")) {
                    String keyName = boundSql.getParameterMappings().get(i).getProperty();
                    keyNames.add(keyName);
                }
            }
    
            // 对关键字进行特殊字符清洗,如果有特殊字符的,在特殊字符前添加转义字符\
            boolean ismodify = false;
            for (String keyName : keyNames) {
                Layer layer = Layer.MAPPER;
                Map<String, Object> parameter = JsonUtil.object2ObjectMap(parameterObject);
    
                if (sql.toLowerCase().contains(" like ?")) {
                    layer = keyName.contains("ew.paramNameValuePairs.") ? Layer.SERVICES_CONSTRUCTOR : Layer.SERVICES_NON_CONSTRUCTOR;
                }
                ismodify = modifyParam(layer, keyName, parameter) || ismodify;
            }
            return ismodify;
        }
    
        // 判断修改参数
        private boolean modifyParam(Layer layer, String keyName, Map<String, Object> parameter) {
            String param;
            switch (layer) {
                case SERVICES_CONSTRUCTOR -> {
                    if (parameter.get("ew") instanceof QueryWrapper<?> wrapper) {
                        parameter = wrapper.getParamNameValuePairs();
                    }
                    String[] keyList = keyName.split("\\.");
                    param = (String) ObjectUtils.defaultIfNull(parameter.get(keyList[2]), "");
    
                    if (isNeedEscape(param)) {
                        parameter.put(keyList[2], "%" + escapeChar(param.substring(1, param.length() - 1)) + "%");
                    }
                    return true;
                }
                case SERVICES_NON_CONSTRUCTOR -> {
                    param = (String) ObjectUtils.defaultIfNull(parameter.get(keyName), "");
                    if (isNeedEscape(param)) {
                        parameter.put(keyName, "%" + escapeChar(param.substring(1, param.length() - 1)) + "%");
                    }
                    parameterObject = parameter;
                    return true;
                }
                case MAPPER -> {
                    param = (String) ObjectUtils.defaultIfNull(parameter.get(keyName), "");
                    if (isNeedEscape(param)) {
                        parameter.put(keyName, escapeChar(param));
                    }
                    parameterObject = parameter;
                    return true;
                }
                case DEFAULT -> {}
            }
            return false;
        }
    
        // 转义特殊字符
        private String escapeChar(String before){
            if (StringUtils.isNotBlank(before)) {
                before = before.replaceAll("\\\\", "\\\\\\\\");
                before = before.replaceAll("_", "\\\\_");
                before = before.replaceAll("%", "\\\\%");
            }
            return before;
        }
    
        private boolean isNeedEscape(String targetStr) {
            return targetStr.contains("_") || targetStr.contains("\\") || targetStr.contains("%");
        }
    
        private MappedStatement buildMappedStatement(MappedStatement ms, BoundSql boundSql) {
            String newSql = boundSql.getSql();
            BoundSql newBoundSql = new BoundSql(ms.getConfiguration(), newSql,
                boundSql.getParameterMappings(), parameterObject);
            MappedStatement newMs = copyFromMappedStatement(ms, new BoundSqlSqlSource(newBoundSql));
    
            for (ParameterMapping mapping : boundSql.getParameterMappings()) {
                String prop = mapping.getProperty();
                if (boundSql.hasAdditionalParameter(prop)) {
                    newBoundSql.setAdditionalParameter(prop, boundSql.getAdditionalParameter(prop));
                }
            }
            return newMs;
        }
    
        private MappedStatement copyFromMappedStatement(MappedStatement ms, SqlSource newSqlSource) {
            MappedStatement.Builder builder = new MappedStatement.Builder(ms.getConfiguration(), ms.getId(), newSqlSource, ms.getSqlCommandType());
            builder.resource(ms.getResource());
            builder.fetchSize(ms.getFetchSize());
            builder.statementType(ms.getStatementType());
            builder.keyGenerator(ms.getKeyGenerator());
            if (ms.getKeyProperties() != null && ms.getKeyProperties().length > 0) {
                builder.keyProperty(ms.getKeyProperties()[0]);
            }
            builder.timeout(ms.getTimeout());
            builder.parameterMap(ms.getParameterMap());
            builder.resultMaps(ms.getResultMaps());
            builder.resultSetType(ms.getResultSetType());
            builder.cache(ms.getCache());
            builder.flushCacheRequired(ms.isFlushCacheRequired());
            builder.useCache(ms.isUseCache());
            return builder.build();
        }
    
        public static class BoundSqlSqlSource implements SqlSource {
            private final BoundSql boundSql;
    
            public BoundSqlSqlSource(BoundSql boundSql) {
                this.boundSql = boundSql;
            }
    
            public BoundSql getBoundSql(Object parameterObject) {
                return boundSql;
            }
        }
    
        private enum Layer {
            // 第一种情况:在业务层进行条件构造产生的模糊查询关键字
            SERVICES_CONSTRUCTOR,
            // 第二种情况:未使用条件构造器,但是在service层进行了查询关键字与模糊查询符%手动拼接
            SERVICES_NON_CONSTRUCTOR,
            // 第三种情况:在Mapper类的注解SQL中进行了模糊查询的拼接
            MAPPER,
            // 默认情况:不处理
            DEFAULT
        }
    
        @Override
        public Object plugin(Object target) {
            return Interceptor.super.plugin(target);
        }
    
        @Override
        public void setProperties(Properties properties) {
            Interceptor.super.setProperties(properties);
        }
    }
    
    // object2ObjectMap
    public static Map<String, Object> object2ObjectMap(Object srcObj) {
        Map<String, Object> targetObject = new HashMap<>();
        if (srcObj instanceof Map<?, ?> objectMap) {
            for (Map.Entry<?, ?> entry : objectMap.entrySet()) {
                targetObject.put((String) entry.getKey(), entry.getValue());
            }
        } else {
            targetObject = JsonUtil.jsonToMap(JsonUtil.objectToJson(srcObj));
        }
        return targetObject;
    }
    
  4. 如果是mysql数据库那么做到这就已经可以了,实测中我们发现,在一些特殊的数据库中可能还需要增加ESCAPE关键字才可以,比如gauss,需要注意ESCAPE并不会影响原有语句的逻辑

    # 高斯数据库中查询语句,前两者有相同的作用:
    SELECT * FROM DISH WHERE ID LIKE CONCAT('%', '_', '%'); # 查询所有数据
    # 转换为↓
    SELECT * FROM DISH WHERE ID LIKE CONCAT('%', '\_', '%'); # 查询所有数据
    # 增强为↓,会查询id中有_的数据
    SELECT * FROM DISH WHERE ID LIKE CONCAT('%', '\_', '%') ESCAPE '\'; 
    

      因此我们需要对原有的sql语句做增强,对于MybatisPlus,因为语法是固定的,可直接在下图的位置进行sql替换

    SpringBoot自定义Mybatis数据源,SpringBoot集成Druid数据库连接池,自定义Mybatis拦截器转义特殊字符\%__第1张图片

      对于我们在xml中定义的sql语句,也可以采用上述的方式进行全局的匹配替换,但不建议这么做,一是xml中定义的sql由于多人开发语法可能并不规范,二是我们在实际场景中可能会对这种模糊查询语句再嵌套一层函数比如to_char()、to_lower()等等,所以建议在xml中自己增强sql,避免出现意想不到的问题

    SpringBoot自定义Mybatis数据源,SpringBoot集成Druid数据库连接池,自定义Mybatis拦截器转义特殊字符\%__第2张图片

你可能感兴趣的:(第三方组件,mybatis,spring,boot,数据库)