SpringCloud Gateway 解析获取请求参数并封装传递到Controller

目录

    • 前言
    • 代码
      • 1. 定义请求封装实体`OAuthRequest.java`
      • 2. 定义抽象请求构造工厂类`OAuthRequestFactory.java`
      • 3. 定义默认WebFlux请求构造工厂实现`WebFluxOAuthRequestFactory.java`
      • 4. 定义请求实体线程参数容器`OAuthRequestContainer.java`
      • 5. 定义一个抽象过滤器`AbstractGatewayFilter.java`用来封装请求解析方法;
      • 6. 定义`CertifiedEntryWebfluxFilter.java`过滤器拦截请求并解析封装请求信息
      • 7. 测试Controller输出参数

前言

Gateway使用过滤器解析请求,封装所有参数。通过线程变量传输到Controller

代码

1. 定义请求封装实体OAuthRequest.java

import org.feasy.cloud.auth.core.authentication.OAuthAuthentication;
import org.feasy.cloud.auth.core.model.ClientDetails;

import java.util.Map;
import java.util.Set;

/**
 * @author YangXiaohui
 * @since 2020/12/31 14:03
 */
public class OAuthRequest {
    /**
     * 请求参数
     */
    private Map<String, String> parameters;
    /**
     * 请求头
     */
    private Map<String, String> headers;
    /**
     * 请求方式:POST、GET、PUT、DELETE
     */
    private String method;
    /**
     * 请求全路径
     */
    private String requestURL;
    /**
     * 请求路径
     */
    private String requestURI;
    /**
     * 请求地址参数
     */
    private String queryString;
    /**
     * 请求来源地址
     */
    private String remoteHost;

    /**
     * 请求认证信息
     */
    private OAuthAuthentication authentication;

    /**
     * 请求客户端信息
     */
    private ClientDetails requestClientDetails;

    public OAuthRequest() {
    }

    public OAuthRequest(Map<String, String> parameters, Map<String, String> headers, String method, String requestURL, String requestURI, String queryString, String remoteHost) {
        this.parameters = parameters;
        this.headers = headers;
        this.method = method;
        this.requestURL = requestURL;
        this.requestURI = requestURI;
        this.queryString = queryString;
        this.remoteHost = remoteHost;
    }

    public OAuthRequest(Map<String, String> parameters, Map<String, String> headers, String method, String requestURL, String requestURI, String queryString, String remoteHost, OAuthAuthentication authentication) {
        this.parameters = parameters;
        this.headers = headers;
        this.method = method;
        this.requestURL = requestURL;
        this.requestURI = requestURI;
        this.queryString = queryString;
        this.remoteHost = remoteHost;
        this.authentication = authentication;
    }


    /**
     * 获取请求参数
     *
     * @param name 参数名
     * @return 请求参数
     */
    public String getParameter(String name) {
        return parameters.get(name);
    }

    public Map<String, String> getParameters() {
        return parameters;
    }

    public OAuthRequest setParameters(Map<String, String> parameters) {
        this.parameters = parameters;
        return this;
    }

    /**
     * 获取请求头
     *
     * @param name 参数名
     * @return 请求头信息
     */
    public String getHeader(String name) {
        return headers.get(name);
    }

    public Map<String, String> getHeaders() {
        return headers;
    }

    public OAuthRequest setHeaders(Map<String, String> headers) {
        this.headers = headers;
        return this;
    }

    public String getMethod() {
        return method;
    }

    public OAuthRequest setMethod(String method) {
        this.method = method;
        return this;
    }

    public String getRequestURL() {
        return requestURL;
    }

    public OAuthRequest setRequestURL(String requestURL) {
        this.requestURL = requestURL;
        return this;
    }

    public String getRequestURI() {
        return requestURI;
    }

    public OAuthRequest setRequestURI(String requestURI) {
        this.requestURI = requestURI;
        return this;
    }

    public String getQueryString() {
        return queryString;
    }

    public OAuthRequest setQueryString(String queryString) {
        this.queryString = queryString;
        return this;
    }

    public String getRemoteHost() {
        return remoteHost;
    }

    public OAuthRequest setRemoteHost(String remoteHost) {
        this.remoteHost = remoteHost;
        return this;
    }

    public OAuthAuthentication getAuthentication() {
        return authentication;
    }

    public OAuthRequest setAuthentication(OAuthAuthentication authentication) {
        this.authentication = authentication;
        return this;
    }

    /**
     * 该请求是否为认证过的请求
     *
     * @return 返回true表示请求认证过,返回false表示未认证
     */
    public boolean isAuthenticated() {
        return authentication != null;
    }


    public OAuthRequest narrowScope(Set<String> scope) {
        this.parameters.put("scope", String.join(",", scope.toArray(new String[]{})));
        return this;
    }
}

2. 定义抽象请求构造工厂类OAuthRequestFactory.java

import org.feasy.cloud.util.lang.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Map;

/**
 * 请求解析工厂类
 *
 * @author YangXiaohui
 * @since 2021/1/4 11:52
 */
public abstract class OAuthRequestFactory {
    private static final Logger logger= LoggerFactory.getLogger(OAuthRequestFactory.class);

    /**
     * 构造请求实体
     *
     * @param httpRequest SpringMvc下传入HttpServletRequest
     * @return {@link OAuthRequest} 请求实体
     */
    public abstract OAuthRequest createRequest(Object httpRequest);
    /**
     * 构造封装请求实体
     *
     * @param headers     请求头信息
     * @param parameters  请求参数
     * @param remoteHost  请求来源IP
     * @param method      请求方式:POST、GET...
     * @param requestURL  请求全路径
     * @param requestURI  请求路径
     * @param queryString 请求路径参数
     */
    protected OAuthRequest buildRequest(Map<String, String> parameters, Map<String, String> headers, String method, String requestURL, String requestURI, String queryString, String remoteHost) {
        final String token = headers.get("HEADER_TOKEN.toLowerCase()");
        final String clientToken = headers.get("HEADER_TOKEN.toLowerCase()");
        // 判断是否包含认证OAuthAuthentication字段
        if (StringUtils.isNotEmpty(token)) {
            // TODO 解析令牌
            //final OAuthAuthentication authentication = resourceServerTokenServices.loadAuthentication(token);
            if (StringUtils.isNotEmpty(clientToken)){
                // TODO 解析请求Client令牌
            }
            return new OAuthRequest(parameters, headers, method, requestURL, requestURI, queryString, remoteHost);
        }
        return new OAuthRequest(parameters, headers, method, requestURL, requestURI, queryString, remoteHost);
    }
}

3. 定义默认WebFlux请求构造工厂实现WebFluxOAuthRequestFactory.java

import org.feasy.cloud.auth.core.request.OAuthRequest;
import org.feasy.cloud.auth.core.request.OAuthRequestFactory;
import org.feasy.cloud.util.lang.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.http.server.reactive.ServerHttpRequest;

import java.net.URI;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
 * @author YangXiaohui
 * @since 2020/12/31 14:45
 */
public class WebFluxOAuthRequestFactory extends OAuthRequestFactory {
    private static final Logger logger = LoggerFactory.getLogger(WebFluxOAuthRequestFactory.class);


    /**
     * 构造请求实体
     *
     * @param httpRequest SpringMvc下传入HttpServletRequest
     * @return {@link OAuthRequest} 请求实体
     */
    @Override
    public OAuthRequest createRequest(Object httpRequest) {
        ServerHttpRequest request = (ServerHttpRequest) httpRequest;
        final String sourceIp = analysisSourceIp(request);
        final URI uri=request.getURI();
        final String url=uri.getHost()+":"+uri.getPort()+uri.getPath()+"?"+uri.getQuery();
        final Map<String, String> headersMap = getHeadersMap(request);
        return this.buildRequest(null, headersMap, request.getMethodValue().toUpperCase(), url, uri.getPath(), uri.getQuery(), sourceIp);
    }

    /**
     * 获取客户端真实IP
     */
    protected String analysisSourceIp(ServerHttpRequest request) {
        String ip = null;
        //X-Forwarded-For:Squid 服务代理
        String ipAddresses = request.getHeaders().getFirst("X-Forwarded-For");
        if (ipAddresses == null || ipAddresses.length() == 0 || "unknown".equalsIgnoreCase(ipAddresses)) {        //Proxy-Client-IP:apache 服务代理
            ipAddresses = request.getHeaders().getFirst("Proxy-Client-IP");
        }
        if (ipAddresses == null || ipAddresses.length() == 0 || "unknown".equalsIgnoreCase(ipAddresses)) {        //WL-Proxy-Client-IP:weblogic 服务代理
            ipAddresses = request.getHeaders().getFirst("WL-Proxy-Client-IP");
        }
        if (ipAddresses == null || ipAddresses.length() == 0 || "unknown".equalsIgnoreCase(ipAddresses)) {        //HTTP_CLIENT_IP:有些代理服务器
            ipAddresses = request.getHeaders().getFirst("HTTP_CLIENT_IP");
        }
        if (ipAddresses == null || ipAddresses.length() == 0 || "unknown".equalsIgnoreCase(ipAddresses)) {        //X-Real-IP:nginx服务代理
            ipAddresses = request.getHeaders().getFirst("X-Real-IP");
        }    //有些网络通过多层代理,那么获取到的ip就会有多个,一般都是通过逗号(,)分割开来,并且第一个ip为客户端的真实IP
        if (ipAddresses != null && ipAddresses.length() != 0) {
            ip = ipAddresses.split(",")[0];
        }    //还是不能获取到,最后再通过request.getRemoteAddr();获取
        if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ipAddresses)) {
            ip = request.getRemoteAddress().getHostString();
        }
        return ip;
    }

    /**
     * 获取所有Header信息
     */
    private Map<String, String> getHeadersMap(ServerHttpRequest request) {
        final HashMap<String, String> headerMap = new HashMap<>();
        for (String key : request.getHeaders().keySet()) {
            final List<String> stringList = request.getHeaders().get(key);
            headerMap.put(key, stringList != null && !stringList.isEmpty() ? StringUtils.join(stringList.toArray(), ",") : null);
        }
        return headerMap;
    }


}

4. 定义请求实体线程参数容器OAuthRequestContainer.java


/**
 * 请求信息容器
 *
 * @author YangXiaohui
 * @since 2020/11/25 10:55
 */
public class OAuthRequestContainer {
    private static ThreadLocal<OAuthRequest> local = new InheritableThreadLocal<>();

    private OAuthRequestContainer() {
    }

    public static void set(OAuthRequest request) {
        local.set(request);
    }

    public static OAuthRequest get() {
        return local.get();
    }

    public static void remove() {
        local.remove();
    }
    public static void rewriteOAuthRequestContainer(ThreadLocal<OAuthRequest> request){
        local=request;
    }
}

5. 定义一个抽象过滤器AbstractGatewayFilter.java用来封装请求解析方法;

import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import org.feasy.cloud.util.lang.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.cloud.gateway.filter.GlobalFilter;
import org.springframework.cloud.gateway.filter.factory.rewrite.CachedBodyOutputMessage;
import org.springframework.core.Ordered;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferUtils;
import org.springframework.http.HttpHeaders;
import org.springframework.http.server.reactive.ServerHttpRequestDecorator;
import org.springframework.util.MultiValueMap;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import java.util.List;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

/**
 * @author YangXiaohui
 * @since 2020/12/31 9:24
 */
public abstract class AbstractGatewayFilter implements GlobalFilter, Ordered {
    private final Logger logger = LoggerFactory.getLogger(AbstractGatewayFilter.class);

    protected final static String parameterReg = "-{28}([0-9]{24})\r\n.+name=\"(\\S*)\"\r\n\r\n(\\S*)";
    protected final static String fileParameterReg = "-{28}([0-9]{24})\r\n.+name=\"(\\S*)\"; filename=\"(\\S*)\"\r\n.*\r\n\r\n";


    protected void parseRequestBody(Map<String, String> parameterMap, String parameterString) {
        this.regexParseBodyString(parameterReg, parameterMap, parameterString);
        this.regexParseBodyString(fileParameterReg, parameterMap, parameterString);
    }

    protected void parseRequestJson(Map<String, String> parameterMap, String parameterString) {
        final JSONObject object = JSON.parseObject(parameterString);
        for (String key : object.keySet()) {
            parameterMap.put(key, object.getString(key));
        }
    }

    protected void parseRequestQuery(Map<String, String> parameterMap, MultiValueMap<String, String> queryParamMap) {
        if (queryParamMap != null && !queryParamMap.isEmpty()) {
            for (String key : queryParamMap.keySet()) {
                final List<String> stringList = queryParamMap.get(key);
                parameterMap.put(key, stringList != null && !stringList.isEmpty() ? StringUtils.join(stringList.toArray(), ",") : null);
            }
        }
    }

    protected void parseRequestQuery(Map<String, String> parameterMap, String parameterString) {
        final String[] paramsStr = parameterString.split("&");
        for (String s : paramsStr) {
            logger.info("请求名:" + s.split("=")[0]);
            logger.info("请求值:" + s.split("=")[1]);
            parameterMap.put(s.split("=")[0], s.split("=")[1]);
        }
    }

    protected void regexParseBodyString(String reg, Map<String, String> parameterMap, String bodyStr) {
        Matcher matcher = Pattern.compile(reg).matcher(bodyStr);
        while (matcher.find()) {
            parameterMap.put(matcher.group(2), matcher.group(3));
            logger.info("请求参数编号:" + matcher.group(1));
            logger.info("请求名:" + matcher.group(2));
            logger.info("请求值:" + matcher.group(3));
        }
    }


    protected Mono<Void> release(ServerWebExchange exchange,
                                 CustomCachedBodyOutputMessage outputMessage, Throwable throwable) {
        if (outputMessage.isCached()) {
            return outputMessage.getBody().map(DataBufferUtils::release)
                    .then(Mono.error(throwable));
        }
        return Mono.error(throwable);
    }

    protected ServerHttpRequestDecorator decorate(ServerWebExchange exchange, HttpHeaders headers,
                                                  CachedBodyOutputMessage outputMessage) {
        return new ServerHttpRequestDecorator(exchange.getRequest()) {
            @Override
            public HttpHeaders getHeaders() {
                long contentLength = headers.getContentLength();
                HttpHeaders httpHeaders = new HttpHeaders();
                httpHeaders.putAll(super.getHeaders());
                if (contentLength > 0) {
                    httpHeaders.setContentLength(contentLength);
                } else {
                    httpHeaders.set(HttpHeaders.TRANSFER_ENCODING, "chunked");
                }
                return httpHeaders;
            }

            @Override
            public Flux<DataBuffer> getBody() {
                return outputMessage.getBody();
            }
        };
    }

}

6. 定义CertifiedEntryWebfluxFilter.java过滤器拦截请求并解析封装请求信息

import com.alibaba.fastjson.JSON;
import org.feasy.cloud.auth.core.request.OAuthRequest;
import org.feasy.cloud.auth.core.request.OAuthRequestContainer;
import org.feasy.cloud.auth.core.request.OAuthRequestFactory;
import org.feasy.cloud.auth.starter.request.WebFluxOAuthRequestFactory;
import org.feasy.cloud.auth.starter.webflux.AbstractGatewayFilter;
import org.feasy.cloud.auth.starter.webflux.CustomCachedBodyOutputMessage;
import org.feasy.cloud.util.lang.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
import org.springframework.cloud.gateway.config.GatewayAutoConfiguration;
import org.springframework.cloud.gateway.filter.GatewayFilterChain;
import org.springframework.cloud.gateway.filter.GlobalFilter;
import org.springframework.cloud.gateway.support.BodyInserterContext;
import org.springframework.core.Ordered;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.http.ReactiveHttpOutputMessage;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.stereotype.Component;
import org.springframework.web.reactive.function.BodyInserter;
import org.springframework.web.reactive.function.BodyInserters;
import org.springframework.web.reactive.function.server.HandlerStrategies;
import org.springframework.web.reactive.function.server.ServerRequest;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Schedulers;

import java.util.HashMap;
import java.util.Map;
import java.util.function.Function;

/**
 * @author YangXiaohui
 * @since 2021/1/4 11:47
 */
@Component
@ConditionalOnClass(GatewayAutoConfiguration.class)
public class CertifiedEntryWebfluxFilter extends AbstractGatewayFilter implements GlobalFilter, Ordered {
    private final Logger logger = LoggerFactory.getLogger(CertifiedEntryWebfluxFilter.class);

    @Override
    public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
        ServerRequest serverRequest = ServerRequest.create(exchange, HandlerStrategies.withDefaults().messageReaders());
        //获取参数类型
        String contentType = exchange.getRequest().getHeaders().getFirst(HttpHeaders.CONTENT_TYPE);
        // 解析参数
        OAuthRequestFactory requestFactory = new WebFluxOAuthRequestFactory();
        OAuthRequest authRequest = requestFactory.createRequest(exchange.getRequest());
        Map<String, String> requestParamsMap = new HashMap<>();
        Mono<String> modifiedBody = serverRequest.bodyToMono(String.class)
                .publishOn(Schedulers.immediate())
                .flatMap(originalBody -> {
                    // 根据请求头,用不同的方式解析Body
                    if (StringUtils.isNotEmpty(contentType)) {
                        if (contentType.startsWith(MediaType.MULTIPART_FORM_DATA_VALUE)) {
                            this.parseRequestBody(requestParamsMap, originalBody);
                        } else if (contentType.startsWith(MediaType.APPLICATION_JSON_VALUE)) {
                            this.parseRequestJson(requestParamsMap, originalBody);
                        } else if (contentType.startsWith(MediaType.APPLICATION_FORM_URLENCODED_VALUE)) {
                            this.parseRequestQuery(requestParamsMap, originalBody);
                        }
                    }
                    // 加载QueryParameter
                    this.parseRequestQuery(requestParamsMap, exchange.getRequest().getQueryParams());
                    logger.info("所有参数:{}", JSON.toJSONString(requestParamsMap));
                    // 把信息放置到线程容器内
                    authRequest.setParameters(requestParamsMap);
                    OAuthRequestContainer.set(authRequest);
                    return Mono.just(originalBody);
                });
        logger.info("所有参数:{}", JSON.toJSONString(requestParamsMap));
        // 把修改过的参数、消费过的参数,重新封装发布
        BodyInserter<Mono<String>, ReactiveHttpOutputMessage> bodyInserter = BodyInserters.fromPublisher(modifiedBody, String.class);
        HttpHeaders headers = new HttpHeaders();
        headers.putAll(exchange.getRequest().getHeaders());
        headers.remove(HttpHeaders.CONTENT_LENGTH);
        CustomCachedBodyOutputMessage outputMessage = new CustomCachedBodyOutputMessage(
                exchange, headers);
        Mono<Void> result = bodyInserter.insert(outputMessage, new BodyInserterContext())
                .then(Mono.defer(() -> {
                    ServerHttpRequest decorator = decorate(exchange, headers,
                            outputMessage);
                    return chain
                            .filter(exchange.mutate().request(decorator).build());
                })).onErrorResume(
                        (Function<Throwable, Mono<Void>>) throwable -> release(
                                exchange, outputMessage, throwable));
        return result;
    }

    @Override
    public int getOrder() {
        return Ordered.HIGHEST_PRECEDENCE + 1000;
    }
}

7. 测试Controller输出参数

import com.alibaba.fastjson.JSON;
import org.feasy.cloud.auth.core.request.OAuthRequest;
import org.feasy.cloud.auth.core.request.OAuthRequestContainer;
import org.feasy.cloud.util.result.Result;
import org.feasy.cloud.util.result.ResultBuilder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RestController;

/**
* 测试Controller
* @author YangXiaohui
* @since 2021/1/4 14:46
*/
@RestController
public class GatewayTestController {
  private static final Logger logger= LoggerFactory.getLogger(GatewayTestController.class);
  @PostMapping("/save")
  public Result<Object> save(){
      final OAuthRequest authRequest = OAuthRequestContainer.get();
      logger.info("请求参数信息:"+JSON.toJSONString(authRequest));
      // 清理容器,防止内存泄漏
      OAuthRequestContainer.remove();
      return ResultBuilder.success();
  }
}

你可能感兴趣的:(Spring,Cloud,笔记,Gateway,参数封装,SpringCloud)