最近在做文件上传下载,需要从request中获取文件流,然后我发现,Post请求的request只可以读取一次,之后就读不到了,思索了一下午想到了解决方法,精髓就是将request的文件流每次进行保存,就可以反复进行读取了,代码如下:
package com.openailab.oascloud.gateway.util;
import com.alibaba.fastjson.JSONObject;
import com.alibaba.fastjson.parser.Feature;
import com.google.common.collect.Maps;
import com.netflix.zuul.context.RequestContext;
import com.netflix.zuul.http.HttpServletRequestWrapper;
import com.netflix.zuul.http.ServletInputStreamWrapper;
import com.openailab.oascloud.common.util.IPUtil;
import com.openailab.oascloud.gateway.filter.CharacterEncodeFilter;
import org.apache.commons.fileupload.FileItem;
import org.apache.commons.fileupload.FileUploadException;
import org.apache.commons.fileupload.disk.DiskFileItemFactory;
import org.apache.commons.fileupload.servlet.ServletFileUpload;
import org.apache.commons.fileupload.servlet.ServletRequestContext;
import org.apache.tomcat.util.http.fileupload.FileItemIterator;
import org.apache.tomcat.util.http.fileupload.FileItemStream;
import org.apache.tomcat.util.http.fileupload.util.Streams;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.util.StreamUtils;
import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServletRequest;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.UnsupportedEncodingException;
import java.nio.charset.Charset;
import java.util.*;
/**
* @Classname: com.openailab.oascloud.gateway.util.ParamUtil
* @Description: 描述
* @Author: zxzhang
* @Date: 2019/6/27
*/
public class ParamUtil {
private static final Logger LOGGER = LoggerFactory.getLogger(CharacterEncodeFilter.class);
/**
* 获取请求参数(Get||Post)
*
* @param ctx
* @return java.util.Map
* @author zxzhang
* @date 2019/6/27
*/
public static Map getRequestParams(RequestContext ctx) {
String method = ctx.getRequest().getMethod();
String uri = ctx.getRequest().getRequestURI();
//判断是POST请求还是GET请求(不同请求获取参数方式不同)
LinkedHashMap param = null;
try {
if (uri.startsWith("/zuul")) {
byte[] requestByte = saveaIns(ctx.getRequest().getInputStream());
rewriteRequest(ctx, requestByte);
param = Maps.newLinkedHashMap();
DiskFileItemFactory factory = new DiskFileItemFactory();
ServletFileUpload upload = new ServletFileUpload(factory);
upload.setHeaderEncoding("UTF-8");
List list = upload.parseRequest(new ServletRequestContext(ctx.getRequest()));
for (FileItem item : list) {
String name = item.getFieldName();
if (item.isFormField()) {
String value = item.getString("UTF-8");
param.put(name, value);
} else {
String filename = item.getName();
param.put(name, filename);
}
}
rewriteRequest(ctx, requestByte);
} else {
if ("GET".equals(method.toUpperCase())) {
Map> map = ctx.getRequestQueryParams();
if (!(Objects.isNull(map) || map.isEmpty())) {
param = Maps.newLinkedHashMap();
for (Map.Entry> entry : map.entrySet()) {
param.put(entry.getKey(), entry.getValue().get(0));
}
}
} else if ("POST".equals(method.toUpperCase())) {
try (InputStream inputStream = ctx.getRequest().getInputStream()) {
String body = StreamUtils.copyToString(inputStream, Charset.forName("UTF-8"));
LOGGER.info("***************原始参数:{}***************", body);
param = JSONObject.parseObject(body, LinkedHashMap.class, Feature.OrderedField);
} catch (IOException e) {
e.printStackTrace();
}
}
}
} catch (Exception e) {
LOGGER.error("***************ParamUtil->getRequestParams throw Exception:{}***************", e);
}
return param;
}
/**
* 设置请求参数(Get||Post)
*
* @param ctx
* @return void
* @author zxzhang
* @date 2019/6/28
*/
public static void setRequestParams(RequestContext ctx) {
String uri = ctx.getRequest().getRequestURI();
String source = uri.indexOf("/platform") > 0 ? "platform" : "application";
String client = uri.indexOf("/web") > 0 ? "web" : "app";
String method = ctx.getRequest().getMethod();
HttpServletRequest request = ctx.getRequest();
if (uri.startsWith("/zuul")) {
return;
} else {
//判断是POST请求还是GET请求(不同请求获取参数方式不同)
if ("GET".equals(method.toUpperCase())) {
Map> param = ctx.getRequestQueryParams();
if (Objects.isNull(param) || param.isEmpty()) {
param = Maps.newHashMap();
}
param.put("source", Arrays.asList(source));
param.put("client", Arrays.asList(client));
param.put("ip", Arrays.asList(IPUtil.getClientIp(request)));
ctx.setRequestQueryParams(param);
} else if ("POST".equals(method.toUpperCase())) {
try (InputStream inputStream = ctx.getRequest().getInputStream()) {
String body = StreamUtils.copyToString(inputStream, Charset.forName("UTF-8"));
Map param = JSONObject.parseObject(body);
if (Objects.isNull(param) || param.isEmpty()) {
param = Maps.newHashMap();
}
param.put("source", source);
param.put("client", client);
param.put("ip", IPUtil.getClientIp(request));
// 重写上下文的HttpServletRequestWrapper
final byte[] paramBytes = param.toString().getBytes();
rewriteRequest(ctx, paramBytes);
} catch (IOException e) {
e.printStackTrace();
}
}
}
}
/**
* 重新写入生成request
*
* @param ctx
* @param paramBytes
* @return void
* @author zxzhang
* @date 2019/10/21
*/
private static void rewriteRequest(RequestContext ctx, byte[] paramBytes) {
ctx.setRequest(new HttpServletRequestWrapper(ctx.getRequest()) {
@Override
public ServletInputStream getInputStream() throws IOException {
return new ServletInputStreamWrapper(paramBytes);
}
@Override
public int getContentLength() {
return paramBytes.length;
}
@Override
public long getContentLengthLong() {
return paramBytes.length;
}
});
}
/**
* 保存流对象(输入流在第二次使用的时候会失效)
* 在需要用到InputStream的地方再封装成InputStream
* ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(buf);
* Workbook wb = new HSSFWorkbook(byteArrayInputStream);//byteArrayInputStream 继承了InputStream,故这样用并没有问题
* 如果只需要用到一次inputstream流,就不用这样啦,直接用就OK
*
* @param ins
*/
public static byte[] saveaIns(InputStream ins) {
byte[] buf = null;
try {
if (ins != null) {
buf = org.apache.commons.io.IOUtils.toByteArray(ins);//ins为InputStream流
}
} catch (IOException e) {
e.printStackTrace();
}
return buf;
}
}
这段代码是在zuul的过滤器中重复获取Post请求文件流并重复读取利用的方法。