处理request.getInputStream()输入流只能读取一次问题

一般我们会在InterceptorAdapter拦截器中对请求进行验证

正常普通接口请求,request.getParameter()可以获取,能多次读取

如果我们的接口是用@RequestBody来接受数据,那么我们在拦截器中

需要读取request的输入流  ,因为 ServletRequest中getReader()和getInputStream()只能调用一次

这样就会导致controller 无法拿到数据。


解决方法 :


自定义一个类 BodyReaderHttpServletRequestWrapper.java 


  1. package com.banxue.handle.web;


    import java.io.BufferedInputStream;
    import java.io.BufferedReader;
    import java.io.ByteArrayInputStream;
    import java.io.ByteArrayOutputStream;
    import java.io.IOException;
    import java.io.InputStream;
    import java.io.InputStreamReader;
    import java.io.UnsupportedEncodingException;
    import java.net.URLDecoder;
    import java.util.Collections;
    import java.util.Enumeration;
    import java.util.HashMap;
    import java.util.Map;
    import java.util.StringTokenizer;


    import javax.servlet.ServletInputStream;
    import javax.servlet.http.HttpServletRequest;
    import javax.servlet.http.HttpServletRequestWrapper;


    import org.apache.commons.lang3.StringUtils;


    import com.banxue.utils.Encodes;


    /**
     * 自定义HttpServletRequestWrapper,让request输入流重复使用多次
     * 
     * @author xiaoph
     *
     */
    public class BodyReaderHttpServletRequestWrapper extends HttpServletRequestWrapper {


        private static final String UTF_8 = "UTF-8";
        private Map paramsMap;


        private final byte[] body; // 报文


        public BodyReaderHttpServletRequestWrapper(HttpServletRequest request) throws IOException {
            super(request);
            body = readBytes(request.getInputStream());


            // 首先从POST中获取数据
            if ("POST".equals(request.getMethod().toUpperCase())) {
                paramsMap = getParamMapFromPost(this);
            } else {
                paramsMap = getParamMapFromGet(this);
            }


        }


        @Override
        public Map getParameterMap() {
            return paramsMap;
        }


        @Override
        public String getParameter(String name) {// 重写getParameter,代表参数从当前类中的map获取
            String[] values = paramsMap.get(name);
            if (values == null || values.length == 0) {
                return null;
            }
            return values[0];
        }


        @Override
        public String[] getParameterValues(String name) {// 同上
            return paramsMap.get(name);
        }


        @Override
        public Enumeration getParameterNames() {
            return Collections.enumeration(paramsMap.keySet());
        }


        @Override
        public BufferedReader getReader() throws IOException {
            return new BufferedReader(new InputStreamReader(getInputStream()));
        }


        @Override
        public ServletInputStream getInputStream() throws IOException {
            final ByteArrayInputStream bais = new ByteArrayInputStream(body);
            return new ServletInputStream() {


                @Override
                public int read() throws IOException {
                    return bais.read();
                }
            };
        }


        private Map getParamMapFromGet(HttpServletRequest request) {
            return parseQueryString(request.getQueryString());
        }


        private HashMap getParamMapFromPost(HttpServletRequest request) {


            String body = StringUtils.EMPTY;
            try {
                body = getRequestBody(request.getInputStream());
            } catch (IOException e) {
                e.printStackTrace();
            }
            HashMap result = new HashMap();


            if (null == body || 0 == body.length()) {
                return result;
            }


            return parseQueryString(body);
        }


        private String getRequestBody(InputStream stream) throws IOException {
            String line = StringUtils.EMPTY;
            StringBuilder body = new StringBuilder();
            int counter = 0;


            // 读取POST提交的数据内容
            BufferedReader reader = new BufferedReader(new InputStreamReader(stream));
            while ((line = reader.readLine()) != null) {
                if (counter > 0) {
                    body.append("rn");
                }
                body.append(line);
                counter++;
            }
            reader.close();
            return body.toString();
        }


        public HashMap parseQueryString(String s) {
            String valArray[] = null;
            if (s == null) {
                throw new IllegalArgumentException();
            }
            HashMap ht = new HashMap();
            StringTokenizer st = new StringTokenizer(s, "&");
            while (st.hasMoreTokens()) {
                String pair = (String) st.nextToken();
                int pos = pair.indexOf('=');
                if (pos == -1) {
                    continue;
                }
                String key = pair.substring(0, pos);
                String val = pair.substring(pos + 1, pair.length());
                if (ht.containsKey(key)) {
                    String oldVals[] = (String[]) ht.get(key);
                    valArray = new String[oldVals.length + 1];
                    for (int i = 0; i < oldVals.length; i++) {
                        valArray[i] = oldVals[i];
                    }
                    valArray[oldVals.length] = decodeValue(val);
                } else {
                    valArray = new String[1];
                    valArray[0] = decodeValue(val);
                }
                ht.put(key, valArray);
            }
            return ht;
        }


        private static byte[] readBytes(InputStream in) throws IOException {
            BufferedInputStream bufin = new BufferedInputStream(in);
            final int buffSize = 1024;
            ByteArrayOutputStream out = new ByteArrayOutputStream(buffSize);


            byte[] temp = new byte[buffSize];
            int size = 0;
            while ((size = bufin.read(temp)) != -1) {
                out.write(temp, 0, size);
            }
            out.flush();
            byte[] content = out.toByteArray();
            bufin.close();
            out.close();
            return content;
        }


        /**
         * 自定义解码函数
         * 
         * @param value
         * @return
         */
        private String decodeValue(String value) {
            if (value.contains("%u")) {
                return Encodes.urlDecode(value);
            } else {
                try {
                    return URLDecoder.decode(value, UTF_8);
                } catch (UnsupportedEncodingException e) {
                    return StringUtils.EMPTY;// 非UTF-8编码
                }
            }
        }


    }


自定义Filter   HttpServletRequestReplacedFilter.java

[java]  view plain  copy
  1. package com.banxue.handle.web;


    import java.io.IOException;


    import javax.servlet.Filter;
    import javax.servlet.FilterChain;
    import javax.servlet.FilterConfig;
    import javax.servlet.ServletException;
    import javax.servlet.ServletRequest;
    import javax.servlet.ServletResponse;
    import javax.servlet.http.HttpServletRequest;


    public class HttpServletRequestReplacedFilter implements Filter {


        @Override
        public void doFilter(ServletRequest request, ServletResponse response,
                FilterChain chain) throws IOException, ServletException {


            ServletRequest requestWrapper = null;
            if (request instanceof HttpServletRequest) {
                HttpServletRequest httpServletRequest = (HttpServletRequest) request;
                if ("POST".equals(httpServletRequest.getMethod().toUpperCase())) {
                    requestWrapper = new BodyReaderHttpServletRequestWrapper(
                            (HttpServletRequest) request);
                }
            }
            if (requestWrapper == null) {
                chain.doFilter(request, response);
            } else {
                chain.doFilter(requestWrapper, response);
            }
        }


        @Override
        public void init(FilterConfig arg0) throws ServletException {


        }


        @Override
        public void destroy() {


        }


    }


在web.xml 配置 


[html]  view plain  copy

  1.       
            HttpServletRequestReplacedFilter  
            com.banxue.handle.web.HttpServletRequestReplacedFilter  
              
                encoding  
                utf-8  
           
      
       
      


Encodes 类

[html]  view plain  copy
  1. package com.banxue.utils;


    import java.io.UnsupportedEncodingException;
    import java.net.URLDecoder;
    import java.net.URLEncoder;


    import org.apache.commons.codec.DecoderException;
    import org.apache.commons.codec.binary.Base64;
    import org.apache.commons.codec.binary.Hex;
    import org.apache.commons.lang3.StringEscapeUtils;


    /**
     * 封装各种格式的编码解码工具类.
     * 
     * 1.Commons-Codec的 hex/base64 编码 2.自制的base62 编码 3.Commons-Lang的xml/html escape
     * 4.JDK提供的URLEncoder
     * 
     */
    public class Encodes {


        private static final String DEFAULT_URL_ENCODING = "UTF-8";
        private static final char[] BASE62 = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz".toCharArray();


        /**
         * Hex编码.
         */
        public static String encodeHex(byte[] input) {
            return Hex.encodeHexString(input);
        }


        /**
         * Hex解码.
         */
        public static byte[] decodeHex(String input) {
            try {
                return Hex.decodeHex(input.toCharArray());
            } catch (DecoderException e) {
                e.printStackTrace();
            }
            return null;
        }


        /**
         * Base64编码.
         */
        public static String encodeBase64(byte[] input) {
            return Base64.encodeBase64String(input);
        }


        /**
         * Base64编码, URL安全(将Base64中的URL非法字符'+'和'/'转为'-'和'_', 见RFC3548).
         */
        public static String encodeUrlSafeBase64(byte[] input) {
            return Base64.encodeBase64URLSafeString(input);
        }


        /**
         * Base64解码.
         */
        public static byte[] decodeBase64(String input) {
            return Base64.decodeBase64(input);
        }


        /**
         * Base62编码。
         */
        public static String encodeBase62(byte[] input) {
            char[] chars = new char[input.length];
            for (int i = 0; i < input.length; i++) {
                chars[i] = BASE62[(input[i] & 0xFF) % BASE62.length];
            }
            return new String(chars);
        }


        /**
         * Html 转码.
         */
        public static String escapeHtml(String html) {
            return StringEscapeUtils.escapeHtml4(html);
        }


        /**
         * Html 解码.
         */
        public static String unescapeHtml(String htmlEscaped) {
            return StringEscapeUtils.unescapeHtml4(htmlEscaped);
        }


        /**
         * Xml 转码.
         */
        public static String escapeXml(String xml) {
            return StringEscapeUtils.escapeXml(xml);
        }


        /**
         * Xml 解码.
         */
        public static String unescapeXml(String xmlEscaped) {
            return StringEscapeUtils.unescapeXml(xmlEscaped);
        }


        /**
         * URL 编码, Encode默认为UTF-8.
         */
        public static String urlEncode(String part) {
            try {
                return URLEncoder.encode(part, DEFAULT_URL_ENCODING);
            } catch (UnsupportedEncodingException e) {
                e.printStackTrace();
            }
            return null;
        }


        /**
         * URL 解码, Encode默认为UTF-8.
         */
        public static String urlDecode(String part) {
            try {
                return URLDecoder.decode(part, DEFAULT_URL_ENCODING);
            } catch (UnsupportedEncodingException e) {
                e.printStackTrace();
            }
            return null;
        }
    }

在springmvc.xml配置文件中配置拦截器

处理request.getInputStream()输入流只能读取一次问题_第1张图片

参考博客:转载自:https://blog.csdn.net/heng_ji/article/details/54893352



你可能感兴趣的:(技术问题)