场景:限制请求后端接口的频率,例如1秒钟只能请求次数不能超过10次,通常的写法是:
1.先去从redis里面拿到当前请求次数
2.判断当前次数是否大于或等于限制次数
3.当前请求次数小于限制次数时进行自增
这三步在请求不是很密集的时候,程序执行很快,可能不会产生问题,如果两个请求几乎在同一时刻到来,我们第1步和第2步的判断是无法保证原子性的。
改进方式:使用redis的lua脚本,将"读取值、判断大小、自增"放到redis的一次操作中,redis底层所有的操作请求都是串行的,也就是一个请求执行完,才会执行下一个请求。
自增的lua脚本如下
/**
* 自增过期时间的原子性脚本
*/
private String maxCountScriptText() {
return "local key = KEYS[1]\n" +
"local count = tonumber(ARGV[1])\n" +
"local time = tonumber(ARGV[2])\n" +
"local current = redis.call('get', key);\n" +
"if current and tonumber(current) > count then\n" +
" return tonumber(current);\n" +
"end\n" +
"current = redis.call('incr', key)\n" +
"if tonumber(current) == 1 then\n" +
" redis.call('expire', key, time)\n" +
"end\n" +
"return tonumber(current);";
}
将接口限流功能封装成一个注解@RateLimiter,在接口方法上面加上@RateLimiter就可以实现限流:
redis工具类:
package com.zhou.redis.util;
import com.zhou.redis.dto.MyRedisMessage;
import com.zhou.redis.exception.LockException;
import com.zhou.redis.script.MaxCountQueryScript;
import com.zhou.redis.script.MaxCountScript;
import lombok.extern.slf4j.Slf4j;
import org.springframework.context.annotation.Configuration;
import org.springframework.data.redis.core.HashOperations;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.ValueOperations;
import java.time.Duration;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.TimeUnit;
@Configuration
@Slf4j
public class RedisUtil {
public RedisTemplate redisTemplate;
private MaxCountScript maxCountScript;
private MaxCountQueryScript maxCountQueryScript;
public RedisUtil(RedisTemplate redisTemplate, MaxCountScript maxCountScript, MaxCountQueryScript maxCountQueryScript) {
this.redisTemplate = redisTemplate;
this.maxCountScript = maxCountScript;
this.maxCountQueryScript = maxCountQueryScript;
}
/**
* 尝试加锁,返回加锁成功或者失败
* @param time 秒
**/
public boolean tryLock(String key,Object value,Long time){
if(time == null || time <= 0){
time = 30L;
}
Boolean b = redisTemplate.opsForValue().setIfAbsent(key, value, Duration.ofSeconds(time));
return b == null ? false : b;
}
/**
* 释放锁(拿到锁之后才能调用释放锁)
**/
public boolean unLock(String key){
Boolean b = redisTemplate.delete(key);
return b == null ? false : b;
}
/**
* 对key进行自增1
* @param maxCount 最大值
* @param time 增加次数
* @return 自增后的值
*/
public Long incr(String key,int maxCount, int time){
List keys = Collections.singletonList(key);
return redisTemplate.execute(maxCountScript, keys, maxCount, time);
}
/**
* 获得当前值
*/
public Long incrNow(String key){
List keys = Collections.singletonList(key);
return redisTemplate.execute(maxCountQueryScript, keys);
}
}
redis配置类:
package com.zhou.redis.config;
import com.fasterxml.jackson.annotation.JsonAutoDetect;
import com.fasterxml.jackson.annotation.PropertyAccessor;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.zhou.redis.listener.MyRedisListener;
import com.zhou.redis.script.MaxCountQueryScript;
import com.zhou.redis.script.MaxCountScript;
import com.zhou.redis.util.RedisTopic;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.data.redis.connection.RedisConnectionFactory;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.listener.PatternTopic;
import org.springframework.data.redis.listener.RedisMessageListenerContainer;
import org.springframework.data.redis.listener.Topic;
import org.springframework.data.redis.listener.adapter.MessageListenerAdapter;
import org.springframework.data.redis.serializer.Jackson2JsonRedisSerializer;
import org.springframework.data.redis.serializer.StringRedisSerializer;
import java.util.Arrays;
import java.util.List;
@Configuration
public class RedisConfig {
@SuppressWarnings("all")
@Bean
public RedisTemplate redisTemplate(RedisConnectionFactory factory) {
RedisTemplate template = new RedisTemplate<>();
template.setConnectionFactory(factory);
//Json序列化配置
Jackson2JsonRedisSerializer
拦截模式枚举类:根据ip拦截或者方法拦截
package com.zhou.aop;
/**
* @author lang.zhou
* @since 2023/1/31 17:56
*/
public enum LimitType {
IP,DEFAULT
}
封装自定义注解:@RateLimiter
package com.zhou.aop;
import java.lang.annotation.*;
/**
* @author lang.zhou
* @since 2023/1/31 17:49
*/
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface RateLimiter {
/**
* 限流key
*/
String key() default "RateLimiter";
/**
* 限流时间,单位秒
*/
int time() default 60;
/**
* 限流次数
*/
int count() default 100;
/**
* 限流类型
*/
LimitType limitType() default LimitType.DEFAULT;
/**
* 限流后返回的文字
*/
String limitMsg() default "访问过于频繁,请稍候再试";
}
注解的切面逻辑:
package com.zhou.aop;
import com.zhou.redis.util.RedisUtil;
import com.zhou.common.utils.IpUtil;
import lombok.extern.slf4j.Slf4j;
import org.aspectj.lang.JoinPoint;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Before;
import org.aspectj.lang.reflect.MethodSignature;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;
import javax.servlet.http.HttpServletRequest;
import java.lang.reflect.Method;
/**
* 接口限流切面
* @author lang.zhou
* @since 2023/1/31 17:50
*/
@Aspect
@Slf4j
@Component
public class RateLimiterAspect {
@Autowired
private RedisUtil redisUtils;
@Before("@annotation(rateLimiter)")
public void doBefore(JoinPoint point, RateLimiter rateLimiter) {
int time = rateLimiter.time();
int count = rateLimiter.count();
String combineKey = getCombineKey(rateLimiter, point);
try {
Long number = redisUtils.incr(combineKey, count, time);
if (number == null || number.intValue() > count){
log.info("请求【{}】被拦截,{}秒内请求次数{}",combineKey,time,number);
throw new RuntimeException(rateLimiter.limitMsg());
}
} catch (ServiceRuntimeException e) {
throw e;
} catch (Exception e) {
throw new RuntimeException("网络繁忙,请稍候再试");
}
}
/**
* 获取限流key
*/
public String getCombineKey(RateLimiter rateLimiter, JoinPoint point) {
StringBuilder s = new StringBuilder(rateLimiter.key());
ServletRequestAttributes requestAttributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
if(requestAttributes != null){
HttpServletRequest request = requestAttributes.getRequest();
if (rateLimiter.limitType() == LimitType.IP) {
s.append(IpUtil.getIpAddr(request)).append("-");
}
}
MethodSignature signature = (MethodSignature) point.getSignature();
Method method = signature.getMethod();
Class> targetClass = method.getDeclaringClass();
s.append(targetClass.getName()).append(".").append(method.getName());
return s.toString();
}
}
lua自增脚本类:
package com.zhou.redis.script;
import org.springframework.data.redis.core.script.DefaultRedisScript;
/**
* @author lang.zhou
* @since 2023/2/25
*/
public class MaxCountScript extends DefaultRedisScript {
public MaxCountScript(String script) {
super(script,Long.class);
}
}
lua查询当前值的脚本类:
package com.zhou.redis.script;
import org.springframework.data.redis.core.script.DefaultRedisScript;
/**
* @author lang.zhou
* @since 2023/2/25
*/
public class MaxCountQueryScript extends DefaultRedisScript {
public MaxCountQueryScript(String script) {
super(script,Long.class);
}
}
订阅消息通道的枚举:
package com.zhou.redis.util;
public class RedisTopic {
public static final String TOPIC1 = "TOPIC1";
public static final String TOPIC2 = "TOPIC2";
}
消息实体类:
package com.zhou.redis.dto;
import lombok.Data;
import java.io.Serializable;
/**
* redis订阅消息实体
* @since 2022/11/11 17:34
*/
@Data
public class MyRedisMessage implements Serializable {
private String msg;
}
订阅消息监听器:
package com.zhou.redis.listener;
import com.zhou.redis.dto.MyRedisMessage;
import com.zhou.redis.util.RedisTopic;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.redis.connection.Message;
import org.springframework.data.redis.connection.MessageListener;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.stereotype.Component;
import javax.script.ScriptException;
/**
* @author lang.zhou
*/
@Slf4j
@Component
public class MyRedisListener implements MessageListener {
@Autowired
private RedisTemplate redisTemplate;
@Override
public void onMessage(Message message, byte[] pattern) {
String topic = new String(pattern);
// 接收的topic
log.info("channel:{}" , topic);
if(RedisTopic.TOPIC1.equals(topic)){
//
}else if(RedisTopic.TOPIC2.equals(topic)){
//序列化对象(特别注意:发布的时候需要设置序列化;订阅方也需要设置序列化)
MyRedisMessage msg = (MyRedisMessage) redisTemplate.getValueSerializer().deserialize(message.getBody());
log.info("message:{}",msg);
}
}
}
注解使用方式:1秒内一个ip最多只能请求10次
@RestController
@RequestMapping("/test/api")
public class CheckController{
@PostMapping("/limit")
@RateLimiter(time = 1, count = 10, limitType = LimitType.IP, limitMsg = "请求过于频繁,请稍后重试")
public void limit(HttpServletRequest request){
//执行业务代码
}
}