在之前这篇文章中,我大致介绍了一下google guava库中的RateLimiter的实现以及它背后的令牌桶算法原理。但是也有新的问题,在分布式的环境中,我们如何针对多机环境做限流呢?在查阅了一些资料和其他人的博客之后,我采用了redis来作为限流器的实现基础。
原因主要有以下几点:
但是我们也知道,限流器在每次请求令牌和放入令牌操作中,存在一个协同的问题,即获取令牌操作要尽可能保证原子性,否则无法保证限流器是否能正常工作。在RateLimiter的实现中使用了mutex作为互斥锁来保证操作的原子性,那么在redis中就需要一个类似于事务的机制来保证获取令牌中多重操作的原子性。
面对这样的需求,我们有几个选择:
经过权衡,我采用了第四种方式,通过redis和lua来编写令牌桶算法来完成分布式限流的需求。
话不多说,先贴出lua代码
-- 返回码 1:操作成功 0:未配置 -1: 获取失败 -2:修改错误,建议重新初始化 -500:不支持的操作
-- redis hashmap 中存放的内容:
-- last_mill_second 上次放入令牌或者初始化的时间
-- stored_permits 目前令牌桶中的令牌数量
-- max_permits 令牌桶容量
-- interval 放令牌间隔
-- app 一个标志位,表示对于当前key有没有限流存在
local SUCCESS = 1
local NO_LIMIT = 0
local ACQUIRE_FAIL = -1
local MODIFY_ERROR = -2
local UNSUPPORT_METHOD = -500
local ratelimit_info = redis.pcall("HMGET",KEYS[1], "last_mill_second", "stored_permits", "max_permits", "interval", "app")
local last_mill_second = ratelimit_info[1]
local stored_permits = tonumber(ratelimit_info[2])
local max_permits = tonumber(ratelimit_info[3])
local interval = tonumber(ratelimit_info[4])
local app = ratelimit_info[5]
local method = ARGV[1]
--获取当前毫秒
--考虑主从策略和脚本回放机制,这个time由客户端获取传入
--local curr_time_arr = redis.call('TIME')
--local curr_timestamp = curr_time_arr[1] * 1000 + curr_time_arr[2]/1000
local curr_timestamp = tonumber(ARGV[2])
-- 当前方法为初始化
if method == 'init' then
--如果app不为null说明已经初始化过,不要重复初始化
if(type(app) ~='boolean' and app ~=nil) then
return SUCCESS
end
redis.pcall("HMSET", KEYS[1],
"last_mill_second", curr_timestamp,
"stored_permits", ARGV[3],
"max_permits", ARGV[4],
"interval", ARGV[5],
"app", ARGV[6])
--始终返回成功
return SUCCESS
end
-- 当前方法为修改配置
if method == "modify" then
if(type(app) =='boolean' or app ==nil) then
return MODIFY_ERROR
end
--只能修改max_permits和interval
redis.pcall("HMSET", KEYS[1],
"max_permits", ARGV[3],
"interval", ARGV[4])
return SUCCESS
end
-- 当前方法为删除
if method == "delete" then
--已经清除完毕
if(type(app) =='boolean' or app ==nil) then
return SUCCESS
end
redis.pcall("DEL", KEYS[1])
return SUCCESS
end
-- 尝试获取permits
if method == "acquire" then
-- 如果app为null说明没有对这个进行任何配置,返回0代表不限流
if(type(app) =='boolean' or app ==nil) then
return NO_LIMIT
end
--需要获取令牌数量
local acquire_permits = tonumber(ARGV[3])
--计算上一次放令牌到现在的时间间隔中,一共应该放入多少令牌
local reserve_permits = math.max(0, math.floor((curr_timestamp - last_mill_second) / interval))
local new_permits = math.min(max_permits, stored_permits + reserve_permits)
local result = ACQUIRE_FAIL
--如果桶中令牌数量够则放行
if new_permits >= acquire_permits then
result = SUCCESS
new_permits = new_permits - acquire_permits
end
--更新当前桶中的令牌数量
redis.pcall("HSET", KEYS[1], "stored_permits", new_permits)
--如果这次有放入令牌,则更新时间
if reserve_permits > 0 then
redis.pcall("HSET", KEYS[1], "last_mill_second", curr_timestamp)
end
return result
end
return UNSUPPORT_METHOD
绝大部分逻辑在注释里面都已经写清楚了(我java客户端用的代码删掉了所有的注释,因为提交上去报编译错误,但是redis-cli调试就没问题,我也没太关注原因)。
大致上,我在这个脚本中编写了4种函数:
代码基本上仿照了Guava RateLimiter的逻辑,实现了触发式的放令牌策略。
由于我的需求中不需要像guava RateLimiter那样的预支令牌的逻辑,因此如果当前没有令牌可供服务,我就直接返回获取失败了。
还有一点需要注意的是,我本来在脚本中写了获取redis服务器当前时间的代码,但是我通过redis-cli执行的时候报错了:
Write commands not allowed after non deterministic commands.
这个错误的原因大家可以参见这篇文章,大致原因跟redis集群的重放和备份策略有关,相当于我调用TIME操作,会在主从各执行一次,得到的结果肯定会存在差异,这个差异就给最终逻辑正确性带来了不确定性。在redis 4.0之后引入了redis.replicate_commands()来放开限制。但我考虑了几个因素之后,还是采用网上大部分人的做法,在执行前先行获取到redis的时间戳,然后当做参数传上去。
对lua调试最开始花掉了我不少时间,主要对于redis-cli命令不太熟悉。大家有一样问题的可以参见这篇文章。大致来说就是将写好的脚本放到redis所在文件夹下(我是windows环境),然后在cmd下执行 redis-cli.exe --eval rate_limit.lua test2(key,可重复) , (逗号分隔) init 10101 100 100 10 test2 (后跟参数,空格隔开)。
在完成了lua的调试工作之后,我们就开始java部分的集成代码编写,我们使用的是spring boot来完成开发。
第一部分是redis配置:
@Bean("rateLimitLua")
public DefaultRedisScript getRateLimitScript() {
DefaultRedisScript rateLimitLua = new DefaultRedisScript<>();
rateLimitLua.setLocation(new ClassPathResource("rate_limit.lua"));
rateLimitLua.setResultType(Long.class);
return rateLimitLua;
}
然后是一些与lua适配的枚举和一些bean:
/**
* @author: Yuanqing Luo
* @date: 2018/10/22
*
* 限流的具体方法
*/
public enum RateLimitMethod {
//initialize rate limiter
init,
//modify rate limiter parameter
modify,
//delete rate limiter
delete,
//acquire permits
acquire;
}
/**
* @author: Yuanqing Luo
* @date: 2018/10/22
* rate limite result
**/
public enum RateLimitResult {
SUCCESS(1L),
NO_LIMIT(0L),
ACQUIRE_FAIL(-1L),
MODIFY_ERROR(-2L),
UNSUPPORT_METHOD(-500L),
ERROR(-505L);
private Long code;
RateLimitResult(Long code){
this.code = code;
}
public static RateLimitResult getResult(Long code){
for(RateLimitResult enums: RateLimitResult.values()){
if(enums.code.equals(code)){
return enums;
}
}
throw new IllegalArgumentException("unknown rate limit return code:" + code);
}
}
/**
* @author: Yuanqing Luo
* @date: 2018/10/22
**/
@Getter
@Setter
public class RateLimitVo {
private String url;
private boolean isLimit;
private Double interval;
private Integer maxPermits;
private Integer initialPermits;
}
第三部分就是限流器的调用组装部分:
/**
* @author: Yuanqing Luo
* @date: 2018/10/22
**/
@Service
@Slf4j
public class RateLimitClient {
private static final String RATE_LIMIT_PREFIX = "ratelimit:";
@Autowired
StringRedisTemplate redisTemplate;
@Resource
@Qualifier("rateLimitLua")
RedisScript rateLimitScript;
public RateLimitResult init(String key, RateLimitVo rateLimitInfo){
return exec(key, RateLimitMethod.init,
rateLimitInfo.getInitialPermits(),
rateLimitInfo.getMaxPermits(),
rateLimitInfo.getInterval(),
key);
}
public RateLimitResult modify(String key, RateLimitVo rateLimitInfo){
return exec(key, RateLimitMethod.modify, key,
rateLimitInfo.getMaxPermits(),
rateLimitInfo.getInterval());
}
public RateLimitResult delete(String key){
return exec(key, RateLimitMethod.delete);
}
public RateLimitResult acquire(String key){
return acquire(key, 1);
}
public RateLimitResult acquire(String key, Integer permits){
return exec(key, RateLimitMethod.acquire, permits);
}
/**
* 执行redis的具体方法,限制method,保证没有其他的东西进来
* @param key
* @param method
* @param params
* @return
*/
private RateLimitResult exec(String key, RateLimitMethod method, Object... params){
try {
Long timestamp = getRedisTimestamp();
String[] allParams = new String[params.length + 2];
allParams[0] = method.name();
allParams[1] = timestamp.toString();
for(int index = 0;index < params.length; index++){
allParams[2 + index] = params[index].toString();
}
Long result = redisTemplate.execute(rateLimitScript,
Collections.singletonList(getKey(key)),
allParams);
return RateLimitResult.getResult(result);
} catch (Exception e){
log.error("execute redis script fail, key:{}, method:{}",
key, method.name(), e);
return RateLimitResult.ERROR;
}
}
private Long getRedisTimestamp(){
Long currMillSecond = redisTemplate.execute(
(RedisCallback) redisConnection -> redisConnection.time()
);
return currMillSecond;
}
private String getKey(String key){
return RATE_LIMIT_PREFIX + key;
}
}
java代码这块比较简单了,基本就是封装了之前lua脚本中的4项操作。
第四部分就是测试代码:
/**
* @author: Yuanqing Luo
* @date: 2018/10/22
**/
@RunWith(SpringRunner.class)
@SpringBootTest(classes = OpenApiGatewayApplication.class)
public class RateLimitTest {
@Autowired
private RateLimitClient rateLimitClient;
@Test
public void testInit(){
RateLimitVo vo = new RateLimitVo();
vo.setInitialPermits(500);
vo.setMaxPermits(500);
vo.setInterval(2.0);
rateLimitClient.init("test", vo);
}
@Test
public void testAcquire() throws InterruptedException {
//10个线程
ExecutorService executorService = Executors.newFixedThreadPool(20);
Subject writeSubject = new SerializedSubject(PublishSubject.create());
Observable readSubject = writeSubject.share();
Observable bucketStream = Observable.defer(()->{
return readSubject.window(200, TimeUnit.MILLISECONDS)
.flatMap(
observable->
observable.reduce(new RateLimitSummary(0,0,0),
(a, b)-> a.reduce(b))
);
});
Observable rollingBucketStream = bucketStream.window(5, 1)
.flatMap(observable->observable.reduce(new RateLimitSummary(0, 0, 0),
(a, b)-> a.reduce(b)));
Runnable acquire = () -> {
Random random = new Random();
while(true){
try {
Thread.sleep(30);
} catch (InterruptedException e) {
e.printStackTrace();
}
RateLimitResult result = rateLimitClient.acquire("test");
writeSubject.onNext(new RateLimitSummary(result));
}
};
//初始时间
final long currentMillis = System.currentTimeMillis();
rollingBucketStream.subscribe(summary->{
double timestamp = (System.currentTimeMillis() - currentMillis)/1000.0;
System.out.println("time:"+ timestamp + ", acquired:" + summary.acquire +
", reject " + summary.reject + ", error: " + summary.error);
});
for(int i=0;i<20;i++){
executorService.submit(acquire);
}
while(true){
Thread.sleep(5000);
}
}
private static class RateLimitSummary{
public int acquire;
public int reject;
public int error;
public RateLimitSummary(RateLimitResult result){
this.acquire = result == RateLimitResult.SUCCESS?1:0;
this.reject = result == RateLimitResult.ACQUIRE_FAIL?1:0;
this.error = result == RateLimitResult.ERROR?1:0;
}
public RateLimitSummary(int acquire, int reject, int error){
this.acquire = acquire;
this.reject = reject;
this.error = error;
}
public RateLimitSummary reduce(RateLimitSummary toAdd){
return new RateLimitSummary(this.acquire + toAdd.acquire,
this.reject + toAdd.reject,
this.error + toAdd.error);
}
}
}
这一段代码我仿照了Hystrix中的熔断统计的代码,通过一个subject来存放获取令牌结果,然后通过第一层bucketStream来将令牌结果按照200ms来分组并且reduce成一个结果。接着通过rollingBucketStream来将200ms的分组组合成一个一秒的时间窗(即5个为一组),并且以200ms为步长滚动。最后统计出来的结果通过subscribe来打印结果。之前的init代码我们看已经初始化了一个大小为500的令牌桶,存放令牌的时间间隔为2.0ms,所以支持的QPS为500。接着我们执行这段代码,并截取一部分输出:
time:75.857, acquired:460, reject 8, error: 0
time:76.056, acquired:483, reject 36, error: 0
time:76.268, acquired:506, reject 52, error: 0
time:76.454, acquired:503, reject 59, error: 0
time:76.707, acquired:457, reject 69, error: 0
time:76.854, acquired:417, reject 66, error: 0
time:77.054, acquired:454, reject 36, error: 0
time:77.255, acquired:459, reject 54, error: 0
time:77.453, acquired:458, reject 77, error: 0
time:77.658, acquired:474, reject 103, error: 0
time:77.858, acquired:490, reject 132, error: 0
可以看到,这个结果基本每200ms输出一次,然后一秒钟内的获取了令牌数目最大值跟500接近,并且能够很好地处理reject。有一部分结果一秒钟获取的令牌数与500差距较大,我分析的原因是因为请求重复时间段比较多,很多请求发生在前一个获取了令牌之后的2ms内,产生了reject。
通过redis和lua,我实现了一个简单的分布式限流器。通过上述代码,大家能看到一个大致的实现框架,并且通过测试代码完成了验证。如果各位看官有什么问题欢迎留言,希望能跟大家共同学习。