解决在Filter中读取Request中的流后, 然后在Controller中@RequestBody的参数无法注入而导致 400 错误

摘要: 大家知道, StringMVC中@RequestBody是读取的流的方式, 如果在之前有读取过流后, 发现就没有了.

我的Filter为了验证请求参数(包括Request Payload的数据)是否有非法符号(sql注入)

package com.ks.tow.common.filter;


import java.io.BufferedReader;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.PrintWriter;
import java.io.Reader;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.List;
import java.util.Map;


import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.ServletException;
import javax.servlet.ServletInputStream;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import javax.servlet.http.HttpServletResponse;


import org.apache.commons.io.IOUtils;
import org.springframework.web.util.ContentCachingRequestWrapper;


import com.alibaba.fastjson.JSONObject;
import com.ks.shop.util.secure.CheckSQLInjectionUtil;
import com.ks.tow.common.enums.HttpStatus;
import com.ks.tow.util.StringUtil;


/**
 * 防sql注入攻击过滤
 * @author LIU
 *
 */
public class CheckSQLInjectionFilter implements Filter {
	
	private List excludes = new ArrayList<>();
	
	public void setExcludes(List excludes) {
		this.excludes = excludes;
	}
	public List getExcludes() {
		return excludes;
	}


	@Override
	public void init(FilterConfig filterConfig) throws ServletException {
		String excludes = filterConfig.getInitParameter("excludes");
		if (StringUtil.isNotBlank(excludes)) {
			String[] array = excludes.split(",");
			for (String url : array) {
				this.excludes.add(url);
			}
		}
	}


	@Override
	public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
			throws IOException, ServletException {
		HttpServletRequest req = (HttpServletRequest)request;
		HttpServletResponse resp = (HttpServletResponse)response;
		String requestPath = req.getRequestURI();
		requestPath = requestPath.substring(req.getContextPath().length() + 1);
		while (requestPath.endsWith("/")){ //预防uri末尾有 ‘/’
			requestPath = requestPath.substring(0, requestPath.length()-1);
		}
		for (String str : excludes) {
			if (str.endsWith("*")) {
				if (requestPath.startsWith(str.substring(0, str.length() - 1))){
					chain.doFilter(req, resp);
					return;
				}
			}
			if(str.equals(requestPath)) {
				chain.doFilter(req, resp);
				return;
			}
		}
		Map paramMap = new HashMap<>();
		String type = req.getContentType();
		ServletRequest requestWrapper = null;  
	    if(req instanceof HttpServletRequest) {  
	    	   requestWrapper = new ReaderReuseHttpServletRequestWrapper(req); 
	    }
		Reader reader = requestWrapper.getReader();
		// 读取Request Payload数据
		String Payload = IOUtils.toString(reader);
		if (type != null && type.startsWith("application/json")){
			JSONObject jsonObject = JSONObject.parseObject(Payload);
			if (jsonObject != null) {
				for(Map.Entry entry : jsonObject.entrySet()) {
					paramMap.put(entry.getKey(), entry.getValue());
				}
			}
		} else if(type != null && type.startsWith("text/plain")) {
			String[] kvs = Payload.split("&");
			for (String kv : kvs) {
				String[] lf = kv.split("=");
				paramMap.put(lf[0], lf[1]);
			}
			
		}
		// 获取请求参数
		Enumeration en = req.getParameterNames();
		while(en.hasMoreElements()) {
			String name = (String) en.nextElement();
			String value = req.getParameter(name);
			paramMap.put(name, value);
		}
		for(Map.Entry node : paramMap.entrySet()) {
			boolean valid = true;
			if (node.getValue() instanceof String)
				valid = CheckSQLInjectionUtil.validate((String)node.getValue());
			if (!valid) {
				resp.setContentType("application/json;charset=UTF-8");
				PrintWriter writer = resp.getWriter();
				writer.write("{\"success\":false,\"msg\":\""+HttpStatus.SECURITY.getName()+"\",\"code\":"+HttpStatus.SECURITY.getCode()+"}");
				writer.flush();
				return;
			}
		}
		chain.doFilter(requestWrapper, resp);
	}


	@Override
	public void destroy() {
		
	}
	
	/**
	 * 两个方法都注明方法只能被调用一次,由于RequestBody是流的形式读取,
	 * 那么流读了一次就没有了,所以只能被调用一次。
	 * 既然是因为流只能读一次的原因,那么只要将流的内容保存下来,就可以实现反复读取了
	 * @author LIU
	 *
	 */
	public static class ReaderReuseHttpServletRequestWrapper extends HttpServletRequestWrapper 	{


	    private final byte[] body;  
	      
	    public ReaderReuseHttpServletRequestWrapper(HttpServletRequest request)   
	    		throws IOException {  
	        super(request);
	        body = IOUtils.toString(request.getReader()).getBytes(Charset.forName("UTF-8"));
	    }  
	  
	    @Override  
	    public BufferedReader getReader() throws IOException {  
	        return new BufferedReader(new InputStreamReader(getInputStream()));  
	    }  
	  
	    @Override  
	    public ServletInputStream getInputStream() throws IOException {  
	        final ByteArrayInputStream bais = new ByteArrayInputStream(body);  
	        return new ServletInputStream() {


	            @Override
	            public int read() throws IOException {
	                return bais.read();
	            }
	
	        }; 
	    }  
	}
}

请注意这里的编码, 最好将其转换成UTF-8的编码格式, 不然你获取到的中文则会使乱码的. 我自己也习惯于UTF-8的编码.

这样子就应该差不多了哦~

以下是校验sql注入的关键代码

public final class CheckSQLInjectionUtil {
	private static final String sqlReg = "(?:')|(?:--)|(/\\*(?:.|[\\n\\r])*?\\*/)|"  
            + "(\\b(select|update|and|or|delete|insert|trancate|char|into|substr|"
            + "ascii|declare|exec|count|master|into|drop|execute)\\b)";
	
	private static Pattern pattern = Pattern.compile(sqlReg, Pattern.CASE_INSENSITIVE);
	
	/**
	 * 检查SQL注入
	 * @param str
	 */
	public static boolean validate(String str) {
		if (pattern.matcher(str).find()) {
			return false;
		}
		return true;
	}
	/**
	 * 检查SQL注入
	 * @param strs
	 */
	public static boolean validate(String[] strs) {
		for (String str : strs) {
			if (pattern.matcher(str).find()) {
				return false;
			}
		}
		return true;
	}
	
}


你可能感兴趣的:(web,spring)