XSS过滤器

今天天气晴朗,气候炎热,是时候来一点java代码来降降温了。

在做web开发的时候,可能会遇到无聊的人,或者对你的网站图谋不轨的人,当然这个是假如。俗话说:害人之心不可有,防人之心不可无。那么,我们就来做一个XSS过滤器,来抵挡一部分恶意的访问吧。

讲一下原理:把request中的请求参数放入我们自己写的包装request类中,然后写一个过滤器对这个包装request中的参数进行过滤。

为什么需要对Request进行一次包装呢?直接用不是更好吗。其实这里面主要是针对post带请求体这种的,如果要拿请求体中的参数,就要从request获取去参数,也就是从流中获取参数。但是一旦从流中获取到参数,在后面流也就关闭了,无法再次获取,再次获取就会报错。就好比男生在飞机完就会进入贤者模式,此刻你再来一个刺激,那就事情不妙了。在代码中也是一样的,也就是我们的Controler无法获取到参数。所以,才有了包装request的事。

好,开启代码之旅吧。

首先,就是定义一些常量,后面的代码会用得着:

package cn.wjp.mydaily.common.filter;

public class HttpConst {
    /**
     * 几种常见的Content-Type
     */
    public static final String FORM_URLENCODED_CONTENT_TYPE ="application/x-www-form-urlencoded";

    public static final String JSON_CONTENT_TYPE = "application/json";

    public static final String MULTIPART_CONTENT_TYPE = "multipart/form-data";

    /**
     * 常见的post/get请求方式
     */
    public static final String POST_METHOD = "post";

    public static final String GET_METHOD = "get";

    public static final String OPTIONS_METHOD = "options";

}

接下来需要分别针对不同的表单提交类型写包装类,对于普通的get/post请求的包装类:

package cn.wjp.mydaily.common.filter;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.io.IOException;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.Map;
import java.util.Vector;

/**
 * 适用于普通get post请求  不包含multipart/form-data   application/json等请求
 * 缓存请求参数   重写获取参数的方法
 */
public class HttpServletRequestNormalWrapper extends HttpServletRequestWrapper {

    private Map parameterMap = new HashMap<>(); // 所有参数的Map集合

    public HttpServletRequestNormalWrapper(HttpServletRequest request){
        super(request);
        Enumeration params = request.getParameterNames();//获得所有请求参数名
        StringBuffer paramsValue = new StringBuffer("");
        while (params.hasMoreElements()) {
            String name = params.nextElement().toString(); //得到参数名
            String[] value = request.getParameterValues(name);//得到参数对应值
            parameterMap.put(name,value);
        }
    }

    /**
     * 获取所有参数名
     *
     * @return 返回所有参数名
     */
    @Override
    public Enumeration getParameterNames() {
        Vector vector = new Vector(parameterMap.keySet());
        return vector.elements();
    }

    /**
     * 获取指定参数名的值,如果有重复的参数名,则返回第一个的值 接收一般变量 ,如text类型
     *
     * @param name 指定参数名
     * @return 指定参数名的值
     */
    @Override
    public String getParameter(String name) {
        String[] values = parameterMap.get(name);
        if(values==null||values.length==0){
            return null;
        }
        return values[0];
        //如果有多个参数值的,请放开该注释  我这里没有做这么细致
        /*StringBuffer sb = new StringBuffer();
        for(int i=0;i getParameterMap() {
        return parameterMap;
    }

    public void setParameterMap(Map parameterMap) {
        this.parameterMap = parameterMap;
    }
}
对于post+application/json的请求,对应的包装类:

package cn.wjp.mydaily.common.filter;

import javax.servlet.ReadListener;
import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.io.*;

/**
 * 将body中的数据缓存起来  重写getInputStream  getReader 等方法  适用于application/json的post请求
 */
public class HttpServletRequestBodyReaderWrapper extends HttpServletRequestWrapper{

    private String body ="{}";//缓存请求体的内容

    public HttpServletRequestBodyReaderWrapper(HttpServletRequest request) throws IOException {
        super(request);
        StringBuilder stringBuilder = new StringBuilder("");
        BufferedReader bufferedReader = null;
        try {
            InputStream inputStream = request.getInputStream();
            if (inputStream != null) {
                bufferedReader = new BufferedReader(new InputStreamReader(inputStream));
                char[] charBuffer = new char[1024];
                int bytesRead = -1;
                while ((bytesRead = bufferedReader.read(charBuffer)) > 0) {
                    stringBuilder.append(charBuffer, 0, bytesRead);
                }
            }
        } catch (IOException ex) {
            throw ex;
        } finally {
            if (bufferedReader != null) {
                try {
                    bufferedReader.close();
                } catch (IOException ex) {
                    throw ex;
                }
            }
        }
        body = stringBuilder.toString();
    }

    @Override
    public ServletInputStream getInputStream() throws IOException {
        final ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(body.getBytes("utf-8"));
        ServletInputStream servletInputStream = new ServletInputStream() {
            public boolean isFinished() {
                return false;
            }
            public boolean isReady() {
                return true;
            }
            public void setReadListener(ReadListener readListener) {}
            public int read() throws IOException {
                return byteArrayInputStream.read();
            }
        };
        return servletInputStream;

    }
    @Override
    public BufferedReader getReader() throws IOException {
        return new BufferedReader(new InputStreamReader(this.getInputStream()));
    }
    public String getBody() {
        return this.body;
    }

    public void setBody(String body) {
        this.body = body;
    }
}

接下来就是XSS校验啦:

package cn.wjp.mydaily.common.filter;

import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.TypeReference;

import javax.servlet.*;
import javax.servlet.FilterConfig;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;

public class XssFilter implements Filter {

    public void doFilter(ServletRequest req, ServletResponse res, FilterChain chain) throws IOException, ServletException {
        HttpServletRequest request =(HttpServletRequest)req;
        HttpServletResponse response =(HttpServletResponse)res;

        String contentType = request.getContentType();//获取contentType请求头
        String method = request.getMethod();//获取请求方法  post/get
        //1 处理get请求  get请求的Content-Type一般为application/x-www-form-urlencoded  或者  text/html
        if(method.trim().equalsIgnoreCase(HttpConst.GET_METHOD)){
            HttpServletRequestNormalWrapper wrapperRequest = new HttpServletRequestNormalWrapper(request);
            Map parameterMap = wrapperRequest.getParameterMap();
            parameterMap =cleanXSSForNormalRequest(parameterMap);
            wrapperRequest.setParameterMap(parameterMap);
            chain.doFilter(wrapperRequest, response);
            return;
        }
        //2 处理post请求  只处理application/x-www-form-urlencoded  application/json,对于multipart/form-data,直接放行
        if(method.trim().equalsIgnoreCase(HttpConst.POST_METHOD)){
            if(contentType.trim().toLowerCase().contains(HttpConst.MULTIPART_CONTENT_TYPE)){
                chain.doFilter(request, response);
                return;
            }
            //处理application/x-www-form-urlencoded
            if(contentType.trim().toLowerCase().contains(HttpConst.FORM_URLENCODED_CONTENT_TYPE)){
                HttpServletRequestNormalWrapper wrapperRequest = new HttpServletRequestNormalWrapper(request);
                Map parameterMap = wrapperRequest.getParameterMap();
                parameterMap =cleanXSSForNormalRequest(parameterMap);
                wrapperRequest.setParameterMap(parameterMap);
                chain.doFilter(wrapperRequest, response);
                return;
            }
            //处理application/json
            if(contentType.trim().toLowerCase().contains(HttpConst.JSON_CONTENT_TYPE)){
                HttpServletRequestBodyReaderWrapper requestWrapper = new HttpServletRequestBodyReaderWrapper(request);
                String body = requestWrapper.getBody();
                body =cleanXSSForPostJsonRequest(body);
                requestWrapper.setBody(body);
                chain.doFilter(requestWrapper, response);
                return;
            }
        }
        chain.doFilter(request, response);
        return;

    }

    public String cleanXSS(String value) {
        if(value==null||value.trim().isEmpty()){
            return null;
        }
        value = value.replaceAll("<", "& lt;").replaceAll(">", "& gt;");
        value = value.replaceAll("\\(", "& #40;").replaceAll("\\)", "& #41;");
        value = value.replaceAll("'", "& #39;");
        value = value.replaceAll("\"", "& #34;");
        value = value.replaceAll("`", "");
        value = value.replaceAll("eval\\((.*)\\)", "");
        value = value.replaceAll("[\\\"\\\'][\\s]*javascript:(.*)[\\\"\\\']", "\"\"");
        value = value.replaceAll("script", "");
        return value;
    }

    /**
     * 普通的post/get请求
     * @param parameterMap
     */
    public Map cleanXSSForNormalRequest(Map parameterMap){
        Map cleanMap = new HashMap<>();
        if(parameterMap==null||parameterMap.size()==0){
            return cleanMap;
        }
        for (Map.Entry entry : parameterMap.entrySet()) {
            String key = entry.getKey();
            String[] value = entry.getValue();
            String cleanKey = cleanXSS(key);
            String[] cleanValue = null;
            if(value!=null&&value.length>0){
                cleanValue = new String[value.length];
                for(int i=0;i entry1 : cleanMap.entrySet()) {
            printStr.append(entry1.getKey()).append("=").append(Arrays.asList(entry1.getValue())).append("&");
        }
        System.out.println("XssFilter:发送的请求参数:"+JSON.toJSONString(printStr));
        return cleanMap;
    }

    /**
     * post的application/json请求
     * @param body
     */
    public String cleanXSSForPostJsonRequest(String body){
        String cleanBody = "{}";
        if(body==null||body.trim().isEmpty()||body.trim().equalsIgnoreCase("{}")||!body.trim().contains(":")){
            return cleanBody;
        }
        Map map = JSON.parseObject(body,new TypeReference>(){});
        if(map==null||map.size()==0){
            return cleanBody;
        }
        Map cleanMap = new HashMap<>();
        for (Map.Entry entry : map.entrySet()) {
            String key = entry.getKey();
            Object value = entry.getValue();
            String valueStr = String.valueOf(value);
            if(valueStr==null||valueStr.trim().isEmpty()||valueStr.trim().equalsIgnoreCase("null")){
                valueStr = null;
            }
            cleanMap.put(cleanXSS(key),cleanXSS(valueStr));
        }
        cleanBody = JSON.toJSONString(cleanMap);
        System.out.println("XssFilter:发送的请求参数:"+cleanBody);
        return cleanBody;
    }


    @Override
    public void destroy() {
    }

    @Override
    public void init(FilterConfig arg0) {
    }
}

“好,那个谁,到我办公室来一下。今天的xss过滤工作做得不错,就奖励你今晚一个加班吧,要知道,这个加班名额真的好不容易,我费了很大的劲才说服咱们组今晚早点下班的,好让你有加班的机会,你可以不要让我失望哟。好,你继续忙吧”。

你可能感兴趣的:(XSS过滤器)