Spring Boot | 使用Redis实现接口限流

表单重复提交的原因笔者遇到过如下几种:

(1)经费有限,无力更换有故障的鼠标,单击鼠标左键时,实际上触发了多次点击;
(2)服务响应慢或者网络卡顿,用户情绪暴躁,疯狂多次点击提交按钮;
(3)测试工程师炫耀手速故意快速多次点击提交按钮。

解决表单重复提交的方式有多种,前端可以解决,后端也可以解决,本篇文章提供了一种笔者甄选出来的自认为比较优雅的后端解决方案。

创建项目

我们创建一个 Spring Boot 项目,引入 Web、Redis、AOP 依赖。

<dependency>
	<groupId>org.springframework.bootgroupId>
	<artifactId>spring-boot-starter-webartifactId>
	<version>2.4.1version>
dependency>
<dependency>
	<groupId>org.springframework.bootgroupId>
	<artifactId>spring-boot-starter-data-redisartifactId>
	<version>2.4.1version>
dependency>
<dependency>
	<groupId>org.springframework.bootgroupId>
	<artifactId>spring-boot-starter-aopartifactId>
	<version>2.4.1version>
dependency>
<dependency>
	<groupId>cn.hutoolgroupId>
	<artifactId>hutool-allartifactId>
	<version>5.5.6version>
dependency>
<dependency>
	<groupId>com.google.code.gsongroupId>
	<artifactId>gsonartifactId>
	<version>2.8.6version>
dependency>

配置Redis

修改application.properties配置文件,添加 Redis 相关配置。

# redis配置
# redis server ip
spring.redis.host=127.0.0.1
# redis server port
spring.redis.port=6379
# redis server password
spring.redis.password=
# 连接超时时间
spring.redis.timeout=5000

创建Lua脚本

resources目录下创建文件scripts\redis\limit.lua,脚本内容如下:

local key = KEYS[1]
local now = tonumber(ARGV[1])
local ttl = tonumber(ARGV[2])
local expired = tonumber(ARGV[3])
local max = tonumber(ARGV[4])

redis.call('zremrangebyscore', key, 0, expired)

local current = tonumber(redis.call('zcard', key))
local next = current + 1

if next > max then
  return 0;
else
  redis.call("zadd", key, now, now)
  redis.call("pexpire", key, ttl)
  return next
end

脚本内容引自 https://www.codetd.com/en/article/12543775

定制Redis

开发 Redis 配置类,设置 Redis 序列化方式,加载 Lua 脚本。

@Configuration
public class RedisConfig {

    @Bean
    public RedisScript<Long> limitRedisScript() {
        DefaultRedisScript redisScript = new DefaultRedisScript<>();
        redisScript.setScriptSource(new ResourceScriptSource(new ClassPathResource("scripts/redis/limit.lua")));
        redisScript.setResultType(Long.class);
        return redisScript;
    }

    @Bean
    public RedisTemplate<String, Object> redisTemplate(RedisConnectionFactory factory) {
        RedisTemplate<String, Object> template = new RedisTemplate<>();
        template.setConnectionFactory(factory);
        RedisSerializer jackson2JsonRedisSerializer = getJacksonSerializer();
        StringRedisSerializer stringRedisSerializer = new StringRedisSerializer();
        // key采用String的序列化方式
        template.setKeySerializer(stringRedisSerializer);
        // hash的key也采用String的序列化方式
        template.setHashKeySerializer(stringRedisSerializer);
        // value序列化方式采用jackson的序列化方式
        template.setValueSerializer(jackson2JsonRedisSerializer);
        // hash的value序列化方式采用jackson
        template.setHashValueSerializer(jackson2JsonRedisSerializer);
        template.afterPropertiesSet();
        return template;
    }

    /**
     * redis的json序列化
     * @return
     */
    private RedisSerializer getJacksonSerializer() {
        ObjectMapper om = new ObjectMapper();
        om.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.ANY);
        om.activateDefaultTyping(LaissezFaireSubTypeValidator.instance, ObjectMapper.DefaultTyping.NON_FINAL);
        return new GenericJackson2JsonRedisSerializer(om);
    }
}

限流注解

开发限流注解,可以设置Redis Key、最大请求次数、限制时间、时间单位。

@Documented
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
public @interface RateLimiter {

    long DEFAULT_REQUEST = 10;

    /**
     * 最大请求次数
     */
    @AliasFor("max")
    long value() default DEFAULT_REQUEST;

    /**
     * 最大请求次数
     */
    @AliasFor("value")
    long max() default DEFAULT_REQUEST;

    /**
     * Redis key值
     */
    String key() default "";

    /**
     * 限制时间,默认1分钟
     */
    long timeout() default 1;

    /**
     * 时间单位,默认分钟
     */
    TimeUnit timeUnit() default TimeUnit.MINUTES;
}

自定义切面

自定义切面解析@RateLimiter注解。

@Slf4j
@Aspect
@Component
@RequiredArgsConstructor(onConstructor_ = @Autowired)
public class LimiteAspect {

    private final static String SEPARATOR = ":";
    private final static String REDIS_LIMIT_KEY_PREFIX = "limit:";
    private final StringRedisTemplate stringRedisTemplate;
    private final RedisScript<Long> limitRedisScript;

    @Pointcut("@annotation(jenny.learn.springboot.redis.annotation.RateLimiter)")
    public void rateLimit() {
    }

    @Around("rateLimit()")
    public Object pointcut(ProceedingJoinPoint point) throws Throwable {
        MethodSignature signature = (MethodSignature) point.getSignature();
        Method method = signature.getMethod();
        // 获取RateLimiter注解
        RateLimiter rateLimiter = AnnotationUtils.findAnnotation(method, RateLimiter.class);
        if (rateLimiter != null) {
            String key = rateLimiter.key();
            // 如果key值为空,key值以类名+方法名为前缀
            if (StrUtil.isBlank(key)) {
                key = method.getDeclaringClass().getName() + StrUtil.DOT + method.getName();
            }
            key = key + SEPARATOR + IpUtil.getIpAddr();

            // 最大请求次数
            long max = rateLimiter.max();
            // 超时时间
            long timeout = rateLimiter.timeout();
            // 超时时间单位
            TimeUnit timeUnit = rateLimiter.timeUnit();
            boolean limited = shouldLimited(key, max, timeout, timeUnit);
            if (limited) {
                throw new RuntimeException("访问过于频繁!");
            }
        }

        return point.proceed();
    }

    /**
     * 是否限制
     * @param key Redis Key
     * @param max 最大请求次数
     * @param timeout 超时时间
     * @param timeUnit 超时时间单位
     * @return 是否限制,限制返回true
     */
    private boolean shouldLimited(String key, long max, long timeout, TimeUnit timeUnit) {
        // 最终的key格式:limit:定制key:ip 或者 limit:类名.方法名:ip
        key = REDIS_LIMIT_KEY_PREFIX + key;
        // 超时时间(毫秒)
        long ttl = timeUnit.toMillis(timeout);
        // 系统当前时间
        long now = Instant.now().toEpochMilli();
        // 过期
        long expired = now - ttl;
        Long executeTimes = stringRedisTemplate.execute(limitRedisScript, Collections.singletonList(key),
                String.valueOf(now), String.valueOf(ttl), String.valueOf(expired), String.valueOf(max));
        if (executeTimes != null) {
            if (executeTimes == 0) {
                log.error("[{}] The access limit has been reached within {} milliseconds per unit time, the current interface limit is {}", key, ttl, max);
                return true;
            } else {
                log.info("[{}] visit {} times in unit time {} milliseconds", key, ttl, executeTimes);
                return false;
            }
        }
        return false;
    }
}

测试接口

开发一个测试接口,1分钟内最多访问两次,超过两次报错。

@Slf4j
@RestController
@RequestMapping
public class TestController {
    @RateLimiter(value = 2)
    @GetMapping("/test1")
    public Dict test1() {
        return Dict.create().set("msg", "hello world.").set("description", "1分钟内最多访问两次,超过两次返回“访问过于频繁!”");
    }
}

全局异常处理器

开发一个全局异常处理器,捕获Controller抛出的异常。

@Slf4j
@RestControllerAdvice
public class GlobalExceptionHandler {
    @ExceptionHandler(RuntimeException.class)
    public Dict handler(RuntimeException ex) {
        return Dict.create().set("msg", ex.getMessage()).set("code", 500);
    }
}

测试

在浏览器地址栏输入接口请求地址http://127.0.0.1:8780/test1,多刷新几次,就会看到我们想要的结果:
在这里插入图片描述

总结

这样我们开发好了,在想要限流的接口上加个@RateLimiter注解就可以了。

你可能感兴趣的:(redis,spring,boot,java)