Java语言Post请求的request只可以读取一次的问题解决

最近在做文件上传下载,需要从request中获取文件流,然后我发现,Post请求的request只可以读取一次,之后就读不到了,思索了一下午想到了解决方法,精髓就是将request的文件流每次进行保存,就可以反复进行读取了,代码如下:

package com.openailab.oascloud.gateway.util;

import com.alibaba.fastjson.JSONObject;
import com.alibaba.fastjson.parser.Feature;
import com.google.common.collect.Maps;
import com.netflix.zuul.context.RequestContext;
import com.netflix.zuul.http.HttpServletRequestWrapper;
import com.netflix.zuul.http.ServletInputStreamWrapper;
import com.openailab.oascloud.common.util.IPUtil;
import com.openailab.oascloud.gateway.filter.CharacterEncodeFilter;
import org.apache.commons.fileupload.FileItem;
import org.apache.commons.fileupload.FileUploadException;
import org.apache.commons.fileupload.disk.DiskFileItemFactory;
import org.apache.commons.fileupload.servlet.ServletFileUpload;
import org.apache.commons.fileupload.servlet.ServletRequestContext;
import org.apache.tomcat.util.http.fileupload.FileItemIterator;
import org.apache.tomcat.util.http.fileupload.FileItemStream;
import org.apache.tomcat.util.http.fileupload.util.Streams;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.util.StreamUtils;

import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServletRequest;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.UnsupportedEncodingException;
import java.nio.charset.Charset;
import java.util.*;

/**
 * @Classname: com.openailab.oascloud.gateway.util.ParamUtil
 * @Description: 描述
 * @Author: zxzhang
 * @Date: 2019/6/27
 */
public class ParamUtil {

    private static final Logger LOGGER = LoggerFactory.getLogger(CharacterEncodeFilter.class);

    /**
     * 获取请求参数(Get||Post)
     *
     * @param ctx
     * @return java.util.Map
     * @author zxzhang
     * @date 2019/6/27
     */
    public static Map getRequestParams(RequestContext ctx) {
        String method = ctx.getRequest().getMethod();
        String uri = ctx.getRequest().getRequestURI();
        //判断是POST请求还是GET请求(不同请求获取参数方式不同)
        LinkedHashMap param = null;
        try {
            if (uri.startsWith("/zuul")) {
                byte[] requestByte = saveaIns(ctx.getRequest().getInputStream());
                rewriteRequest(ctx, requestByte);
                param = Maps.newLinkedHashMap();
                DiskFileItemFactory factory = new DiskFileItemFactory();
                ServletFileUpload upload = new ServletFileUpload(factory);
                upload.setHeaderEncoding("UTF-8");
                List list = upload.parseRequest(new ServletRequestContext(ctx.getRequest()));
                for (FileItem item : list) {
                    String name = item.getFieldName();
                    if (item.isFormField()) {
                        String value = item.getString("UTF-8");
                        param.put(name, value);
                    } else {
                        String filename = item.getName();
                        param.put(name, filename);
                    }
                }
                rewriteRequest(ctx, requestByte);
            } else {
                if ("GET".equals(method.toUpperCase())) {
                    Map> map = ctx.getRequestQueryParams();
                    if (!(Objects.isNull(map) || map.isEmpty())) {
                        param = Maps.newLinkedHashMap();
                        for (Map.Entry> entry : map.entrySet()) {
                            param.put(entry.getKey(), entry.getValue().get(0));
                        }
                    }
                } else if ("POST".equals(method.toUpperCase())) {
                    try (InputStream inputStream = ctx.getRequest().getInputStream()) {
                        String body = StreamUtils.copyToString(inputStream, Charset.forName("UTF-8"));
                        LOGGER.info("***************原始参数:{}***************", body);
                        param = JSONObject.parseObject(body, LinkedHashMap.class, Feature.OrderedField);
                    } catch (IOException e) {
                        e.printStackTrace();
                    }
                }
            }
        } catch (Exception e) {
            LOGGER.error("***************ParamUtil->getRequestParams throw Exception:{}***************", e);
        }
        return param;
    }

    /**
     * 设置请求参数(Get||Post)
     *
     * @param ctx
     * @return void
     * @author zxzhang
     * @date 2019/6/28
     */
    public static void setRequestParams(RequestContext ctx) {
        String uri = ctx.getRequest().getRequestURI();
        String source = uri.indexOf("/platform") > 0 ? "platform" : "application";
        String client = uri.indexOf("/web") > 0 ? "web" : "app";
        String method = ctx.getRequest().getMethod();
        HttpServletRequest request = ctx.getRequest();
        if (uri.startsWith("/zuul")) {
            return;
        } else {
            //判断是POST请求还是GET请求(不同请求获取参数方式不同)
            if ("GET".equals(method.toUpperCase())) {
                Map> param = ctx.getRequestQueryParams();
                if (Objects.isNull(param) || param.isEmpty()) {
                    param = Maps.newHashMap();
                }
                param.put("source", Arrays.asList(source));
                param.put("client", Arrays.asList(client));
                param.put("ip", Arrays.asList(IPUtil.getClientIp(request)));
                ctx.setRequestQueryParams(param);
            } else if ("POST".equals(method.toUpperCase())) {
                try (InputStream inputStream = ctx.getRequest().getInputStream()) {
                    String body = StreamUtils.copyToString(inputStream, Charset.forName("UTF-8"));
                    Map param = JSONObject.parseObject(body);
                    if (Objects.isNull(param) || param.isEmpty()) {
                        param = Maps.newHashMap();
                    }
                    param.put("source", source);
                    param.put("client", client);
                    param.put("ip", IPUtil.getClientIp(request));
                    // 重写上下文的HttpServletRequestWrapper
                    final byte[] paramBytes = param.toString().getBytes();
                    rewriteRequest(ctx, paramBytes);
                } catch (IOException e) {
                    e.printStackTrace();
                }
            }
        }
    }

    /**
     * 重新写入生成request
     *
     * @param ctx
     * @param paramBytes
     * @return void
     * @author zxzhang
     * @date 2019/10/21
     */
    private static void rewriteRequest(RequestContext ctx, byte[] paramBytes) {
        ctx.setRequest(new HttpServletRequestWrapper(ctx.getRequest()) {
            @Override
            public ServletInputStream getInputStream() throws IOException {
                return new ServletInputStreamWrapper(paramBytes);
            }

            @Override
            public int getContentLength() {
                return paramBytes.length;
            }

            @Override
            public long getContentLengthLong() {
                return paramBytes.length;
            }
        });
    }


    /**
     * 保存流对象(输入流在第二次使用的时候会失效)
     * 在需要用到InputStream的地方再封装成InputStream
     * ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(buf);
     * Workbook wb = new HSSFWorkbook(byteArrayInputStream);//byteArrayInputStream 继承了InputStream,故这样用并没有问题
     * 如果只需要用到一次inputstream流,就不用这样啦,直接用就OK
     *
     * @param ins
     */
    public static byte[] saveaIns(InputStream ins) {
        byte[] buf = null;
        try {
            if (ins != null) {
                buf = org.apache.commons.io.IOUtils.toByteArray(ins);//ins为InputStream流
            }
        } catch (IOException e) {
            e.printStackTrace();
        }
        return buf;
    }
}

这段代码是在zuul的过滤器中重复获取Post请求文件流并重复读取利用的方法。

你可能感兴趣的:(Java文件流处理)