目录
一、定义
二、谓词使用
1、After
2、Before
3、Between
4、Cookie
5、Header
6、Host
7、Method
8、Path
9、Query
10、RemoteAddr
11、Weight
SpringCloudGateway中三个重要词汇:
路由(Route):配置网关中的一个完整路由,包括命名,地址,谓词集合(规则),过滤器集合。
谓词、断言(Predicate):这是一个 Java 8 函数谓词。输入类型是一个 Spring 框架的 ServerWebExchange。这允许开发人员匹配来自 HTTP 请求的任何内容,例如头部或参数。简单说就是看发送的请求url中是否符合谓词中的规则,符合就通过,不符合就进行拦截。
过滤器(Filter):这些是 Spring 框架网关过滤器在特定工厂中构建的实例。这里,可以在发送下游请求之前或之后修改请求和响应。简单说就是负责在代理服务之前或是之后做的一些事情。
Gateway中有很多已经实现好的谓词,可以查看GatewayPredicate下的实现类:
接下来单独用这11个实现类举例,Gateway项目可以参考:SpringCloudGateway--自动路由映射与手动路由映射_雨欲语的博客-CSDN博客
在指定时间之后,顾名思义就是只有在指定时间之后的请求才会生效:
routes:
- id: service-one
uri: lb://service-one
predicates:
- After=2022-11-26T00:00:00.000+08:00[Asia/Shanghai]
filters:
- StripPrefix=1
metadata:
connect-timeout: 15000 #ms
response-timeout: 15000 #ms
进入AfterRoutePredicateFactory实现里面,可以看到是由工厂方法加匿名内部类实现:
在SpringCloudGateway中可以找到加载这些实现类的工厂方法:
其余的实现类也是同理
指定时间之前,只有在指定时间之前的请求才会生效:
routes:
- id: service-one
uri: lb://service-one
predicates:
- Before=2022-11-26T00:00:00.000+08:00[Asia/Shanghai]
filters:
- StripPrefix=1
metadata:
connect-timeout: 15000 #ms
response-timeout: 15000 #ms
package org.springframework.cloud.gateway.handler.predicate;
import java.time.ZonedDateTime;
import java.util.Collections;
import java.util.List;
import java.util.function.Predicate;
import org.springframework.web.server.ServerWebExchange;
/**
* @author Spencer Gibb
*/
public class BeforeRoutePredicateFactory extends AbstractRoutePredicateFactory {
/**
* DateTime key.
*/
public static final String DATETIME_KEY = "datetime";
public BeforeRoutePredicateFactory() {
super(Config.class);
}
@Override
public List shortcutFieldOrder() {
return Collections.singletonList(DATETIME_KEY);
}
@Override
public Predicate apply(Config config) {
return new GatewayPredicate() {
@Override
public boolean test(ServerWebExchange serverWebExchange) {
final ZonedDateTime now = ZonedDateTime.now();
return now.isBefore(config.getDatetime());
}
@Override
public String toString() {
return String.format("Before: %s", config.getDatetime());
}
};
}
public static class Config {
private ZonedDateTime datetime;
public ZonedDateTime getDatetime() {
return datetime;
}
public void setDatetime(ZonedDateTime datetime) {
this.datetime = datetime;
}
}
}
需要在设定的时间范围之内才能进行请求转发:
routes:
- id: service-one
uri: lb://service-one
predicates:
- Between=2022-11-26T00:00:00.000+08:00[Asia/Shanghai],2022-11-30T00:00:00.000+08:00[Asia/Shanghai]
filters:
- StripPrefix=1
metadata:
connect-timeout: 15000 #ms
response-timeout: 15000 #ms
源码:
package org.springframework.cloud.gateway.handler.predicate;
import java.time.ZonedDateTime;
import java.util.Arrays;
import java.util.List;
import java.util.function.Predicate;
import javax.validation.constraints.NotNull;
import org.springframework.util.Assert;
import org.springframework.validation.annotation.Validated;
import org.springframework.web.server.ServerWebExchange;
/**
* @author Spencer Gibb
*/
public class BetweenRoutePredicateFactory extends AbstractRoutePredicateFactory {
/**
* DateTime 1 key.
*/
public static final String DATETIME1_KEY = "datetime1";
/**
* DateTime 2 key.
*/
public static final String DATETIME2_KEY = "datetime2";
public BetweenRoutePredicateFactory() {
super(Config.class);
}
@Override
public List shortcutFieldOrder() {
return Arrays.asList(DATETIME1_KEY, DATETIME2_KEY);
}
@Override
public Predicate apply(Config config) {
Assert.isTrue(config.getDatetime1().isBefore(config.getDatetime2()),
config.getDatetime1() + " must be before " + config.getDatetime2());
return new GatewayPredicate() {
@Override
public boolean test(ServerWebExchange serverWebExchange) {
final ZonedDateTime now = ZonedDateTime.now();
return now.isAfter(config.getDatetime1()) && now.isBefore(config.getDatetime2());
}
@Override
public String toString() {
return String.format("Between: %s and %s", config.getDatetime1(), config.getDatetime2());
}
};
}
@Validated
public static class Config {
@NotNull
private ZonedDateTime datetime1;
@NotNull
private ZonedDateTime datetime2;
public ZonedDateTime getDatetime1() {
return datetime1;
}
public Config setDatetime1(ZonedDateTime datetime1) {
this.datetime1 = datetime1;
return this;
}
public ZonedDateTime getDatetime2() {
return datetime2;
}
public Config setDatetime2(ZonedDateTime datetime2) {
this.datetime2 = datetime2;
return this;
}
}
}
要求请求中包含指定Cookie名和满足特定正则要求的值,Cookie必须有两个值,第一个Cookie包含的参数名,第二个表示参数对应的值,正则表达式:
routes:
- id: service-one
uri: lb://service-one
predicates:
- Cookie=username,admin*
filters:
- StripPrefix=1
metadata:
connect-timeout: 15000 #ms
response-timeout: 15000 #ms
package org.springframework.cloud.gateway.handler.predicate;
import java.util.Arrays;
import java.util.List;
import java.util.function.Predicate;
import javax.validation.constraints.NotEmpty;
import org.springframework.http.HttpCookie;
import org.springframework.validation.annotation.Validated;
import org.springframework.web.server.ServerWebExchange;
/**
* @author Spencer Gibb
*/
public class CookieRoutePredicateFactory extends AbstractRoutePredicateFactory {
/**
* Name key.
*/
public static final String NAME_KEY = "name";
/**
* Regexp key.
*/
public static final String REGEXP_KEY = "regexp";
public CookieRoutePredicateFactory() {
super(Config.class);
}
@Override
public List shortcutFieldOrder() {
return Arrays.asList(NAME_KEY, REGEXP_KEY);
}
@Override
public Predicate apply(Config config) {
return new GatewayPredicate() {
@Override
public boolean test(ServerWebExchange exchange) {
List cookies = exchange.getRequest().getCookies().get(config.name);
if (cookies == null) {
return false;
}
for (HttpCookie cookie : cookies) {
if (cookie.getValue().matches(config.regexp)) {
return true;
}
}
return false;
}
@Override
public String toString() {
return String.format("Cookie: name=%s regexp=%s", config.name, config.regexp);
}
};
}
@Validated
public static class Config {
@NotEmpty
private String name;
@NotEmpty
private String regexp;
public String getName() {
return name;
}
public Config setName(String name) {
this.name = name;
return this;
}
public String getRegexp() {
return regexp;
}
public Config setRegexp(String regexp) {
this.regexp = regexp;
return this;
}
}
}
表示请求头中必须包含的内容。
注意:参数名和参数值之间依然使用逗号,参数值要使用正则表达式
如果Header只有一个值表示请求头中必须包含的参数。如果有两个值,第一个表示请求头必须包含的参数名,第二个表示请求头参数对应值。
routes:
- id: service-one
uri: lb://service-one
predicates:
- Header=Connection,keep-alive
- Header=Cache-Control,max-age=0
filters:
- StripPrefix=1
metadata:
connect-timeout: 15000 #ms
response-timeout: 15000 #ms
package org.springframework.cloud.gateway.handler.predicate;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.function.Predicate;
import javax.validation.constraints.NotEmpty;
import org.springframework.util.ObjectUtils;
import org.springframework.validation.annotation.Validated;
import org.springframework.web.server.ServerWebExchange;
/**
* @author Spencer Gibb
*/
public class HeaderRoutePredicateFactory extends AbstractRoutePredicateFactory {
/**
* Header key.
*/
public static final String HEADER_KEY = "header";
/**
* Regexp key.
*/
public static final String REGEXP_KEY = "regexp";
public HeaderRoutePredicateFactory() {
super(Config.class);
}
@Override
public List shortcutFieldOrder() {
return Arrays.asList(HEADER_KEY, REGEXP_KEY);
}
@Override
public Predicate apply(Config config) {
boolean hasRegex = !ObjectUtils.isEmpty(config.regexp);
return new GatewayPredicate() {
@Override
public boolean test(ServerWebExchange exchange) {
List values = exchange.getRequest().getHeaders().getOrDefault(config.header,
Collections.emptyList());
if (values.isEmpty()) {
return false;
}
// values is now guaranteed to not be empty
if (hasRegex) {
// check if a header value matches
for (int i = 0; i < values.size(); i++) {
String value = values.get(i);
if (value.matches(config.regexp)) {
return true;
}
}
return false;
}
// there is a value and since regexp is empty, we only check existence.
return true;
}
@Override
public String toString() {
return String.format("Header: %s regexp=%s", config.header, config.regexp);
}
};
}
@Validated
public static class Config {
@NotEmpty
private String header;
private String regexp;
public String getHeader() {
return header;
}
public Config setHeader(String header) {
this.header = header;
return this;
}
public String getRegexp() {
return regexp;
}
public Config setRegexp(String regexp) {
this.regexp = regexp;
return this;
}
}
}
匹配请求参数中Host参数的值,可以有多个,使用逗号隔开,**表示支持模糊匹配:
routes:
- id: service-one
uri: lb://service-one
predicates:
- Host=127.0.0.1:8080,**.test.com
filters:
- StripPrefix=1
metadata:
connect-timeout: 15000 #ms
response-timeout: 15000 #ms
package org.springframework.cloud.gateway.handler.predicate;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.function.Predicate;
import org.springframework.cloud.gateway.support.ServerWebExchangeUtils;
import org.springframework.core.style.ToStringCreator;
import org.springframework.util.AntPathMatcher;
import org.springframework.util.PathMatcher;
import org.springframework.validation.annotation.Validated;
import org.springframework.web.server.ServerWebExchange;
/**
* @author Spencer Gibb
*/
public class HostRoutePredicateFactory extends AbstractRoutePredicateFactory {
private PathMatcher pathMatcher = new AntPathMatcher(".");
public HostRoutePredicateFactory() {
super(Config.class);
}
public void setPathMatcher(PathMatcher pathMatcher) {
this.pathMatcher = pathMatcher;
}
@Override
public List shortcutFieldOrder() {
return Collections.singletonList("patterns");
}
@Override
public ShortcutType shortcutType() {
return ShortcutType.GATHER_LIST;
}
@Override
public Predicate apply(Config config) {
return new GatewayPredicate() {
@Override
public boolean test(ServerWebExchange exchange) {
String host = exchange.getRequest().getHeaders().getFirst("Host");
String match = null;
for (int i = 0; i < config.getPatterns().size(); i++) {
String pattern = config.getPatterns().get(i);
if (pathMatcher.match(pattern, host)) {
match = pattern;
break;
}
}
if (match != null) {
Map variables = pathMatcher.extractUriTemplateVariables(match, host);
ServerWebExchangeUtils.putUriTemplateVariables(exchange, variables);
return true;
}
return false;
}
@Override
public String toString() {
return String.format("Hosts: %s", config.getPatterns());
}
};
}
@Validated
public static class Config {
private List patterns = new ArrayList<>();
public List getPatterns() {
return patterns;
}
public Config setPatterns(List patterns) {
this.patterns = patterns;
return this;
}
@Override
public String toString() {
return new ToStringCreator(this).append("patterns", patterns).toString();
}
}
}
Method表示请求方式。支持多个值,使用逗号分隔,多个值之间为or条件
routes:
- id: service-one
uri: lb://service-one
predicates:
- Method=GET,POST # 表示只允许GET or POST请求通过
filters:
- StripPrefix=1
metadata:
connect-timeout: 15000 #ms
response-timeout: 15000 #ms
package org.springframework.cloud.gateway.handler.predicate;
import java.util.Arrays;
import java.util.List;
import java.util.function.Predicate;
import org.springframework.http.HttpMethod;
import org.springframework.validation.annotation.Validated;
import org.springframework.web.server.ServerWebExchange;
import static java.util.Arrays.stream;
/**
* @author Spencer Gibb
* @author Dennis Menge
*/
public class MethodRoutePredicateFactory extends AbstractRoutePredicateFactory {
/**
* Methods key.
*/
public static final String METHODS_KEY = "methods";
public MethodRoutePredicateFactory() {
super(Config.class);
}
@Override
public List shortcutFieldOrder() {
return Arrays.asList(METHODS_KEY);
}
@Override
public ShortcutType shortcutType() {
return ShortcutType.GATHER_LIST;
}
@Override
public Predicate apply(Config config) {
return new GatewayPredicate() {
@Override
public boolean test(ServerWebExchange exchange) {
HttpMethod requestMethod = exchange.getRequest().getMethod();
return stream(config.getMethods()).anyMatch(httpMethod -> httpMethod == requestMethod);
}
@Override
public String toString() {
return String.format("Methods: %s", Arrays.toString(config.getMethods()));
}
};
}
@Validated
public static class Config {
private HttpMethod[] methods;
public HttpMethod[] getMethods() {
return methods;
}
public void setMethods(HttpMethod... methods) {
this.methods = methods;
}
}
}
请求url包含的路径,这种是使用得最多的一种,路径可以配置多个,用逗号隔开。
routes:
- id: service-one
uri: lb://service-one
predicates:
- Path=/service/**,/server/**
filters:
- StripPrefix=1
metadata:
connect-timeout: 15000 #ms
response-timeout: 15000 #ms
package org.springframework.cloud.gateway.handler.predicate;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.function.Predicate;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.core.style.ToStringCreator;
import org.springframework.http.server.PathContainer;
import org.springframework.validation.annotation.Validated;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.util.pattern.PathPattern;
import org.springframework.web.util.pattern.PathPattern.PathMatchInfo;
import org.springframework.web.util.pattern.PathPatternParser;
import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.putUriTemplateVariables;
import static org.springframework.http.server.PathContainer.parsePath;
/**
* @author Spencer Gibb
* @author Dhawal Kapil
*/
public class PathRoutePredicateFactory extends AbstractRoutePredicateFactory {
private static final Log log = LogFactory.getLog(PathRoutePredicateFactory.class);
private static final String MATCH_TRAILING_SLASH = "matchTrailingSlash";
private PathPatternParser pathPatternParser = new PathPatternParser();
public PathRoutePredicateFactory() {
super(Config.class);
}
private static void traceMatch(String prefix, Object desired, Object actual, boolean match) {
if (log.isTraceEnabled()) {
String message = String.format("%s \"%s\" %s against value \"%s\"", prefix, desired,
match ? "matches" : "does not match", actual);
log.trace(message);
}
}
public void setPathPatternParser(PathPatternParser pathPatternParser) {
this.pathPatternParser = pathPatternParser;
}
@Override
public List shortcutFieldOrder() {
return Arrays.asList("patterns", MATCH_TRAILING_SLASH);
}
@Override
public ShortcutType shortcutType() {
return ShortcutType.GATHER_LIST_TAIL_FLAG;
}
@Override
public Predicate apply(Config config) {
final ArrayList pathPatterns = new ArrayList<>();
synchronized (this.pathPatternParser) {
pathPatternParser.setMatchOptionalTrailingSeparator(config.isMatchTrailingSlash());
config.getPatterns().forEach(pattern -> {
PathPattern pathPattern = this.pathPatternParser.parse(pattern);
pathPatterns.add(pathPattern);
});
}
return new GatewayPredicate() {
@Override
public boolean test(ServerWebExchange exchange) {
PathContainer path = parsePath(exchange.getRequest().getURI().getRawPath());
PathPattern match = null;
for (int i = 0; i < pathPatterns.size(); i++) {
PathPattern pathPattern = pathPatterns.get(i);
if (pathPattern.matches(path)) {
match = pathPattern;
break;
}
}
if (match != null) {
traceMatch("Pattern", match.getPatternString(), path, true);
PathMatchInfo pathMatchInfo = match.matchAndExtract(path);
putUriTemplateVariables(exchange, pathMatchInfo.getUriVariables());
return true;
}
else {
traceMatch("Pattern", config.getPatterns(), path, false);
return false;
}
}
@Override
public String toString() {
return String.format("Paths: %s, match trailing slash: %b", config.getPatterns(),
config.isMatchTrailingSlash());
}
};
}
@Validated
public static class Config {
private List patterns = new ArrayList<>();
private boolean matchTrailingSlash = true;
public List getPatterns() {
return patterns;
}
public Config setPatterns(List patterns) {
this.patterns = patterns;
return this;
}
/**
* @deprecated use {@link #isMatchTrailingSlash()}
*/
@Deprecated
public boolean isMatchOptionalTrailingSeparator() {
return isMatchTrailingSlash();
}
/**
* @deprecated use {@link #setMatchTrailingSlash(boolean)}
*/
@Deprecated
public Config setMatchOptionalTrailingSeparator(boolean matchOptionalTrailingSeparator) {
setMatchTrailingSlash(matchOptionalTrailingSeparator);
return this;
}
public boolean isMatchTrailingSlash() {
return matchTrailingSlash;
}
public Config setMatchTrailingSlash(boolean matchTrailingSlash) {
this.matchTrailingSlash = matchTrailingSlash;
return this;
}
@Override
public String toString() {
return new ToStringCreator(this).append("patterns", patterns)
.append(MATCH_TRAILING_SLASH, matchTrailingSlash).toString();
}
}
}
Query指请求参数,有两种,一种是必须包含参数名:
routes:
- id: service-one
uri: lb://service-one
predicates:
- Query=username
filters:
- StripPrefix=1
metadata:
connect-timeout: 15000 #ms
response-timeout: 15000 #ms
另外一种是设置参数的值,用正则匹配:
routes:
- id: service-one
uri: lb://service-one
predicates:
- Query=username,admin*
filters:
- StripPrefix=1
metadata:
connect-timeout: 15000 #ms
response-timeout: 15000 #ms
package org.springframework.cloud.gateway.handler.predicate;
import java.util.Arrays;
import java.util.List;
import java.util.function.Predicate;
import javax.validation.constraints.NotEmpty;
import org.springframework.util.StringUtils;
import org.springframework.validation.annotation.Validated;
import org.springframework.web.server.ServerWebExchange;
/**
* @author Spencer Gibb
*/
public class QueryRoutePredicateFactory extends AbstractRoutePredicateFactory {
/**
* Param key.
*/
public static final String PARAM_KEY = "param";
/**
* Regexp key.
*/
public static final String REGEXP_KEY = "regexp";
public QueryRoutePredicateFactory() {
super(Config.class);
}
@Override
public List shortcutFieldOrder() {
return Arrays.asList(PARAM_KEY, REGEXP_KEY);
}
@Override
public Predicate apply(Config config) {
return new GatewayPredicate() {
@Override
public boolean test(ServerWebExchange exchange) {
if (!StringUtils.hasText(config.regexp)) {
// check existence of header
return exchange.getRequest().getQueryParams().containsKey(config.param);
}
List values = exchange.getRequest().getQueryParams().get(config.param);
if (values == null) {
return false;
}
for (String value : values) {
if (value != null && value.matches(config.regexp)) {
return true;
}
}
return false;
}
@Override
public String toString() {
return String.format("Query: param=%s regexp=%s", config.getParam(), config.getRegexp());
}
};
}
@Validated
public static class Config {
@NotEmpty
private String param;
private String regexp;
public String getParam() {
return param;
}
public Config setParam(String param) {
this.param = param;
return this;
}
public String getRegexp() {
return regexp;
}
public Config setRegexp(String regexp) {
this.regexp = regexp;
return this;
}
}
}
允许访问的客户端地址,不能使用类似localhost这种:
routes:
- id: service-one
uri: lb://service-one
predicates:
- RemoteAddr=127.0.0.1
filters:
- StripPrefix=1
metadata:
connect-timeout: 15000 #ms
response-timeout: 15000 #ms
package org.springframework.cloud.gateway.handler.predicate;
import java.net.InetSocketAddress;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.function.Predicate;
import javax.validation.constraints.NotEmpty;
import javax.validation.constraints.NotNull;
import io.netty.handler.ipfilter.IpFilterRuleType;
import io.netty.handler.ipfilter.IpSubnetFilterRule;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.cloud.gateway.support.ipresolver.RemoteAddressResolver;
import org.springframework.validation.annotation.Validated;
import org.springframework.web.server.ServerWebExchange;
import static org.springframework.cloud.gateway.support.ShortcutConfigurable.ShortcutType.GATHER_LIST;
/**
* @author Spencer Gibb
*/
public class RemoteAddrRoutePredicateFactory
extends AbstractRoutePredicateFactory {
private static final Log log = LogFactory.getLog(RemoteAddrRoutePredicateFactory.class);
public RemoteAddrRoutePredicateFactory() {
super(Config.class);
}
@Override
public ShortcutType shortcutType() {
return GATHER_LIST;
}
@Override
public List shortcutFieldOrder() {
return Collections.singletonList("sources");
}
@NotNull
private List convert(List values) {
List sources = new ArrayList<>();
for (String arg : values) {
addSource(sources, arg);
}
return sources;
}
@Override
public Predicate apply(Config config) {
List sources = convert(config.sources);
return new GatewayPredicate() {
@Override
public boolean test(ServerWebExchange exchange) {
InetSocketAddress remoteAddress = config.remoteAddressResolver.resolve(exchange);
if (remoteAddress != null && remoteAddress.getAddress() != null) {
String hostAddress = remoteAddress.getAddress().getHostAddress();
String host = exchange.getRequest().getURI().getHost();
if (log.isDebugEnabled() && !hostAddress.equals(host)) {
log.debug("Remote addresses didn't match " + hostAddress + " != " + host);
}
for (IpSubnetFilterRule source : sources) {
if (source.matches(remoteAddress)) {
return true;
}
}
}
return false;
}
@Override
public String toString() {
return String.format("RemoteAddrs: %s", config.getSources());
}
};
}
private void addSource(List sources, String source) {
if (!source.contains("/")) { // no netmask, add default
source = source + "/32";
}
String[] ipAddressCidrPrefix = source.split("/", 2);
String ipAddress = ipAddressCidrPrefix[0];
int cidrPrefix = Integer.parseInt(ipAddressCidrPrefix[1]);
sources.add(new IpSubnetFilterRule(ipAddress, cidrPrefix, IpFilterRuleType.ACCEPT));
}
@Validated
public static class Config {
@NotEmpty
private List sources = new ArrayList<>();
@NotNull
private RemoteAddressResolver remoteAddressResolver = new RemoteAddressResolver() {
};
public List getSources() {
return sources;
}
public Config setSources(List sources) {
this.sources = sources;
return this;
}
public Config setSources(String... sources) {
this.sources = Arrays.asList(sources);
return this;
}
public Config setRemoteAddressResolver(RemoteAddressResolver remoteAddressResolver) {
this.remoteAddressResolver = remoteAddressResolver;
return this;
}
}
}
负载均衡中权重:
routes:
- id: service-one
uri: lb://service-one
predicates:
- Weight=group,2
filters: StripPrefix=1
- id: service-two
uri: lb://service-two
predicates:
- Weight=group,8
filters: StripPrefix=1
package org.springframework.cloud.gateway.handler.predicate;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.function.Predicate;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.cloud.gateway.event.WeightDefinedEvent;
import org.springframework.cloud.gateway.support.WeightConfig;
import org.springframework.context.ApplicationEventPublisher;
import org.springframework.context.ApplicationEventPublisherAware;
import org.springframework.web.server.ServerWebExchange;
import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.GATEWAY_PREDICATE_ROUTE_ATTR;
import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.WEIGHT_ATTR;
/**
* @author Spencer Gibb
*/
// TODO: make this a generic Choose out of group predicate?
public class WeightRoutePredicateFactory extends AbstractRoutePredicateFactory
implements ApplicationEventPublisherAware {
/**
* Weight config group key.
*/
public static final String GROUP_KEY = WeightConfig.CONFIG_PREFIX + ".group";
/**
* Weight config weight key.
*/
public static final String WEIGHT_KEY = WeightConfig.CONFIG_PREFIX + ".weight";
private static final Log log = LogFactory.getLog(WeightRoutePredicateFactory.class);
private ApplicationEventPublisher publisher;
public WeightRoutePredicateFactory() {
super(WeightConfig.class);
}
@Override
public void setApplicationEventPublisher(ApplicationEventPublisher publisher) {
this.publisher = publisher;
}
@Override
public List shortcutFieldOrder() {
return Arrays.asList(GROUP_KEY, WEIGHT_KEY);
}
@Override
public String shortcutFieldPrefix() {
return WeightConfig.CONFIG_PREFIX;
}
@Override
public void beforeApply(WeightConfig config) {
if (publisher != null) {
publisher.publishEvent(new WeightDefinedEvent(this, config));
}
}
@Override
public Predicate apply(WeightConfig config) {
return new GatewayPredicate() {
@Override
public boolean test(ServerWebExchange exchange) {
Map weights = exchange.getAttributeOrDefault(WEIGHT_ATTR, Collections.emptyMap());
String routeId = exchange.getAttribute(GATEWAY_PREDICATE_ROUTE_ATTR);
// all calculations and comparison against random num happened in
// WeightCalculatorWebFilter
String group = config.getGroup();
if (weights.containsKey(group)) {
String chosenRoute = weights.get(group);
if (log.isTraceEnabled()) {
log.trace("in group weight: " + group + ", current route: " + routeId + ", chosen route: "
+ chosenRoute);
}
return routeId.equals(chosenRoute);
}
else if (log.isTraceEnabled()) {
log.trace("no weights found for group: " + group + ", current route: " + routeId);
}
return false;
}
@Override
public String toString() {
return String.format("Weight: %s %s", config.getGroup(), config.getWeight());
}
};
}
}