[druid 源码解析] 10 wallFilter解析

接下来,我们将讲解 druid pool 包以外的包解析,这次我们先从 wallFilter 开始说起,我们先来写一个 wallFilter 的 example,首先我们需要在配置文件中开启 wallFilter ,接着我们从配置开始说起,配置信息如下:

Spring:
  datasource:
    druid:
      filter:
        wall:
       enabled: true
          config:
            select-where-alway-true-check: true

首先需要开启 wallFilter ,然后配置 config,这里配置了 select-where-alway-true-check: true 就是检查永真条件的 where 语句,除了以上的配置外,还包可以配置如下属性:

proerties

properties

我们先来测试一下 select-where-alway-true-check: true 属性,我们的 Mybatis 的 Mapper 文件中配置了 wehere 1 = 1 , 这个条件,然后进行测试,会发现如下报错信息:

java.sql.SQLException: sql injection violation, dbType mysql, druid-version 1.2.8, not terminal sql, token WHEN : select
    ......
    from TABLES
    when 1 = 1
    at com.alibaba.druid.wall.WallFilter.checkInternal(WallFilter.java:859) ~[druid-1.2.8.jar:1.2.8]
    at com.alibaba.druid.wall.WallFilter.connection_prepareStatement(WallFilter.java:295) ~[druid-1.2.8.jar:1.2.8]
    at com.alibaba.druid.filter.FilterChainImpl.connection_prepareStatement(FilterChainImpl.java:568) ~[druid-1.2.8.jar:1.2.8]
    at com.alibaba.druid.filter.FilterAdapter.connection_prepareStatement(FilterAdapter.java:930) ~[druid-1.2.8.jar:1.2.8]

我们可以看到,这里会直接报错,SQL 注入异常,我们根据堆栈位置,找出 WallFilter 的入口位置, 如下:


    @Override
    public PreparedStatementProxy connection_prepareStatement(FilterChain chain, ConnectionProxy connection, String sql)
                                                                                                                        throws SQLException {
        return chain.connection_prepareStatement(connection, sql);
    }

这里我们之前有讲过,这里是责任链模式,这里会先加载所有的 Filter 然后每个 Filter 通过递归的方式调用,我们再来看一下 WallFilter 的执行方法:

 @Override
    public PreparedStatementProxy connection_prepareStatement(FilterChain chain, ConnectionProxy connection, String sql)
                                                                                                                        throws SQLException {
        String dbType = connection.getDirectDataSource().getDbType();
        WallContext context = WallContext.create(dbType);
        try {
            WallCheckResult result = checkInternal(sql);
            context.setWallUpdateCheckItems(result.getUpdateCheckItems());
            sql = result.getSql();
            PreparedStatementProxy stmt = chain.connection_prepareStatement(connection, sql);
            setSqlStatAttribute(stmt);
            return stmt;
        } finally {
            WallContext.clearContext();
        }
    }

首先是根据 dbType 生成 WallContext ,这个步骤没有太复杂的程序,主要是将 dbType 设置到 WallContext 中, 接着调用 checkInternal 方法:

private WallCheckResult checkInternal(String sql) throws SQLException {
        WallCheckResult checkResult = provider.check(sql);
        List violations = checkResult.getViolations();
        if (violations.size() > 0) {
            ......
        }
        return checkResult;
    }

其实主要是调用 provider 来检查,我们看一下其实这个 provider 是在 WallFilter init 的时候进行初始化的,我们先看一下 init 方法:

 case mysql:
            case oceanbase:
            case drds:
            case mariadb:
            case h2:
            case presto:
            case trino:
                if (config == null) {
                    config = new WallConfig(MySqlWallProvider.DEFAULT_CONFIG_DIR);
                }

                provider = new MySqlWallProvider(config);
                break;
...

这里传进去的就是我们之前配置的 WallFilter 相关的 config 配置信息,我们再来看一下检查的具体逻辑:

 private WallCheckResult checkInternal(String sql) {
        checkCount.incrementAndGet();

        WallContext context = WallContext.current();

        if (config.isDoPrivilegedAllow() && ispPrivileged()) {
            WallCheckResult checkResult = new WallCheckResult();
            checkResult.setSql(sql);
            return checkResult;
        }

        // first step, check whiteList
        boolean mulltiTenant = config.getTenantTablePattern() != null && config.getTenantTablePattern().length() > 0;
        if (!mulltiTenant) {
            WallCheckResult checkResult = checkWhiteAndBlackList(sql);
            if (checkResult != null) {
                checkResult.setSql(sql);
                return checkResult;
            }
        }

        hardCheckCount.incrementAndGet();
        final List violations = new ArrayList();
        List statementList = new ArrayList();
        boolean syntaxError = false;
        boolean endOfComment = false;
        try {
            SQLStatementParser parser = createParser(sql);
            parser.getLexer().setCommentHandler(WallCommentHandler.instance);

            if (!config.isCommentAllow()) {
                parser.getLexer().setAllowComment(false); // deny comment
            }
            if (!config.isCompleteInsertValuesCheck()) {
                parser.setParseCompleteValues(false);
                parser.setParseValuesSize(config.getInsertValuesCheckSize());
            }
            
            parser.parseStatementList(statementList);

            final Token lastToken = parser.getLexer().token();
            if (lastToken != Token.EOF && config.isStrictSyntaxCheck()) {
                violations.add(new IllegalSQLObjectViolation(ErrorCode.SYNTAX_ERROR, "not terminal sql, token "
                                                                                     + lastToken, sql));
            }
            endOfComment = parser.getLexer().isEndOfComment();
        } catch (NotAllowCommentException e) {
            violations.add(new IllegalSQLObjectViolation(ErrorCode.COMMENT_STATEMENT_NOT_ALLOW, "comment not allow", sql));
            incrementCommentDeniedCount();
        } catch (ParserException e) {
            syntaxErrorCount.incrementAndGet();
            syntaxError = true;
            if (config.isStrictSyntaxCheck()) {
                violations.add(new SyntaxErrorViolation(e, sql));
            }
        } catch (Exception e) {
            if (config.isStrictSyntaxCheck()) {
                violations.add(new SyntaxErrorViolation(e, sql));
            }
        }

        if (statementList.size() > 1 && !config.isMultiStatementAllow()) {
            violations.add(new IllegalSQLObjectViolation(ErrorCode.MULTI_STATEMENT, "multi-statement not allow", sql));
        }

        WallVisitor visitor = createWallVisitor();
        visitor.setSqlEndOfComment(endOfComment);

        if (statementList.size() > 0) {
            boolean lastIsHint = false;
            for (int i=0; i 0) {
            violations.addAll(visitor.getViolations());
        }

        Map tableStat = context.getTableStats();

        boolean updateCheckHandlerEnable = false;
        {
            WallUpdateCheckHandler updateCheckHandler = config.getUpdateCheckHandler();
            if (updateCheckHandler != null) {
                for (SQLStatement stmt : statementList) {
                    if (stmt instanceof SQLUpdateStatement) {
                        SQLUpdateStatement updateStmt = (SQLUpdateStatement) stmt;
                        SQLName table = updateStmt.getTableName();
                        if (table != null) {
                            String tableName = table.getSimpleName();
                            Set updateCheckColumns = config.getUpdateCheckTable(tableName);
                            if (updateCheckColumns != null && updateCheckColumns.size() > 0) {
                                updateCheckHandlerEnable = true;
                                break;
                            }
                        }
                    }
                }
            }
        }

        WallSqlStat sqlStat = null;
        if (violations.size() > 0) {
            violationCount.incrementAndGet();

            if ((!updateCheckHandlerEnable) && sql.length() < MAX_SQL_LENGTH) {
                sqlStat = addBlackSql(sql, tableStat, context.getFunctionStats(), violations, syntaxError);
            }
        } else {
            if ((!updateCheckHandlerEnable) && sql.length() < MAX_SQL_LENGTH) {
                boolean selectLimit = false;
                if (config.getSelectLimit() > 0) {
                    for (SQLStatement stmt : statementList) {
                        if (stmt instanceof SQLSelectStatement) {
                            selectLimit = true;
                            break;
                        }
                    }
                }

                if (!selectLimit) {
                    sqlStat = addWhiteSql(sql, tableStat, context.getFunctionStats(), syntaxError);
                }
            }
        }
        
        if(sqlStat == null && updateCheckHandlerEnable){
            sqlStat = new WallSqlStat(tableStat, context.getFunctionStats(), violations, syntaxError);
        }

        Map tableStats = null;
        Map functionStats = null;
        if (context != null) {
            tableStats = context.getTableStats();
            functionStats = context.getFunctionStats();
            recordStats(tableStats, functionStats);
        }

        WallCheckResult result;
        if (sqlStat != null) {
            context.setSqlStat(sqlStat);
            result = new WallCheckResult(sqlStat, statementList);
        } else {
            result = new WallCheckResult(null, violations, tableStats, functionStats, statementList, syntaxError);
        }

        String resultSql;
        if (visitor.isSqlModified()) {
            resultSql = SQLUtils.toSQLString(statementList, dbType);
        } else {
            resultSql = sql;
        }
        result.setSql(resultSql);

        result.setUpdateCheckItems(visitor.getUpdateCheckItems());

        return result;
    }

主要做了以下几个事情:
1、检查这个 SQL 是否在白名单中,假如是就直接返回结果。
2、对 SQL 进行解析,生成 SQLStatement 列表,因为可能存在复合语句。
3、调用 SQLStatementaccept 方法,将 config 生成的 WallVisitor 放进去,然后检查是否会抛出异常,假如会,就代表存在语法错误,记录到 Result 中。

你可能感兴趣的:([druid 源码解析] 10 wallFilter解析)