动手写个java快速开发框架-(6)XSS防护

之前写的几篇文章已经把一个快速开发脚手架基本搭建起来了,但是之前一致没有考虑安全问题,今天就抛砖引玉在框架中加入XSS防护,其实类似的还有很多,例如防SQL注入、安全字符校验等这些,本篇文章里就都实现了,后续会在框架中逐渐完善。

XSS在百度百科里的解释是:

XSS攻击全称跨站脚本攻击,是为不和层叠样式表(Cascading Style Sheets, CSS)的缩写混淆,故将跨站脚本攻击缩写为XSS,XSS是一种在web应用中的计算机安全漏洞,它允许恶意web用户将代码植入到提供给其它用户使用的页面中.

在MkFramework中实现防XSS攻击主要是通过Filter来实现,这里我们需要实现XssFilter、XssHttpServletRequestWrapper,核心的Xss防护功能都在这个装饰器Wrapper中,因为涉及到对HttpRequest内容的处理,所以必须要在装饰器类中进行处理,下面我们来看下代码。

首先定义一个继承自Filter的XssFilter,在Filter中将Xss过滤核心处理类XssHttpServletRequestWrapper的对象传递给servlet或controller类进行处理。

public class XssFilter implements Filter {

   @Override
   public void init(FilterConfig config) throws ServletException {
   }

   public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
            throws IOException, ServletException {
      XssHttpServletRequestWrapper xssRequest = new XssHttpServletRequestWrapper(
            (HttpServletRequest) request);
      chain.doFilter(xssRequest, response);
   }

   @Override
   public void destroy() {
   }

}

定义好了Filter后,当然还需要将这个Filter注册到IOC容器中,那么需要再定义个config,用来在容器启动的时候将FilterBean注册到IOC里。

@Configuration
public class FilterConfig {
    @Bean
    public FilterRegistrationBean xssFilterRegistration() {
        FilterRegistrationBean registration = new FilterRegistrationBean();
        registration.setDispatcherTypes(DispatcherType.REQUEST);
        registration.setFilter(new XssFilter());
        registration.addUrlPatterns("/*");
        registration.setName("xssFilter");
        registration.setOrder(Integer.MAX_VALUE);
        return registration;
    }
}

接下来就是核心装饰器类的实现了,在XSS过滤装饰器类中会用到一个HTMLFilter类,该类直接用了Joseph写HTMLFilter,可以对XSS常见的跨站攻击脚本进行过滤,具体源码直接看git。

在XssHttpServletRequestWrapper中我们重写了getInputStream()、getParameter()、getParameterValues()、getParameterMap()、getHeader()这些方法,在这些方法中分别调用了HtmlFilter中的XSS关键字过滤方法来进行处理。

public class XssHttpServletRequestWrapper extends HttpServletRequestWrapper {
    HttpServletRequest orgRequest;
    //html过滤
    private final static HTMLFilter htmlFilter = new HTMLFilter();

    public XssHttpServletRequestWrapper(HttpServletRequest request) {
        super(request);
        orgRequest = request;
    }

    @Override
    public ServletInputStream getInputStream() throws IOException {
        //非json类型,直接返回
        if(!MediaType.APPLICATION_JSON_VALUE.equalsIgnoreCase(super.getHeader(HttpHeaders.CONTENT_TYPE))){
            return super.getInputStream();
        }

        //为空,直接返回
        String json = IOUtils.toString(super.getInputStream(), "utf-8");
        if (StringUtils.isBlank(json)) {
            return super.getInputStream();
        }

        //xss过滤
        json = xssEncode(json);
        final ByteArrayInputStream bis = new ByteArrayInputStream(json.getBytes("utf-8"));
        return new ServletInputStream() {
            @Override
            public boolean isFinished() {
                return true;
            }

            @Override
            public boolean isReady() {
                return true;
            }

            @Override
            public void setReadListener(ReadListener readListener) {

            }

            @Override
            public int read() throws IOException {
                return bis.read();
            }
        };
    }

    @Override
    public String getParameter(String name) {
        String value = super.getParameter(xssEncode(name));
        if (StringUtils.isNotBlank(value)) {
            value = xssEncode(value);
        }
        return value;
    }

    @Override
    public String[] getParameterValues(String name) {
        String[] parameters = super.getParameterValues(name);
        if (parameters == null || parameters.length == 0) {
            return null;
        }

        for (int i = 0; i < parameters.length; i++) {
            parameters[i] = xssEncode(parameters[i]);
        }
        return parameters;
    }

    @Override
    public Map getParameterMap() {
        Map map = new LinkedHashMap<>();
        Map parameters = super.getParameterMap();
        for (String key : parameters.keySet()) {
            String[] values = parameters.get(key);
            for (int i = 0; i < values.length; i++) {
                values[i] = xssEncode(values[i]);
            }
            map.put(key, values);
        }
        return map;
    }

    @Override
    public String getHeader(String name) {
        String value = super.getHeader(xssEncode(name));
        if (StringUtils.isNotBlank(value)) {
            value = xssEncode(value);
        }
        return value;
    }

    private String xssEncode(String input) {
        return htmlFilter.filter(input);
    }

    /**
     * 获取最原始的request
     */
    public HttpServletRequest getOrgRequest() {
        return orgRequest;
    }

    /**
     * 获取最原始的request
     */
    public static HttpServletRequest getOrgRequest(HttpServletRequest request) {
        if (request instanceof XssHttpServletRequestWrapper) {
            return ((XssHttpServletRequestWrapper) request).getOrgRequest();
        }

        return request;
    }

}

至此整个XSS过滤的模块就写完了,类似的大家还可以添加防SQL注入的模块,这里就不细描述了,思路也是一样的,都是通过Filter来实现。

本文对应的github tag为v0.6,可以通过连接下载https://github.com/feiweiwei/MkFramework4java/releases/tag/v0.6,也可以通过git clone -b v0.6 https://github.com/feiweiwei/MkFramework4java.git

1040x100

赶快点击下面的链接3折购买腾讯云产品,搭建第一个自己的服务器,只有自己动手搭建了才能快速成长。
https://cloud.tencent.com/redirect.php?redirect=1014&cps_key=62032bbf66cfc9ae14665ee20e89e29b&from=console

你可能感兴趣的:(动手写个java快速开发框架-(6)XSS防护)