Spring boot 添加 XssFilter过滤器

第一步部分代码: XSS_ERROR(90006, “入参含有非法字符”)

@Component
@Slf4j
@WebFilter(filterName = "xssFilter", urlPatterns = "/*")
@Order(5)
public class XssFilter implements Filter {

    private static final String SCRIPT_LOW_REGEX = ".*((((\\%3C)|<)[^\\n]+((\\%3E)|>))|(((\\%22)|"|(\\%27)|')[(\\%20) ]*((\\%2B)|\\+|(\\%3B)|;))|(((\\%3D)|=)[(\\%20) ]*((\\%22)|"|(\\%27)|'))).*";
    private static final String SCRIPT_UPPER_REGEX = ".*((((\\%3C)|<)[^\\n]+((\\%3E)|>))|(((\\%22)|"|(\\%27)|')[(\\%20) ]*((\\%2B)|\\+|(\\%3B)|;))|(((\\%3D)|=)[(\\%20) ]*((\\%22)|"|(\\%27)|'))).*";
    private static final String SQL_REGEX = ".*((\\%27)|(\\')).*((\\-\\-)|(((\\%6F)|o|(\\%4F))((\\%72)|r|(\\%52)))|((\\%3B)|(;))).*";
    private static final String NOT_PROTECT = "/undefined$";

    @Override
    public void init(FilterConfig filterConfig) throws ServletException {

    }

    @Override
    public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException {
        System.out.println("*************执行过滤xssFilter******");
        HttpServletResponse res = (HttpServletResponse) servletResponse;
        res.setCharacterEncoding("UTF-8");
        res.setContentType("application/json; charset=utf-8");
        HttpServletRequest request ;
        if (servletRequest instanceof HttpServletRequest) {
            request = new XssHttpServletRequestWrapper((HttpServletRequest) servletRequest);
        } else {
            filterChain.doFilter(servletRequest, servletResponse);
            return;
        }

        Map<String, String[]> parameterMap = getParams(request);
        if (Objects.isNull(parameterMap) || parameterMap.size() == 0) {
            parameterMap = request.getParameterMap();
        }

        String url = request.getRequestURI();
        final Pattern scriptLowRegex = Pattern.compile(SCRIPT_LOW_REGEX);
        final Pattern sqlRegex = Pattern.compile(SQL_REGEX);
        final Pattern notProject = Pattern.compile(NOT_PROTECT);
        final Pattern scriptUpperRegex = Pattern.compile(SCRIPT_UPPER_REGEX);
        if (!notProject.matcher(url).find()) {
            if (Objects.nonNull(parameterMap) && parameterMap.size() > 0) {
                Iterator<Map.Entry<String, String[]>> iterator = parameterMap.entrySet().iterator();
                while (iterator.hasNext()) {
                    Map.Entry<String, String[]> next = iterator.next();
                    String[] value = next.getValue();
                    for (int i = 0; i < value.length; i++) {
                        String paraValue = value[i].toUpperCase();
                        paraValue = replaceStr(paraValue);
                        if (scriptLowRegex.matcher(paraValue).matches()|| sqlRegex.matcher(paraValue).matches() || scriptUpperRegex.matcher(paraValue).matches()) {
                            PrintWriter writer = null;
                            try {
                                writer = res.getWriter();
                                String resultVal = JSON.toJSONString(ApiRes.getInstance(ResultEnum.XSS_ERROR));
                                writer.write(resultVal);
                                writer.flush();
                                writer.close();
                            } catch (Exception e) {
                                log.error("AuthPathFilter Error" + e.getMessage(), e);
                            } finally {
                                if (null != writer) {
                                    writer.close();
                                }
                            }
                            return;
                        }
                    }
                }
            }
        }
        filterChain.doFilter(request, servletResponse);
        return;
    }

    @Override
    public void destroy() {

    }

    private Map<String, String[]> getParams(HttpServletRequest request) {
        Map<String, String[]> paras = null;
        BufferedReader streamReader = null;
        try {
            streamReader = new BufferedReader(new InputStreamReader(request.getInputStream(), "UTF-8"));
            StringBuilder responseStrBuilder = new StringBuilder();
            String inputStr;
            while ((inputStr = streamReader.readLine()) != null) {
                responseStrBuilder.append(inputStr);
            }
            String jString = responseStrBuilder.toString();
            if (StringUtils.isNotBlank(jString)) {
                boolean valid = isJSONValid(jString);
                if (valid) {
                    paras = new HashMap<>();
                    JSONObject jsonObject = JSONObject.parseObject(jString);
                    Iterator<Map.Entry<String, Object>> iterator = jsonObject.entrySet().iterator();
                    while (iterator.hasNext()) {
                        Map.Entry<String, Object> next = iterator.next();
                        String key = next.getKey();
                        if (Objects.isNull(next.getValue())) {
                            paras.put(key, new String[]{""});
                        } else {
                            String value = next.getValue().toString();
                            paras.put(key, new String[]{value});
                        }
                    }
                }
            }
        } catch (Exception e) {
            log.error("参数解析错误", e);
        }
        return paras;
    }

    private boolean isJSONValid(String jsonInString) {
        try {
            final ObjectMapper mapper = new ObjectMapper();
            mapper.readTree(jsonInString);
            return true;
        } catch (IOException e) {
            return false;
        }
    }

    private String replaceStr(String value) {
        value = value.replace("+", "%2B")
                //.replace("/", "%2F")
                .replace("?", "%3F")
                .replace("%", "%25")
                .replace("#", "%23")
                .replace("&", "%26")
                .replace("=", "%3D")
                .replace("@", "%40")
                .replace(":", "%3A")
                .replace(";", "%3B")
                .replace("<", "%3C")
                .replace(">", "%3E")
                .replace("\\", "%5C")
                .replace("|", "%7C")
                .replace("$", "%24")
                .replace("^", "%5E")
                .replace(",", "%2C")
                .replace("'", "%27")
                .replace("=", "%3D")
                .replace("[", "%5B")
                .replace("]", "%5D")
                .replace("{", "%7B")
                .replace("}", "%7D")
                .replace("\"", "%22");
        return value;
    }

}

第二部分代码:

public class XssHttpServletRequestWrapper extends HttpServletRequestWrapper {

    private final byte[] bytes;

    /**
     * Constructs a request object wrapping the given request.
     *
     * @param request The request to wrap
     * @throws IllegalArgumentException if the request is null
     */
    public XssHttpServletRequestWrapper(HttpServletRequest request) throws IOException {
        super(request);
        // 读取输入流里的请求参数,并保存到bytes里
        bytes = IOUtils.toByteArray(request.getInputStream());

    }
    @Override
    public ServletInputStream getInputStream() throws IOException {
        return new BufferedServletInputStream(this.bytes);
    }

    class BufferedServletInputStream extends ServletInputStream {
        private ByteArrayInputStream inputStream;
        public BufferedServletInputStream(byte[] buffer) {
            //此处即赋能,可以详细查看ByteArrayInputStream的该构造函数;
            this.inputStream = new ByteArrayInputStream( buffer );
        }
        @Override
        public int available() throws IOException {
            return inputStream.available();
        }
        @Override
        public int read() throws IOException {
            return inputStream.read();
        }
        @Override
        public int read(byte[] b, int off, int len) throws IOException {
            return inputStream.read( b, off, len );
        }

        @Override
        public boolean isFinished() {
            return false;
        }

        @Override
        public boolean isReady() {
            return false;
        }

        @Override
        public void setReadListener(ReadListener listener) {

        }
    }
}

你可能感兴趣的:(Spring,Boot)