重写webFlux中的webFilter,随意跳转到指定的过滤器,其实很简单

背景:由于webFlux中的webFilter没有对指定请求路径进行分流功能,因为过滤器就像一根管道,默认情况下过滤器执行的顺序是已经固定好了的,谁也不能插队,所以我希望过滤器可以通过我的请求来分发到不同的过滤器来执行。

原理:重写 webFlux中 的 DefaultWebFilterChain.java实现类 与 WebFilterChain.java接口

DefaultWebFilterChain.java代码如下:

 

/*
 * Copyright 2002-2017 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.springframework.web.server.handler;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Objects;

import lombok.extern.slf4j.Slf4j;
import reactor.core.publisher.Mono;

import org.springframework.util.Assert;
import org.springframework.util.ObjectUtils;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebFilter;
import org.springframework.web.server.WebFilterChain;
import org.springframework.web.server.WebHandler;

/**
 * Default implementation of {@link WebFilterChain}.
 *
 * @author Rossen Stoyanchev
 * @since 5.0
 */
@Slf4j
public class DefaultWebFilterChain implements WebFilterChain {

   private final List filters;

   private final WebHandler handler;

   private final int index;


   public DefaultWebFilterChain(WebHandler handler, WebFilter... filters) {
      Assert.notNull(handler, "WebHandler is required");
      this.filters = ObjectUtils.isEmpty(filters) ? Collections.emptyList() : Arrays.asList(filters);
      this.handler = handler;
      this.index = 0;
   }

   private DefaultWebFilterChain(DefaultWebFilterChain parent, int index) {
      this.filters = parent.getFilters();
      this.handler = parent.getHandler();
      this.index = index;
   }


   public List getFilters() {
      return this.filters;
   }

   public WebHandler getHandler() {
      return this.handler;
   }


   @Override
   public Mono filter(ServerWebExchange exchange) {
      return Mono.defer(() -> {
         if (this.index < this.filters.size()) {
            WebFilter filter = this.filters.get(this.index);
            WebFilterChain chain = new DefaultWebFilterChain(this, this.index + 1);
            return filter.filter(exchange, chain);
         } else {
            return this.handler.handle(exchange);
         }
      });
   }

   /**
    * 通过索引确定过滤器
    * @param exchange the current server exchange
    * @param index the filter index,you can choose filter to execute
    * @return
    */
   @Override
   public Mono filter(ServerWebExchange exchange, int index) {
      return Mono.defer(() -> {
         if (index < this.filters.size()) {
            WebFilter filter = this.filters.get(index);
            WebFilterChain chain = new DefaultWebFilterChain(this, index + 1);
            return filter.filter(exchange, chain);
         } else {
            return this.handler.handle(exchange);
         }
      });
   }

   /**
    * 通过指定过滤器执行
    * @param exchange the current server exchange
    * @param cls appoint one filter to execute
    * @return
    */
   @Override
   public Mono filter(ServerWebExchange exchange, Class cls) {
      WebFilter filter = null;
      int index = 0;
      try {
         for (int i = 0; i < this.filters.size(); i++) {
            Class targetCls = Class.forName(this.filters.get(i).getClass().getName());
            if (cls.isAssignableFrom(targetCls)) {
               filter = this.filters.get(i);
               index = i;
            }
         }
      } catch (Exception e) {
         log.error("get filter is error,please contact with administrator!");
      }
      if (Objects.nonNull(filter)) {
         WebFilterChain chain = new DefaultWebFilterChain(this, index + 1);
         return filter.filter(exchange, chain);
      } else {
         return this.handler.handle(exchange);
      }
   }
}

 WebFilterChain.java代码如下:

/*
 * Copyright 2002-2015 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.springframework.web.server;

import reactor.core.publisher.Mono;

/**
 * Contract to allow a {@link WebFilter} to delegate to the next in the chain.
 *
 * @author Rossen Stoyanchev
 * @since 5.0
 */
public interface WebFilterChain {

   /**
    * Delegate to the next {@code WebFilter} in the chain.
    * @param exchange the current server exchange
    * @return {@code Mono} to indicate when request handling is complete
    */
   Mono filter(ServerWebExchange exchange);
   /**
    * Delegate to the next {@code WebFilter} in the chain.
    * @param exchange the current server exchange
    * @param index the filter index,you can choose filter to execute
    * @return {@code Mono} to indicate when request handling is complete
    */
   Mono filter(ServerWebExchange exchange,int index);

   /**
    * Delegate to the next {@code WebFilter} in the chain.
    * @param exchange the current server exchange
    * @param cls appoint one filter to execute
    * @return {@code Mono} to indicate when request handling is complete
    */
   Mono filter(ServerWebExchange exchange, Class cls);

} 

接下来,编写自己的过滤器,主要包括三部分:

1、入口的过滤器(DispatcherFilter.java) ,主要用于分发到不同的过滤器

2、业务过滤器(checkFilter.java,自己随意编写)

3、底层过滤器(BaseFilter.java)

一、BaseFilter.java代码如下:

 

package com.test.filter;

import io.netty.buffer.ByteBufAllocator;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.core.Ordered;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferUtils;
import org.springframework.core.io.buffer.NettyDataBufferFactory;
import org.springframework.http.HttpHeaders;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpRequestDecorator;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.stereotype.Component;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebFilter;
import org.springframework.web.server.WebFilterChain;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import java.net.URI;
import java.nio.CharBuffer;
import java.nio.charset.StandardCharsets;
import java.util.Map;
import java.util.concurrent.atomic.AtomicReference;

/**
 * 
 */
@Slf4j
@Component
public abstract class BaseFilter implements WebFilter, Ordered {


    protected   String bodyData="";

    protected  void readRequestData(ServerHttpRequest serverHttpRequest)
    {
        String method = serverHttpRequest.getMethodValue();
        if ("POST".equals(method)) {
            bodyData=resolveBodyFromRequest(serverHttpRequest);
        }
        else if("GET".equals(method))
        {
            bodyData=getContentByGetRequest(serverHttpRequest);
        }
    }


    public void after(ServerWebExchange exchange) {
    }

    @Override
    public Mono filter(ServerWebExchange exchange, WebFilterChain chain) {
        System.out.println("BaseFilter is comming");
        ServerHttpRequest serverHttpRequest = exchange.getRequest();
        ServerHttpResponse response = exchange.getResponse();

        String method = serverHttpRequest.getMethodValue();
        if ("POST".equals(method)) {
            //下面的将请求体再次封装写回到request里,传到下一级,否则,由于请求体已被消费,后续的服务将取不到值
            URI uri = serverHttpRequest.getURI();
            ServerHttpRequest request = serverHttpRequest.mutate().uri(uri).build();
            DataBuffer bodyDataBuffer = stringBuffer(bodyData);
            Flux bodyFlux = Flux.just(bodyDataBuffer);
            request = new ServerHttpRequestDecorator(request) {
                @Override
                public Flux getBody() {
                    return bodyFlux;
                }
            };
            //封装request,传给下一级
            return chain.filter(exchange.mutate().request(request).build());
        } else if ("GET".equals(method)) {
            return chain.filter(exchange);
        }
        return chain.filter(exchange);
    }

    private String getContentByGetRequest(ServerHttpRequest serverHttpRequest) {
        Map requestQueryParams = serverHttpRequest.getQueryParams().toSingleValueMap();
        StringBuffer s = new StringBuffer();
        String content = "";
        requestQueryParams.forEach((k, v) -> {
            System.out.println(k + "  v:" + v);
            s.append(k + "=" + v + "&");
        });
        if (StringUtils.isNotEmpty(s.toString())) {
            content = s.toString().substring(0, s.toString().lastIndexOf('&'));
        }
        return content;
    }

    /**
     * 从Flux中获取字符串的方法
     *
     * @return 请求体
     */
    private String resolveBodyFromRequest(ServerHttpRequest serverHttpRequest) {
        //获取请求体
        Flux body = serverHttpRequest.getBody();

        AtomicReference bodyRef = new AtomicReference<>();
        body.subscribe(buffer -> {
            CharBuffer charBuffer = StandardCharsets.UTF_8.decode(buffer.asByteBuffer());
            DataBufferUtils.release(buffer);
            bodyRef.set(charBuffer.toString());
        });
        //获取request body
        return bodyRef.get();
    }

    private DataBuffer stringBuffer(String value) {
        byte[] bytes = value.getBytes(StandardCharsets.UTF_8);

        NettyDataBufferFactory nettyDataBufferFactory = new NettyDataBufferFactory(ByteBufAllocator.DEFAULT);
        DataBuffer buffer = nettyDataBufferFactory.allocateBuffer(bytes.length);
        buffer.write(bytes);
        return buffer;
    }

    @Override
    public int getOrder() {
        return 0;
    }


}
 

二,DispatcherFilter.java分流过滤器

package org.test.filter
import lombok.extern.slf4j.Slf4j;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.stereotype.Component;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebFilterChain;
import reactor.core.publisher.Mono;
@Slf4j
@Component
public class DispatcherFilter extends BaseFilter {

   @Override
   public int getOrder() {
      //数值越小,优先级越高
      return Integer.MIN_VALUE;
   }

   @Override
   public Mono filter(ServerWebExchange exchange, WebFilterChain chain) {
      super.readRequestData(exchange.getRequest());
      List signUrls = cantaloupeConfig.getCheckSignWhiteUrls();
      List cacheUrls = cantaloupeConfig.getUpdateCacheWhiteUrls();

      //根据不同的请求分流到不同的过滤器
      if (checkUrl(exchange,signUrls)) {
         return chain.filter(exchange,CheckSignFilter.class);
      }else if (checkUrl(exchange,cacheUrls))
      {
         return chain.filter(exchange,UpdateCacheFilter.class);
      }
      return super.filter(exchange,chain);
   }

}

三,各业务过滤器,举个栗子,CheckSignFilter.java

package com.test.filter;

import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.http.HttpHeaders;
import org.springframework.stereotype.Component;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebFilterChain;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import java.util.List;

/**
 
 * 验证签名过滤器
 */
@Slf4j
@Component
public class CheckSignFilter extends BaseFilter {

   @Override
   public Mono filter(ServerWebExchange exchange, WebFilterChain chain) {
      //to do
      return super.filter(exchange,chain);
   }

   @Override
   public void after(ServerWebExchange exchange) {
   }

   @Override
   public int getOrder() {
      return -100;
   }
}

是不是很简单的就可以让过滤器听你的话,亲测无误,如有问题,请留言 !

你可能感兴趣的:(webFlux,原创)