SpringCould Gateway里添加前置过滤器校验参数格式

目前SpringMVC做了很好的兼容处理,即使请求参数的json格式有错误,也能只取json串前面一段能匹配上的数据,而不报错。

啥事有好处也有坏处,不好的地方就是有注入风险。所以需要校验入参的格式,那接下来就需要考虑在哪里做校验。因为校验入参格式是公共操作,所以最好选在网关层做拦截校验,而不是下发到应用层。

我们项目采用的是SpringCould Gateway实现的网关,我尝试在GlobalFilter的实现类下读取请求参数然后校验json格式,可以做到,但引发了一个问题,请求参数在流里面,被读取过一次后就没有了,于是校验通过的请求下发到应用层时就会报错。

查找了一些资料,和多次尝试后,解决了这个问题。步骤如下:


  1. 版本信息pom.xml
<dependencies>
    <dependency>
        <groupId>org.springframework.bootgroupId>
        <artifactId>spring-boot-dependenciesartifactId>
        <version>2.3.4.RELEASEversion>
    dependency>

   <dependency>
       <groupId>org.springframework.cloudgroupId>
       <artifactId>spring-cloud-starter-gatewayartifactId>
       <version>2.2.2.RELEASEversion>
   dependency>
dependencies>

  1. 添加一个前置Filter先读出存一份,Order优先级调低

import org.springframework.cloud.gateway.filter.GatewayFilterChain;
import org.springframework.cloud.gateway.filter.GlobalFilter;
import org.springframework.core.Ordered;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferUtils;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpRequestDecorator;
import org.springframework.stereotype.Component;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

/**
 * 前置拦截器
 * 
 * @Date 2021/7/16 16:29
 */
@Component
public class DataBufferFilter implements Ordered, GlobalFilter {

    @Override
    public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
        ServerHttpRequest request = exchange.getRequest();
        HttpMethod reqMethod = request.getMethod();
        if ((HttpMethod.POST.equals(reqMethod) || HttpMethod.PUT.equals(reqMethod))
                && MediaType.APPLICATION_JSON.equals(request.getHeaders().getContentType())) {
            return DataBufferUtils.join(request.getBody()).flatMap(dataBuffer -> {
                DataBufferUtils.retain(dataBuffer);
                Flux<DataBuffer> cachedFlux = Flux.defer(() -> Flux.just(dataBuffer.slice(0, dataBuffer.readableByteCount())));
                ServerHttpRequest mutatedRequest = new ServerHttpRequestDecorator(exchange.getRequest()) {
                    @Override
                    public Flux<DataBuffer> getBody() {
                        return cachedFlux;
                    }
                };
                return chain.filter(exchange.mutate().request(mutatedRequest).build());
            });
        } else {
            return chain.filter(exchange);
        }
    }

    @Override
    public int getOrder() {
        return Ordered.HIGHEST_PRECEDENCE;
    }
}


  1. 业务校验Filter里读取校验
import org.springframework.cloud.gateway.filter.GatewayFilterChain;
import org.springframework.cloud.gateway.filter.GlobalFilter;
import org.springframework.core.Ordered;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferUtils;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import com.xxx.xxx.util.JsonValidator;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import java.nio.CharBuffer;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;

@Slf4j
@Component
public class ValidateFilter implements GlobalFilter, Ordered {

    @Override
    public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
        ServerHttpRequest request = exchange.getRequest();
        validateParam(request);
        ServerHttpRequest.Builder mutate = validateToken(request);
        ServerHttpRequest buildRequest = mutate.build();
        return chain.filter(exchange.mutate().request(buildRequest).build());
    }

    private void validateParam(ServerHttpRequest request) {
        HttpMethod reqMethod = request.getMethod();
        if ((HttpMethod.POST.equals(reqMethod) || HttpMethod.PUT.equals(reqMethod))
                && MediaType.APPLICATION_JSON.equals(request.getHeaders().getContentType())) {
            String param = resolveBodyFromRequest(request);
            if (!JsonValidator.validate(param)) {
                log.debug("param:{}" + param);
                throw new ValidException("Param Error");
            }
        }
    }
    
    private String resolveBodyFromRequest(ServerHttpRequest serverHttpRequest) {
        Flux<DataBuffer> body = serverHttpRequest.getBody();
        AtomicReference<String> bodyRef = new AtomicReference<>();
        body.subscribe(buffer -> {
            CharBuffer charBuffer = StandardCharsets.UTF_8.decode(buffer.asByteBuffer());
            DataBufferUtils.release(buffer);
            bodyRef.set(charBuffer.toString());
        });
        return bodyRef.get();
    }

  1. 自定义的json格式校验类
import org.apache.commons.lang3.StringUtils;
import java.text.CharacterIterator;
import java.text.StringCharacterIterator;

/**
 * JSON校验
 * 
 * @Date 2021/7/16 17:12
 */
public class JsonValidator {
    private CharacterIterator it;
    private char c;
    private int col;

    private static JsonValidator instance = new JsonValidator();

    private JsonValidator() {
    }

    public static boolean validate(String input) {
        return instance.valid(input);
    }

    private boolean valid(String input) {
        if (StringUtils.isBlank(input)) {
            return true;
        }
        input = input.trim();
        boolean ret = true;
        it = new StringCharacterIterator(input);
        c = it.first();
        col = 1;
        if (!value()) {
            ret = error("value", 1);
        } else {
            skipWhiteSpace();
            if (c != CharacterIterator.DONE) {
                ret = error("end", col);
            }
        }

        return ret;
    }

    private boolean value() {
        return literal("true") || literal("false") || literal("null") || string() || number() || object() || array();
    }

    private boolean literal(String text) {
        CharacterIterator ci = new StringCharacterIterator(text);
        char t = ci.first();
        if (c != t)
            return false;

        int start = col;
        boolean ret = true;
        for (t = ci.next(); t != CharacterIterator.DONE; t = ci.next()) {
            if (t != nextCharacter()) {
                ret = false;
                break;
            }
        }
        nextCharacter();
        if (!ret)
            error("literal " + text, start);
        return ret;
    }

    private boolean array() {
        return aggregate('[', ']', false);
    }

    private boolean object() {
        return aggregate('{', '}', true);
    }

    private boolean aggregate(char entryCharacter, char exitCharacter, boolean prefix) {
        if (c != entryCharacter)
            return false;
        nextCharacter();
        skipWhiteSpace();
        if (c == exitCharacter) {
            nextCharacter();
            return true;
        }

        for (;;) {
            if (prefix) {
                int start = col;
                if (!string())
                    return error("string", start);
                skipWhiteSpace();
                if (c != ':')
                    return error("colon", col);
                nextCharacter();
                skipWhiteSpace();
            }
            if (value()) {
                skipWhiteSpace();
                if (c == ',') {
                    nextCharacter();
                } else if (c == exitCharacter) {
                    break;
                } else {
                    return error("comma or " + exitCharacter, col);
                }
            } else {
                return error("value", col);
            }
            skipWhiteSpace();
        }

        nextCharacter();
        return true;
    }

    private boolean number() {
        if (!Character.isDigit(c) && c != '-')
            return false;
        int start = col;
        if (c == '-')
            nextCharacter();
        if (c == '0') {
            nextCharacter();
        } else if (Character.isDigit(c)) {
            while (Character.isDigit(c))
                nextCharacter();
        } else {
            return error("number", start);
        }
        if (c == '.') {
            nextCharacter();
            if (Character.isDigit(c)) {
                while (Character.isDigit(c))
                    nextCharacter();
            } else {
                return error("number", start);
            }
        }
        if (c == 'e' || c == 'E') {
            nextCharacter();
            if (c == '+' || c == '-') {
                nextCharacter();
            }
            if (Character.isDigit(c)) {
                while (Character.isDigit(c))
                    nextCharacter();
            } else {
                return error("number", start);
            }
        }
        return true;
    }

    private boolean string() {
        if (c != '"')
            return false;

        int start = col;
        boolean escaped = false;
        for (nextCharacter(); c != CharacterIterator.DONE; nextCharacter()) {
            if (!escaped && c == '\\') {
                escaped = true;
            } else if (escaped) {
                if (!escape()) {
                    return false;
                }
                escaped = false;
            } else if (c == '"') {
                nextCharacter();
                return true;
            }
        }
        return error("quoted string", start);
    }

    private boolean escape() {
        int start = col - 1;
        if (" \\\"/bfnrtu".indexOf(c) < 0) {
            return error("escape sequence  \\\",\\\\,\\/,\\b,\\f,\\n,\\r,\\t  or  \\uxxxx ", start);
        }
        if (c == 'u') {
            if (!ishex(nextCharacter()) || !ishex(nextCharacter()) || !ishex(nextCharacter())
                    || !ishex(nextCharacter())) {
                return error("unicode escape sequence  \\uxxxx ", start);
            }
        }
        return true;
    }

    private boolean ishex(char d) {
        return "0123456789abcdefABCDEF".indexOf(c) >= 0;
    }

    private char nextCharacter() {
        c = it.next();
        ++col;
        return c;
    }

    private void skipWhiteSpace() {
        while (Character.isWhitespace(c)) {
            nextCharacter();
        }
    }

    private boolean error(String type, int col) {
        System.out.printf("type: %s, col: %s%s", type, col, System.getProperty("line.separator"));
        return false;
    }

}


你可能感兴趣的:(SpringCloud,spring,cloud,gateway,GlobalFilter)