request.getInputStream()输入流只能读取一次问题

背景

通常对安全性有要求的接口都会对请求参数做一些签名验证,而我们一般会把验签的逻辑统一放到过滤器或拦截器里,这样就不用每个接口都去重复编写验签的逻辑。

在一个项目中会有很多的接口,而不同的接口可能接收不同类型的数据,例如表单数据和json数据,表单数据还好说,调用request的getParameterMap就能全部取出来。而json数据就有些麻烦了,因为json数据放在body中,我们需要通过request的输入流去读取。

但问题在于request的输入流只能读取一次不能重复读取,所以我们在过滤器或拦截器里读取了request的输入流之后,请求走到controller层时就会报错。而本文的目的就是介绍如何解决在这种场景下遇到HttpServletRequest的输入流只能读取一次的问题。

HttpServletRequest的输入流只能读取一次的原因

我们先来看看为什么HttpServletRequest的输入流只能读一次,当我们调用getInputStream()方法获取输入流时得到的是一个InputStream对象,而实际类型是ServletInputStream,它继承于InputStream。

InputStream的read()方法内部有一个postion,标志当前流被读取到的位置,每读取一次,该标志就会移动一次,如果读到最后,read()会返回-1,表示已经读取完了。如果想要重新读取则需要调用reset()方法,position就会移动到上次调用mark的位置,mark默认是0,所以就能从头再读了。调用reset()方法的前提是已经重写了reset()方法,当然能否reset也是有条件的,它取决于markSupported()方法是否返回true。

InputStream默认不实现reset(),并且markSupported()默认也是返回false,这一点查看其源码便知:
request.getInputStream()输入流只能读取一次问题_第1张图片



我们再来看看`ServletInputStream`,可以看到该类没有重写mark(),reset()以及markSupported()方法:

request.getInputStream()输入流只能读取一次问题_第2张图片


综上,InputStream默认不实现reset的相关方法,而ServletInputStream也没有重写reset的相关方法,这样就无法重复读取流,这就是我们从request对象中获取的输入流就只能读取一次的原因。



使用自定义的HttpServletRequest包装类

既然ServletInputStream不支持重复读流中的数据,那么我们就自定义一个子类,对HttpSerlvetRequest类进行包装增强,将ServletInputStream中的流数据保存起来, 在读request流数据时,将保存起来的数据返回即可。

所幸JavaEE API中提供了一个HttpServletRequestWrapper类,从类名可以看出它是一个Http请求包装器,是基于装饰者模式并实现了HttpServletRequest接口。
request.getInputStream()输入流只能读取一次问题_第3张图片
从上图源码可以看到,该类并没有真正去实现HttpServletRequest的方法,而是在方法内又去调用ServletRequest的方法。
所以我们可以通过继承该类并实现想要重新定义的方法已达到包装原生SerlvetRequest对象的目的。




解决方案:
首先我们需要定义一个容器,将输入流中的数据保存进容器中,然后重写HttpServletRequestWrapper中的getInputStream()方法,每次都从这个容器中读数据。这样request的输入流就可以重复读了。

/**
 * 解决request流只读取一次的问题
 */
@Slf4j
public class RequestWrapper extends HttpServletRequestWrapper {

    /**
     * 存储body数据的容器
     */
    private final byte[] body;

    public RequestWrapper(HttpServletRequest request) throws IOException {
        super(request);

        // 将body数据存储起来
        body = getBodyString(request).getBytes(Charset.defaultCharset());
    }

    /**
     * 获取请求Body
     *
     * @param request request
     * @return String
     */
    public String getBodyString(final ServletRequest request) {
        try {
            return inputStream2String(request.getInputStream());
        } catch (IOException e) {
            log.error("", e);
            throw new RuntimeException(e);
        }
    }

    /**
     * 获取请求Body
     *
     * @return String
     */
    public String getBodyString() {
        final InputStream inputStream = new ByteArrayInputStream(body);

        return inputStream2String(inputStream);
    }

    /**
     * 将inputStream里的数据读取出来并转换成字符串
     *
     * @param inputStream inputStream
     * @return String
     */
    private String inputStream2String(InputStream inputStream) {
        StringBuilder sb = new StringBuilder();
        BufferedReader reader = null;

        try {
            reader = new BufferedReader(new InputStreamReader(inputStream, Charset.defaultCharset()));
            String line;
            while ((line = reader.readLine()) != null) {
                sb.append(line);
            }
        } catch (IOException e) {
            log.error("", e);
            throw new RuntimeException(e);
        } finally {
            if (reader != null) {
                try {
                    reader.close();
                } catch (IOException e) {
                    log.error("", e);
                }
            }
        }

        return sb.toString();
    }

    @Override
    public BufferedReader getReader() throws IOException {
        return new BufferedReader(new InputStreamReader(getInputStream()));
    }

    @Override
    public ServletInputStream getInputStream() throws IOException {

        final ByteArrayInputStream inputStream = new ByteArrayInputStream(body);

        return new ServletInputStream() {
            @Override
            public int read() throws IOException {
                return inputStream.read();
            }

            @Override
            public boolean isFinished() {
                return false;
            }

            @Override
            public boolean isReady() {
                return false;
            }

            @Override
            public void setReadListener(ReadListener readListener) {
            }
        };
    }

}


现在有了包装类,该怎么将原本的`HttpServletRequest`替换成我们自己的request呢?使用Filter在doFilter(req,resp)中将req替换成我们自己的`RequestWrapper`就可以实现。

/**
 * 解决request流只读取一次的问题
 */
@Slf4j
public class ReplaceStreamFilter implements Filter {

    @Override
    public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException {
        ServletRequest requestWrapper = new RequestWrapper((HttpServletRequest) request);
        chain.doFilter(requestWrapper, response);
    }
}

@Configuration
public class FilterConfig {
    /**
     * 注册过滤器
     *
     * @return FilterRegistrationBean
     */
    @Bean
    public FilterRegistrationBean someFilterRegistration() {
        FilterRegistrationBean registration = new FilterRegistrationBean();
        registration.setFilter(replaceStreamFilter());
        registration.addUrlPatterns("/*");
        registration.setName("streamFilter");
        return registration;
    }
 
    /**
     * 实例化StreamFilter
     *
     * @return Filter
     */
    @Bean(name = "replaceStreamFilter")
    public Filter replaceStreamFilter() {
        return new ReplaceStreamFilter();
    }
}


然后我们就可以在拦截器中愉快的获取request流中的数据,麻麻再也不用担心Controller了~

@Slf4j
public class SignatureInterceptor implements HandlerInterceptor {
    @Override
    public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {
        log.info("[preHandle] executing... request uri is {}", request.getRequestURI());
        if (isJson(request)) {
            // 获取json字符串
            String jsonParam = new RequestWrapper(request).getBodyString();
            log.info("[preHandle] json数据 : {}", jsonParam);
 
            // 验签逻辑...略...
        }
 
        return true;
    }

    /**
     * 判断本次请求的数据类型是否为json
     *
     * @param request request
     * @return boolean
     */
    private boolean isJson(HttpServletRequest request) {
        if (request.getContentType() != null) {
            return request.getContentType().equals(MediaType.APPLICATION_JSON_VALUE) ||
                    request.getContentType().equals(MediaType.APPLICATION_JSON_UTF8_VALUE);
        }
 
        return false;
    }
}

接下来我们就可以测试一下在拦截器中读取了输入流后在controller层是否还能正常接收数据,代码如下:

@RestController
@RequestMapping("/user")
public class DemoController {
 
    @PostMapping("/register")
    public UserParam register(@RequestBody UserParam userParam){
        return userParam;
    }
}

启动项目,请求结果如下,可以看到controller正常接收到数据并返回了

你可能感兴趣的:(学习笔记,java,spring,boot,spring,servlet)