springboot整合ratelimiter+redis+lua实现限流

依赖

<dependency>
    <groupId>org.springframework.boot</groupId>
    <artifactId>spring-boot-starter-data-redis</artifactId>
</dependency>
<dependency>
    <groupId>org.springframework.boot</groupId>
    <artifactId>spring-boot-starter-web</artifactId>
</dependency>
<dependency>
    <groupId>org.springframework.boot</groupId>
    <artifactId>spring-boot-starter-aop</artifactId>
</dependency>

限流注解

/**
 * @author jigua
 * @version 1.0
 * @className LimitType
 * @description
 * @create 2022/6/17 14:09
 */
public enum LimitType {
    /**
     * 默认策略全局限流
     */
    DEFAULT,

    /**
     * 根据IP地址限流
     */
    IP
}

import java.lang.annotation.*;

/**
 * @author jigua
 * @version 1.0
 * @className RateLimiter
 * @description
 * @create 2022/6/17 14:10
 */
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface RateLimiter {

    /**
     * 限流key
     */
    String key() default "rate_limit:";


    /**
     * 限流时间,单位秒
     */
    int time() default 30;

    /**
     * 限流次数
     */
    int count() default 300;

    /**
     * 限流类型
     */
    LimitType limitType() default LimitType.DEFAULT;
}

RedisTemplate配置

import com.fasterxml.jackson.annotation.JsonAutoDetect;
import com.fasterxml.jackson.annotation.PropertyAccessor;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.core.io.ClassPathResource;
import org.springframework.data.redis.connection.RedisConnectionFactory;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.script.DefaultRedisScript;
import org.springframework.data.redis.serializer.Jackson2JsonRedisSerializer;
import org.springframework.scripting.support.ResourceScriptSource;

/**
 * @author jigua
 * @version 1.0
 * @className RedisConfig
 * @description 修改 RedisTemplate 序列化方案
 * @create 2022/6/17 14:12
 */
@Configuration
public class RedisConfig {
    @Bean
    public RedisTemplate<Object, Object> redisTemplate(RedisConnectionFactory connectionFactory) {
        RedisTemplate<Object, Object> redisTemplate = new RedisTemplate<>();
        redisTemplate.setConnectionFactory(connectionFactory);
        // 使用Jackson2JsonRedisSerialize 替换默认序列化(默认采用的是JDK序列化)
        //JdkSerializationRedisSerializer序列化时有可能会出现一些莫名其妙的前缀导致读取失败
        Jackson2JsonRedisSerializer<Object> jackson2JsonRedisSerializer = new Jackson2JsonRedisSerializer<>(Object.class);
        ObjectMapper om = new ObjectMapper();
        om.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.ANY);
        om.enableDefaultTyping(ObjectMapper.DefaultTyping.NON_FINAL);
        jackson2JsonRedisSerializer.setObjectMapper(om);
        redisTemplate.setKeySerializer(jackson2JsonRedisSerializer);
        redisTemplate.setValueSerializer(jackson2JsonRedisSerializer);
        redisTemplate.setHashKeySerializer(jackson2JsonRedisSerializer);
        redisTemplate.setHashValueSerializer(jackson2JsonRedisSerializer);
        System.out.println("redis加载完成");
        return redisTemplate;
    }

    @Bean
    public DefaultRedisScript<Long> limitScript() {
        DefaultRedisScript<Long> redisScript = new DefaultRedisScript<>();
        redisScript.setScriptSource(new ResourceScriptSource(new ClassPathResource("lua/limit.lua")));
        redisScript.setResultType(Long.class);
        System.out.println("lua脚本加载完成");
        return redisScript;
    }
}

lua脚本

local key = KEYS[1]
local count = tonumber(ARGV[1])
local time = tonumber(ARGV[2])
local current = redis.call('get', key)
if current and tonumber(current) > count then
    return tonumber(current)
end
current = redis.call('incr', key)
if tonumber(current) == 1 then
    redis.call('expire', key, time)
end
return tonumber(current)

使用lua脚本的两种思路

  • 在 Redis 服务端定义好 Lua 脚本,然后计算出来一个散列值,在 Java 代码中,通过这个散列值锁定要执行哪个 Lua 脚本
  • 直接在 Java 代码中将 Lua 脚本定义好,然后发送到 Redis 服务端去执行。(本文使用)

脚本解释

  • 首先获取到传进来的 key 以及 限流的 count 和时间 time
  • 通过 get 获取到这个 key 对应的值,这个值就是当前时间窗内这个接口可以访问多少次
  • 如果是第一次访问,此时拿到的结果为 nil,否则拿到的结果应该是一个数字,所以接下来就判断,如果拿到的结果是一个数字,并且这个数字还大于 count,那就说明已经超过流量限制了,那么直接返回查询的结果即可
  • 如果拿到的结果为 nil,说明是第一次访问,此时就给当前 key 自增 1,然后设置一个过期时间
  • 最后把自增 1 后的值返回就可以了

自定义切面

import com.example.huibaozi.util.IpUtil;
import com.google.protobuf.ServiceException;
import org.aspectj.lang.JoinPoint;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Before;
import org.aspectj.lang.reflect.MethodSignature;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.script.RedisScript;
import org.springframework.stereotype.Component;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;

import java.lang.reflect.Method;
import java.util.Collections;
import java.util.List;

/**
 * @author jigua
 * @version 1.0
 * @className RateLimiterAspect
 * @description
 * @create 2022/6/17 14:22
 */
@Aspect
@Component
public class RateLimiterAspect {
    private static final Logger log = LoggerFactory.getLogger(RateLimiterAspect.class);

    @Autowired
    private RedisTemplate<Object, Object> redisTemplate;

    @Autowired
    private RedisScript<Long> limitScript;

    @Before("@annotation(rateLimiter)")
    public void doBefore(JoinPoint point, RateLimiter rateLimiter) throws Throwable {
        String key = rateLimiter.key();
        int time = rateLimiter.time();
        int count = rateLimiter.count();

        String combineKey = getCombineKey(rateLimiter, point);
        List<Object> keys = Collections.singletonList(combineKey);
        try {
            Long number = redisTemplate.execute(limitScript, keys, count, time);
            if (number == null || number.intValue() > count) {
                throw new ServiceException("访问过于频繁,请稍候再试");
            }
            log.info("限制请求'{}',当前请求'{}',缓存key'{}'", count, number.intValue(), key);
        } catch (ServiceException e) {
            throw e;
        } catch (Exception e) {
            throw new RuntimeException("服务器限流异常,请稍候再试");
        }
    }

    public String getCombineKey(RateLimiter rateLimiter, JoinPoint point) {
        StringBuffer stringBuffer = new StringBuffer(rateLimiter.key());
        if (rateLimiter.limitType() == LimitType.IP) {
            stringBuffer.append(IpUtil.getIpAddr(((ServletRequestAttributes) RequestContextHolder.currentRequestAttributes()).getRequest())).append("-");
        }
        MethodSignature signature = (MethodSignature) point.getSignature();
        Method method = signature.getMethod();
        Class<?> targetClass = method.getDeclaringClass();
        stringBuffer.append(targetClass.getName()).append("-").append(method.getName());
        return stringBuffer.toString();
    }
}

全局异常拦截器

import com.google.protobuf.ServiceException;
import org.springframework.web.bind.annotation.ExceptionHandler;
import org.springframework.web.bind.annotation.RestControllerAdvice;

import java.util.HashMap;
import java.util.Map;
/**
 * @author jigua
 * @version 1.0
 * @className GlobalException
 * @description
 * @create 2022/6/17 14:27
 */
@RestControllerAdvice
public class GlobalException {

    @ExceptionHandler(ServiceException.class)
    public Map<String,Object> serviceException(ServiceException e) {
        Map<String,Object> map = new HashMap<>(2);
        map.put("status",500);
        map.put("message","访问人数过多");
        return map;
    }
}

测试

基于ip的测试

    @GetMapping("/testIp")
    @ResponseBody
    @RateLimiter(time = 50, count = 3, limitType = LimitType.IP)
    public String testIp() throws Exception {
        return "success";
    }

在这里插入图片描述
在这里插入图片描述
测试时有这样的情况不要慌,把请求路径中的localhost替换成ip就好了

springboot整合ratelimiter+redis+lua实现限流_第1张图片
在这里插入图片描述
在这里插入图片描述

基于全局的测试

    @GetMapping("/testService")
    @ResponseBody
 
    @RateLimiter(time = 50, count = 3, limitType = LimitType.DEFAULT)
    public JSONObject testService() throws Exception {
        JSONObject list = new JSONObject();
        list.put("1", 1);
        return list;
    }

在这里插入图片描述
springboot整合ratelimiter+redis+lua实现限流_第2张图片
springboot整合ratelimiter+redis+lua实现限流_第3张图片

你可能感兴趣的:(spring,springboot,redis,spring,boot,lua)