基于Redis的分布式令牌桶限流器

本文根据Guava RateLimiter令牌桶限流器修改的基于Redis的分布式限流器。令牌桶采用横定速率生成令牌存放入桶中,通过计算获取指定令牌数所需要的等待时间来进行限流。

注:其中对于令牌桶的更新需要依赖分布式同步锁:DistributedLock
本文采用基于Redis的RedLock来实现,请参见本人另外的一篇文章:基于Redis RedLock的分布式同步锁

1、Guava RateLimiter原理

根据令牌桶算法,桶中的令牌是持续生成存放的,有请求时需要先从桶中拿到令牌才能开始执行,谁来持续生成令牌存放呢?

对于令牌桶中令牌的产生一般有两种做法:

  • 一种解法是,开启一个定时任务,由定时任务持续生成令牌。这样的问题在于会极大的消耗系统资源,如,某接口需要分别对每个用户做访问频率限制,假设系统中存在6W用户,则至多需要开启6W个定时任务来维持每个桶中的令牌数,这样的开销是巨大的。
  • 另一种解法则是延迟计算,如上resync函数。该函数会在每次获取令牌之前调用,其实现思路为,若当前时间晚于nextFreeTicketMicros,则计算该段时间内可以生成多少令牌,将生成的令牌加入令牌桶中并更新数据。这样一来,只需要在获取令牌时计算一次即可。

Guava RateLimiter的做法是第二种,当每次获取令牌时,先执行resync来更新令牌桶中令牌的数量,从而达到异步产生令牌的目的。

关键变量:
  • nextFreeTicketMicros:表示下一次允许补充令牌的时间(时刻)。这个变量的解释比较拗口,看下面流程会比较清晰
  • maxPermits:最大令牌数
  • storedPermits:当前存储的令牌数,数量不能超过最大令牌数

其中关键方法如下:

void resync(long nowMicros) {
  // if nextFreeTicket is in the past, resync to now
  if (nowMicros > nextFreeTicketMicros) {
    double newPermits = (nowMicros - nextFreeTicketMicros) / coolDownIntervalMicros();
    storedPermits = min(maxPermits, storedPermits + newPermits);
    nextFreeTicketMicros = nowMicros;
  }
}

主要有如下几步操作:

  • 1、根据当前的时间戳nowMicros与上次一个获取令牌后设置的下次允许补充令牌的时间戳nextFreeTicketMicros进行比较,如果当前时间在上一次设定的nextFreeTicketMicros之后,那么表示可以有多的令牌可以获取
  • 2、计算当前时间与上一次设定时间的差值,除以每个令牌产生的时间间隔coolDownIntervalMicros()来获取这一时间段新产生的令牌数,同时加上上次剩余的令牌数与最大令牌数进行比较,取小者作为当前的桶中的令牌数
  • 3、将下次允许产生令牌的时间设置为当前时间nowMicros

2、基于Redis的分布式令牌桶

令牌桶
/**
 * redis令牌桶
 * @author: Meng.Liu
 * @date: 2018/11/12 下午4:07
 */
@Data
public class RedisPermits {

    /**
     * maxPermits 最大存储令牌数
     */
    private Long maxPermits;
    /**
     * storedPermits 当前存储令牌数
     */
    private Long storedPermits;
    /**
     * intervalMillis 添加令牌时间间隔
     */
    private Long intervalMillis;
    /**
     * nextFreeTicketMillis 下次请求可以获取令牌的起始时间,默认当前系统时间
     */
    private Long nextFreeTicketMillis;

    /**
     * @param permitsPerSecond 每秒放入的令牌数
     * @param maxBurstSeconds  maxPermits由此字段计算,最大存储maxBurstSeconds秒生成的令牌
     */
    public RedisPermits(Double permitsPerSecond, Integer maxBurstSeconds) {
        if (null == maxBurstSeconds) {
            maxBurstSeconds = 60;
        }
        this.maxPermits = (long) (permitsPerSecond * maxBurstSeconds);
        this.storedPermits = permitsPerSecond.longValue();
        this.intervalMillis = (long) (TimeUnit.SECONDS.toMillis(1) / permitsPerSecond);
        this.nextFreeTicketMillis = System.currentTimeMillis();
    }

    /**
     * redis的过期时长
     * @return
     */
    public long expires() {
        long now = System.currentTimeMillis();
        return 2 * TimeUnit.MINUTES.toSeconds(1)
                + TimeUnit.MILLISECONDS.toSeconds(Math.max(nextFreeTicketMillis, now) - now);
    }

    /**
     * 异步更新当前持有的令牌数
     * 若当前时间晚于nextFreeTicketMicros,则计算该段时间内可以生成多少令牌,将生成的令牌加入令牌桶中并更新数据
     * @param now
     * @return
     */
    public boolean reSync(long now){
        if (now > nextFreeTicketMillis) {
            storedPermits = Math.min(maxPermits, storedPermits + (now - nextFreeTicketMillis) / intervalMillis);
            nextFreeTicketMillis = now;
            return true;
        }
        return false;
    }
}

该类为令牌桶信息,其中包含了令牌桶的大小,令牌产生速率以及核心令牌桶异步更新方法reSync。

限流器处理类
/**
 * 令牌桶限流器
 * @author: Meng.Liu
 * @date: 2018/11/12 下午4:31
 */
@Slf4j
@Data
public class RateLimiter {

    /**
     * redis key
     */
    private String key;
    /**
     * redis分布式锁的key
     * @return
     */
    private String lockKey;
    /**
     * 每秒存入的令牌数
     */
    private Double permitsPerSecond;
    /**
     * 最大存储maxBurstSeconds秒生成的令牌
     */
    private Integer maxBurstSeconds;
    /**
     * 分布式同步锁
     */
    private DistributedLock syncLock;

    public RateLimiter(String key, Double permitsPerSecond, Integer maxBurstSeconds, DistributedLock syncLock){
        this.key = key;
        this.lockKey = "DISTRIBUTED_LOCK:" + key;
        this.permitsPerSecond = permitsPerSecond;
        this.maxBurstSeconds = maxBurstSeconds;
        this.syncLock = syncLock;
    }

    /**
     * 生成并存储默认令牌桶
     * @return
     */
    private RedisPermits putDefaultPermits() {
        this.lock();
        try{
            Object obj = RedisUtils.select().getValue(key);
            if( null == obj ){
                RedisPermits permits = new RedisPermits(permitsPerSecond, maxBurstSeconds);
                RedisUtils.select().addValue(key, permits, permits.expires(), TimeUnit.SECONDS);
                return permits;
            }else{
                return (RedisPermits) obj;
            }
        }finally {
            this.unlock();
        }

    }

    /**
     * 加锁
     */
    private void lock(){
        syncLock.lock(lockKey);
    }

    /**
     * 解锁
     */
    private void unlock(){
        syncLock.unLock(lockKey);
    }

    /**
     * 获取令牌桶
     * @return
     */
    public RedisPermits getPermits() {
        Object obj = RedisUtils.select().getValue(key);
        if( null == obj ){
            return putDefaultPermits();
        }
        return (RedisPermits) obj;
    }

    /**
     * 更新令牌桶
     * @param permits
     */
    public void setPermits(RedisPermits permits) {
        RedisUtils.select().addValue(key, permits, permits.expires(), TimeUnit.SECONDS);
    }

    /**
     * 等待直到获取指定数量的令牌
     * @param tokens
     * @return
     * @throws InterruptedException
     */
    public Long acquire(Long tokens) throws InterruptedException {
        long milliToWait = this.reserve(tokens);
        log.info("acquire for {}ms {}", milliToWait, Thread.currentThread().getName());
        Thread.sleep(milliToWait);
        return milliToWait;
    }

    /**
     * 获取1一个令牌
     * @return
     * @throws InterruptedException
     */
    private long acquire() throws InterruptedException{
        return acquire(1L);
    }

    /**
     *
     * @param tokens 要获取的令牌数
     * @param timeout 获取令牌等待的时间,负数被视为0
     * @param unit
     * @return
     * @throws InterruptedException
     */
    private Boolean tryAcquire(Long tokens, Long timeout, TimeUnit unit) throws InterruptedException{
        long timeoutMicros = Math.max(unit.toMillis(timeout), 0);
        checkTokens(tokens);
        Long milliToWait;
        try {
            this.lock();
            if (!this.canAcquire(tokens, timeoutMicros)) {
                return false;
            } else {
                milliToWait = this.reserveAndGetWaitLength(tokens);
            }
        } finally {
            this.unlock();
        }
        Thread.sleep(milliToWait);
        return true;
    }

    /**
     * 获取一个令牌
     * @param timeout
     * @param unit
     * @return
     * @throws InterruptedException
     */
    private Boolean tryAcquire(Long timeout , TimeUnit unit) throws InterruptedException{
        return tryAcquire(1L,timeout, unit);
    }

    private long redisNow(){
        Long time = RedisUtils.select().currentTime();
        return null == time ? System.currentTimeMillis() : time;
    }

    /**
     * 获取令牌n个需要等待的时间
     * @param tokens
     * @return
     */
    private long reserve(Long tokens) {
        this.checkTokens(tokens);
        try {
            this.lock();
            return this.reserveAndGetWaitLength(tokens);
        } finally {
            this.unlock();
        }
    }

    /**
     * 校验token值
     * @param tokens
     */
    private void checkTokens(Long tokens) {
        if( tokens < 0 ){
            throw new IllegalArgumentException("Requested tokens " + tokens + " must be positive");
        }
    }

    /**
     * 在等待的时间内是否可以获取到令牌
     * @param tokens
     * @param timeoutMillis
     * @return
     */
    private Boolean canAcquire(Long tokens, Long timeoutMillis){
        return queryEarliestAvailable(tokens) - timeoutMillis <= 0;
    }

    /**
     * 返回获取{tokens}个令牌最早可用的时间
     * @param tokens
     * @return
     */
    private Long queryEarliestAvailable(Long tokens){
        long n = redisNow();
        RedisPermits permit = this.getPermits();
        permit.reSync(n);
        // 可以消耗的令牌数
        long storedPermitsToSpend = Math.min(tokens, permit.getStoredPermits());
        // 需要等待的令牌数
        long freshPermits = tokens - storedPermitsToSpend;
        // 需要等待的时间
        long waitMillis = freshPermits * permit.getIntervalMillis();
        return LongMath.saturatedAdd(permit.getNextFreeTicketMillis() - n, waitMillis);
    }

    /**
     * 预定@{tokens}个令牌并返回所需要等待的时间
     * @param tokens
     * @return
     */
    private Long reserveAndGetWaitLength(Long tokens){
        long n = redisNow();
        RedisPermits permit = this.getPermits();
        permit.reSync(n);
        // 可以消耗的令牌数
        long storedPermitsToSpend = Math.min(tokens, permit.getStoredPermits());
        // 需要等待的令牌数
        long freshPermits = tokens - storedPermitsToSpend;
        // 需要等待的时间
        long waitMillis = freshPermits * permit.getIntervalMillis();
        permit.setNextFreeTicketMillis(LongMath.saturatedAdd(permit.getNextFreeTicketMillis(), waitMillis));
        permit.setStoredPermits( permit.getStoredPermits() - storedPermitsToSpend );
        this.setPermits(permit);
        return permit.getNextFreeTicketMillis() - n;
    }
}
令牌桶限流器工厂
/**
 * 令牌桶限流器工厂
 * @author: Meng.Liu
 * @date: 2018/11/12 下午4:26
 */
@Component
@Slf4j
@ConditionalOnBean(DistributedLock.class)
public class RateLimiterFactory {

    @Autowired
    private DistributedLock distributedLock;

    /**
     * 本地持有对象
     */
    private volatile Map rateLimiterMap = new ConcurrentHashMap<>();

    /**
     * @param key              redis key
     * @param permitsPerSecond 每秒产生的令牌数
     * @param maxBurstSeconds  最大存储多少秒的令牌
     * @return
     */
    public RateLimiter build(String key, Double permitsPerSecond, Integer maxBurstSeconds) {
        if (!rateLimiterMap.containsKey(key)) {
            synchronized (this) {
                if (!rateLimiterMap.containsKey(key)) {
                    rateLimiterMap.put(key, new RateLimiter(key, permitsPerSecond, maxBurstSeconds, distributedLock));
                }
            }
        }
        return rateLimiterMap.get(key);
    }
}

核心方法介绍

修饰符和类型 方法和描述
double acquire() 从RateLimiter获取一个许可,该方法会被阻塞直到获取到请求
double acquire(int permits) 从RateLimiter获取指定许可数,该方法会被阻塞直到获取到请求
boolean tryAcquire(int permits, long timeout, TimeUnit unit) 从RateLimiter 获取指定许可数如果该许可数可以在不超过timeout的时间内获取得到的话,或者如果无法在timeout 过期之前获取得到许可数的话,那么立即返回false (无需等待)
boolean tryAcquire(long timeout, TimeUnit unit) 从RateLimiter 获取许可如果该许可可以在不超过timeout的时间内获取得到的话,或者如果无法在timeout 过期之前获取得到许可的话,那么立即返回false(无需等待)

你可能感兴趣的:(Redis,互联网架构)