springboot + redis + lua 实现访问量控制

访问量的控制比较常见,对外提供的服务,有的需要控制QPS,防止服务宕机;有的需要控制一个时间段的访问数量。

本文基于springboot,采用切面+redis的方式实现

  1. 在需要进行访问量控制的地方加入注解;
  2. 在注解操作中,获取当前访问的ip地址,利用redis做计数,超过limit则报错;
  3. 问题的关键在于:在分布式环境下,对redis的操作可能会出现竞争,所以要把对redis的操作使用lua脚本,这样所有的操作是原子性的。

自定义注解:

import org.springframework.core.Ordered;
import org.springframework.core.annotation.Order;

import java.lang.annotation.*;

@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
@Documented
@Order(Ordered.HIGHEST_PRECEDENCE)
public @interface RequestLimit {
    /**
     * 允许访问的最大次数
     */
    int count() default Integer.MAX_VALUE;

    /**
     * 时间段,单位为毫秒,默认值一分钟
     */
    long time() default 60000;
}

切面操作:

import com.example.common.Constants;
import com.example.common.ErrorCode;
import com.example.exception.BusinessException;
import com.example.utils.IpUtils;
import org.aspectj.lang.JoinPoint;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Before;
import org.aspectj.lang.annotation.Pointcut;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.data.redis.core.script.DefaultRedisScript;
import org.springframework.stereotype.Component;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;

import javax.servlet.http.HttpServletRequest;
import java.util.Collections;

@Aspect
@Component
public class RequestLimitAspect {
    @Autowired
    private DefaultRedisScript redisScript;

    @Autowired
    private StringRedisTemplate stringRedisTemplate;

    @Pointcut("@annotation(com.example.aspect.RequestLimit)")
    public void pointcut() {
    }

    @Before("pointcut() && @annotation(requestLimit)")
    public void doBefore(JoinPoint joinPoint, RequestLimit requestLimit) {
        ServletRequestAttributes requestAttributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
        if (null == requestAttributes) {
            return;
        }
        HttpServletRequest httpRequest = requestAttributes.getRequest();

        String ip = IpUtils.getRealIP(httpRequest);
        String key = Constants.KEY_PREFIX + ip;

        Boolean allow = stringRedisTemplate.execute(
                redisScript,
                Collections.singletonList(key),
                String.valueOf(requestLimit.count()), //limit
                String.valueOf(requestLimit.time())); //expire

        assert allow != null;
        if (!allow) {
            throw new BusinessException(ErrorCode.REQUEST_EXCEED_LIMIT);
        }

        return;
    }
}

其中对redis的操作用了一个配置类

import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.core.io.ClassPathResource;
import org.springframework.data.redis.core.script.DefaultRedisScript;
import org.springframework.scripting.support.ResourceScriptSource;

@Configuration
public class LuaRedisConfiguration {

    @Bean
    public DefaultRedisScript redisScript() {
        DefaultRedisScript redisScript = new DefaultRedisScript<>();
        redisScript.setScriptSource(new ResourceScriptSource(new ClassPathResource("script/requestLimit.lua")));
        redisScript.setResultType(Boolean.class);
        return redisScript;
    }
}

lua脚本:

local key = KEYS[1]
local value = 1
local limit = tonumber(ARGV[1])
local expire = ARGV[2]

if redis.call("SET", key, value, "NX", "PX", expire) then
    return 1
else
    if redis.call("INCR", key) <= limit then
        return 1
    end
    if redis.call("TTL", key) == -1 then
        redis.call("PEXPIRE", key, expire)
    end
end
return 0

参考文章:http://www.genxiaogu.com/Springboot-%E9%9B%86%E7%BE%A4QPS%E6%8E%A7%E5%88%B6starter/

你可能感兴趣的:(springboot + redis + lua 实现访问量控制)