阿里云redis集群模式导致的gateway网关限流不生效

阿里云redis集群模式下无法限流

问题现象

问题现象: 阿里云redis集群无法限流, 提示一下异常Error in execution; nested exception is io.lettuce.core.RedisCommandExecutionException:
ERR bad lua script for redis cluster, all the keys that the script uses should be passed using the KEYS array, and KEYS should not be in expression

翻译过来大致的意思就是Redis集群中有错误lua脚本,脚本使用的所有keys都应该使用KEYS数组传递,并且键不应在表达式中

看一下gateway网关中使用的lua脚本, 位于gateway依赖包下的META-INF/scripts/request_rate_limiter.lua

阿里云redis集群模式导致的gateway网关限流不生效_第1张图片
gateway原Lua脚本如下

local tokens_key = KEYS[1]
local timestamp_key = KEYS[2]
--redis.log(redis.LOG_WARNING, "tokens_key " .. tokens_key)

local rate = tonumber(ARGV[1])
local capacity = tonumber(ARGV[2])
local now = tonumber(ARGV[3])
local requested = tonumber(ARGV[4])

local fill_time = capacity/rate
local ttl = math.floor(fill_time*2)

--redis.log(redis.LOG_WARNING, "rate " .. ARGV[1])
--redis.log(redis.LOG_WARNING, "capacity " .. ARGV[2])
--redis.log(redis.LOG_WARNING, "now " .. ARGV[3])
--redis.log(redis.LOG_WARNING, "requested " .. ARGV[4])
--redis.log(redis.LOG_WARNING, "filltime " .. fill_time)
--redis.log(redis.LOG_WARNING, "ttl " .. ttl)

local last_tokens = tonumber(redis.call("get", tokens_key))
if last_tokens == nil then
  last_tokens = capacity
end
--redis.log(redis.LOG_WARNING, "last_tokens " .. last_tokens)

local last_refreshed = tonumber(redis.call("get", timestamp_key))
if last_refreshed == nil then
  last_refreshed = 0
end
--redis.log(redis.LOG_WARNING, "last_refreshed " .. last_refreshed)

local delta = math.max(0, now-last_refreshed)
local filled_tokens = math.min(capacity, last_tokens+(delta*rate))
local allowed = filled_tokens >= requested
local new_tokens = filled_tokens
local allowed_num = 0
if allowed then
  new_tokens = filled_tokens - requested
  allowed_num = 1
end

--redis.log(redis.LOG_WARNING, "delta " .. delta)
--redis.log(redis.LOG_WARNING, "filled_tokens " .. filled_tokens)
--redis.log(redis.LOG_WARNING, "allowed_num " .. allowed_num)
--redis.log(redis.LOG_WARNING, "new_tokens " .. new_tokens)

if ttl > 0 then
  redis.call("setex", tokens_key, ttl, new_tokens)
  redis.call("setex", timestamp_key, ttl, now)
end

-- return { allowed_num, new_tokens, capacity, filled_tokens, requested, new_tokens }
return { allowed_num, new_tokens }

下边这版是删除了部分注释代码, 以及添加了部分注释的

-- token的key
local tokens_key = KEYS[1]
-- 时间戳的key
local timestamp_key = KEYS[2]

-- 往令牌桶里面放令牌的速率,一秒多少个
local rate = tonumber(ARGV[1])
-- 令牌桶最大容量
local capacity = tonumber(ARGV[2])
-- 当前的时间戳
local now = tonumber(ARGV[3])
-- 请求消耗令牌的数量
local requested = tonumber(ARGV[4])

-- 计算放满令牌桶的所需时长
local fill_time = capacity/rate
-- redis过期时间 这里为什么是放满令牌桶的两倍
-- 因为这个时间不能太长,加入太长10s,你第一秒把令牌拿完,后面9s,就会出现突刺现象
local ttl = math.floor(fill_time*2)

-- 获取令牌桶的数量,如果为空,将令牌桶容量赋值给当前token
local last_tokens = tonumber(redis.call("get", tokens_key))
if last_tokens == nil then
  last_tokens = capacity
end

-- 获取最后的更新时间戳,如果为空,设置为0
local last_refreshed = tonumber(redis.call("get", timestamp_key))
if last_refreshed == nil then
  last_refreshed = 0
end

-- 计算出时间间隔
local delta = math.max(0, now-last_refreshed)

-- 该往令牌桶放令牌的数量
local filled_tokens = math.min(capacity, last_tokens+(delta*rate))
-- 看剩余的令牌是否能够获取到
local allowed = filled_tokens >= requested
local new_tokens = filled_tokens
-- 零代表false, 即是限流
local allowed_num = 0

-- 如果允许获取得到,计算出剩余的令牌数量,并标记可以获取
if allowed then
  new_tokens = filled_tokens - requested
  allowed_num = 1
end

-- 存到redis
if ttl > 0 then
  redis.call("setex", tokens_key, ttl, new_tokens)
  redis.call("setex", timestamp_key, ttl, now)
end

--  lua可以返回多个字段,java获取时用List获取
return { allowed_num, new_tokens }

核心问题就在

local tokens_key = KEYS[1]

local timestamp_key = KEYS[2]

这里将KEYS[1]和 KEYS[2]赋值给了变量, 然后传递给了后续的代码

然而, 为了保证脚本里面的所有操作都在相同slot进行,云数据库Redis集群版本会对Lua脚本做如下限制

  1. 所有key都应该由KEYS数组来传递redis.call/pcall中调用的redis命令,key的位置必须是KEYS array(不能使用Lua变量替换KEYS),否则直接返回错误信息:ERR bad lua script for redis cluster, all the keys that the script uses should be passed using the KEYS arrayrn
  2. 所有key必须在一个slot上,否则返回错误信息: ERR eval/evalsha command keys must be in same slotrn
  3. 调用必须要带有key,否则直接返回错误信息: ERR for redis cluster, eval/evalsha number of keys can’t be negative or zerorn

核心: 然而gateway自带原Lua脚本违背了第一条, 使用Lua变量替换了KEYS

如何解决?

步骤一: 修改lua脚本

将脚本的进行替换, 为了方便观察我将非相关的注释删除, 同时对改动的地方进行标注(每一个改动的地方都使用数字标注)

  1. 删除local tokens_key = KEYS[1]local timestamp_key = KEYS[2](这里为了方便观察我就注释掉了)
  2. 将所有用到tokens_key 的地方替换成KEYS[1]
  3. 将所有用到timestamp_key的地方替换成KEYS[2]
-- local tokens_key = KEYS[1] -- 1. 注释掉这行代码
-- local timestamp_key = KEYS[2] -- 2. 注释掉这行代码


local rate = tonumber(ARGV[1])
local capacity = tonumber(ARGV[2])
local now = tonumber(ARGV[3])
local requested = tonumber(ARGV[4])


local fill_time = capacity/rate
local ttl = math.floor(fill_time*2)

-- local last_tokens = tonumber(redis.call("get", tokens_key)) -- 3.1 将这行代码的tokens_key修改成KEYS[1]
local last_tokens = tonumber(redis.call("get", KEYS[1])) -- 3.2 修改后的代码
if last_tokens == nil then
  last_tokens = capacity
end

-- local last_refreshed = tonumber(redis.call("get", timestamp_key)) -- 4.1 将这行代码的timestamp_key修改成KEYS[2]
local last_refreshed = tonumber(redis.call("get", KEYS[2])) -- 4.2 修改后的代码
if last_refreshed == nil then
  last_refreshed = 0
end

local delta = math.max(0, now-last_refreshed)

local filled_tokens = math.min(capacity, last_tokens+(delta*rate))
local allowed = filled_tokens >= requested
local new_tokens = filled_tokens
local allowed_num = 0

if allowed then
  new_tokens = filled_tokens - requested
  allowed_num = 1
end

if ttl > 0 then
  -- redis.call("setex", tokens_key, ttl, new_tokens) -- 5.1 将这行代码的tokens_key修改成KEYS[1]
  redis.call("setex", KEYS[1], ttl, new_tokens) -- 5.2 修改后的代码
  -- redis.call("setex", timestamp_key, ttl, now) -- 6.1 将这行代码的timestamp_key修改成KEYS[2]
  redis.call("setex", timestamp_key, ttl, now) -- 6.2 修改后的代码
end

return { allowed_num, new_tokens }

步骤二: 重写isAllowed()方法, 并将原Lua脚本替换成我们修改后的脚本

import cn.hutool.core.collection.CollUtil;
import com.alibaba.fastjson.JSONObject;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.cloud.gateway.event.FilterArgsEvent;
import org.springframework.cloud.gateway.filter.ratelimit.RedisRateLimiter;
import org.springframework.cloud.gateway.route.RouteDefinitionRouteLocator;
import org.springframework.cloud.gateway.support.ConfigurationService;
import org.springframework.data.redis.core.ReactiveStringRedisTemplate;
import org.springframework.data.redis.core.script.DefaultRedisScript;
import org.springframework.data.redis.core.script.RedisScript;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import java.time.Instant;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicBoolean;

/**
 * @author whitebrocade
 * @version 1.0
 * @description: 自定义的限流器
 */
@Slf4j
@ConfigurationProperties("spring.cloud.gateway.redis-rate-limiter")
public class MyRequestRateLimiter extends RedisRateLimiter {

    /**
     * 配置文件中自定义流策略的的key
     */
    private static final String KEY_RESOLVER_KEY = "key-resolver";

    /**
     * 间隔符号
     */
    private static final String INTERVAL_MARK = "_";

    public MyRequestRateLimiter(ReactiveStringRedisTemplate redisTemplate, RedisScript<List<Long>> script, ConfigurationService configurationService) {
        super(redisTemplate, script, configurationService);
    }

    @Override
    public void onApplicationEvent(FilterArgsEvent event) {
        Map<String, Object> args = event.getArgs();
        if (args.containsKey(KEY_RESOLVER_KEY)) {
            String routeId = event.getRouteId() + INTERVAL_MARK + args.get(KEY_RESOLVER_KEY).hashCode();
            super.onApplicationEvent(new FilterArgsEvent(event.getSource(), routeId, args));
        }
        super.onApplicationEvent(event);
    }

    // -------------下边是重写isAllowed()需要的逻辑

    @Autowired
    @Qualifier("reactiveStringRedisTemplateFlowLimit")
    ReactiveStringRedisTemplate redisTemplate;

    private Config defaultConfig;

    private AtomicBoolean initialized = new AtomicBoolean(false);

    /**
     * 限流lua脚本(修改后)
     */
    private final String luaScriptStr = "local rate = tonumber(ARGV[1])\n" +
            "local capacity = tonumber(ARGV[2])\n" +
            "local now = tonumber(ARGV[3])\n" +
            "local requested = tonumber(ARGV[4])\n" +
            "\n" +
            "local fill_time = capacity/rate\n" +
            "local ttl = math.floor(fill_time*2)\n" +
            "\n" +
            "local last_tokens = tonumber(redis.call(\"get\", KEYS[1]))\n" +
            "if last_tokens == nil then\n" +
            "  last_tokens = capacity\n" +
            "end\n" +
            "\n" +
            "local last_refreshed = tonumber(redis.call(\"get\", KEYS[2]))\n" +
            "if last_refreshed == nil then\n" +
            "  last_refreshed = 0\n" +
            "end\n" +
            "\n" +
            "local delta = math.max(0, now-last_refreshed)\n" +
            "\n" +
            "local filled_tokens = math.min(capacity, last_tokens+(delta*rate))\n" +
            "local allowed = filled_tokens >= requested\n" +
            "local new_tokens = filled_tokens\n" +
            "local allowed_num = 0\n" +
            "\n" +
            "if allowed then\n" +
            "  new_tokens = filled_tokens - requested\n" +
            "  allowed_num = 1\n" +
            "end\n" +
            "\n" +
            "if ttl > 0 then\n" +
            "  redis.call(\"setex\", KEYS[1], ttl, new_tokens)\n" +
            "  redis.call(\"setex\", KEYS[2], ttl, now)\n" +
            "end\n" +
            "\n" +
            "return { allowed_num, new_tokens }\n";

    @Override
    public Mono<Response> isAllowed(String routeId, String id) {
        Config routeConfig = loadConfiguration(routeId);

        int replenishRate = routeConfig.getReplenishRate();

        int burstCapacity = routeConfig.getBurstCapacity();

        int requestedTokens = routeConfig.getRequestedTokens();

        try {
            // keys参数
            List<String> keys = getKeys(id);
            // args脚本参数
            List<String> scriptArgs = Arrays.asList(
                    replenishRate + "",
                    burstCapacity + "",
                    Instant.now().getEpochSecond() + "",
                    requestedTokens + "");
            // 执行lua脚本
            DefaultRedisScript<List> luaScript = new DefaultRedisScript<>(luaScriptStr, List.class);
            Flux<List> flux = this.redisTemplate.execute(luaScript, keys, scriptArgs);

            // 根据执行结果记录异常或将结果返回
            return flux.onErrorResume(throwable -> {
                log.error("无法调用rate的限流lua脚本: {}", JSONObject.toJSONString(flux), throwable);
                // 将List 转换成 ArrayList
                ArrayList<Long> arr = new ArrayList<>();
                CollUtil.addAll(arr, Arrays.asList(1L, - 1L));
                return Flux.just(arr);
            }).reduce(new ArrayList<Long>(), (longs, l) -> {
                longs.addAll(l);
                return longs;
            }).map(results -> { // 将响应结果返回
                // 0-限流 1-通过
                boolean allowed = results.get(0) == 1L;
                Long tokensLeft = results.get(1);
                Response response = new Response(allowed, getHeaders(routeConfig, tokensLeft));
                log.error("限流返回结果响应体response: " + response);
                return response;
            });
        } catch (Exception e) {
            log.error("Redis限流异常", e);
        }
        // 如果出现了异常, 那么就直接放行, 同时将剩余令牌数设置成-1
        return Mono.just(new Response(true, getHeaders(routeConfig, -1L)));
    }
	/*
	限流拼接的前缀
	1. request_rate_limiter.{xxxx}.tokens
	2. request_rate_limiter.{xxxx}.timestamp
	*/
    static List<String> getKeys(String id) {
        String prefix = "request_rate_limiter.{" + id;

        String tokenKey = prefix + "}.tokens";
        String timestampKey = prefix + "}.timestamp";
        return Arrays.asList(tokenKey, timestampKey);
    }

    /*
   	根据routeId加载配置
    */
    Config loadConfiguration(String routeId) {
        Config routeConfig = getConfig().getOrDefault(routeId, defaultConfig);

        if (routeConfig == null) {
            routeConfig = getConfig().get(RouteDefinitionRouteLocator.DEFAULT_FILTERS);
        }

        if (routeConfig == null) {
            throw new IllegalArgumentException("No Configuration found for route " + routeId + " or defaultFilters");
        }
        return routeConfig;
    }
}

你可能感兴趣的:(redis,gateway,阿里云,redis,gateway)