使用 Guava 中的 RateLimiter
进行限流, RateLimiter
是基于令牌桶实现的
引入依赖
<dependency>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
<version>29.0-jre</version>
</dependency>
测试
void test() {
RateLimiter rateLimiter=RateLimiter.create(1);
for (int i = 0; i <3; i++) {
int num = 2 * i + 1;
log.info("获取{}个令牌", num);
double cost = rateLimiter.acquire(num);
log.info("获取{}个令牌结束,耗时{}ms",num,cost);
}
}
限流注解
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
public @interface Limiting {
// 默认每秒放入桶中的token
double limitNum() default 20;
String name() default "";
}
aop 切面类
/**
* @author 凉了的凉茶
* @date by 2023/9/18
* @des
*/
@Aspect
@Component
@Slf4j
public class RateLimitAspect {
private ConcurrentHashMap<String, RateLimiter> RATE_LIMITER = new ConcurrentHashMap<>();
private RateLimiter rateLimiter;
@Pointcut("@annotation(top.cymgc.project.aop.Limiting)")
public void serviceLimit() {
}
@Around("serviceLimit()")
public Object around(ProceedingJoinPoint point) throws Throwable {
//获取拦截的方法名
Signature sig = point.getSignature();
//获取拦截的方法名
MethodSignature msig = (MethodSignature) sig;
//返回被织入增加处理目标对象
Object target = point.getTarget();
//为了获取注解信息
Method currentMethod = target.getClass().getMethod(msig.getName(), msig.getParameterTypes());
//获取注解信息
Limiting annotation = currentMethod.getAnnotation(Limiting.class);
//获取注解每秒加入桶中的 token
double limitNum = annotation.limitNum();
// 注解所在方法名区分不同的限流策略
String functionName = msig.getName();
if (RATE_LIMITER.containsKey(functionName)) {
rateLimiter = RATE_LIMITER.get(functionName);
} else {
RATE_LIMITER.put(functionName, RateLimiter.create(limitNum));
rateLimiter = RATE_LIMITER.get(functionName);
}
if (rateLimiter.tryAcquire()) {
log.info("处理完成");
return point.proceed();
} else {
throw new RuntimeException("服务器繁忙,请稍后再试。");
}
}
}
使用
@CrossOrigin
@RestController
@RequestMapping("/auth")
public class Auth {
@Limiting(limitNum = 30, name = "authDemo")
@GetMapping("/demo")
public String authDemo() {
return "ok";
}
}
结果
ip 工具类
public class IPUtil {
/**
* 获取客户端IP
*
* @param request 请求对象
* @return IP地址
*/
public static String getIpAddr(HttpServletRequest request) {
if (request == null) {
return "unknown";
}
String ip = request.getHeader("x-forwarded-for");
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
ip = request.getHeader("Proxy-Client-IP");
}
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
ip = request.getHeader("X-Forwarded-For");
}
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
ip = request.getHeader("WL-Proxy-Client-IP");
}
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
ip = request.getHeader("X-Real-IP");
}
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
ip = request.getRemoteAddr();
}
return "0:0:0:0:0:0:0:1".equals(ip) ? "127.0.0.1" : getMultistageReverseProxyIp(ip);
}
/**
* 从多级反向代理中获得第一个非unknown IP地址
*
* @param ip 获得的IP地址
* @return 第一个非unknown IP地址
*/
public static String getMultistageReverseProxyIp(String ip) {
// 多级反向代理检测
if (ip != null && ip.indexOf(",") > 0) {
final String[] ips = ip.trim().split(",");
for (String subIp : ips) {
if (false == isUnknown(subIp)) {
ip = subIp;
break;
}
}
}
return ip;
}
/**
* 检测给定字符串是否为未知,多用于检测HTTP请求相关
*
* @param checkString 被检测的字符串
* @return 是否未知
*/
public static boolean isUnknown(String checkString) {
return StringUtils.isBlank(checkString) || "unknown".equalsIgnoreCase(checkString);
}
}
引入依赖
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-data-redis</artifactId>
</dependency>
在 application.yml
中配置 redis
参数
spring:
# redis 配置
redis:
port: 6379
host: localhost
database: 0
password: redis 密码
注解 + AOP 实现
限流注解
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
public @interface RateLimiter {
/**
* 限流时间,单位秒
*/
int time() default 5;
/**
* 限流次数
*/
int count() default 10;
}
定义 aop 切面
@Slf4j
@Aspect
@Component
public class RateLimiterAspect {
@Autowired
private RedisTemplate redisTemplate;
/**
* 实现限流(新思路)
* @param point
* @param rateLimiter
* @throws Throwable
*/
@SuppressWarnings("unchecked")
@Before("@annotation(rateLimiter)")
public void doBefore(JoinPoint point, RateLimiter rateLimiter) throws Throwable {
// 在 {time} 秒内仅允许访问 {count} 次。
int time = rateLimiter.time();
int count = rateLimiter.count();
// 根据用户IP(可选)和接口方法,构造key
String combineKey = getCombineKey(point);
// 限流逻辑实现
ZSetOperations zSetOperations = redisTemplate.opsForZSet();
// 记录本次访问的时间结点
long currentMs = System.currentTimeMillis();
zSetOperations.add(combineKey, currentMs, currentMs);
// 这一步是为了防止member一直存在于内存中
redisTemplate.expire(combineKey, time, TimeUnit.SECONDS);
// 移除{time}秒之前的访问记录(滑动窗口思想)
zSetOperations.removeRangeByScore(combineKey, 0, currentMs - time * 1000);
// 获得当前窗口内的访问记录数
Long currCount = zSetOperations.zCard(combineKey);
// 限流判断
if (currCount > count) {
log.error("[limit] 限制请求数'{}',当前请求数'{}',缓存key'{}'", count, currCount, combineKey);
throw new RuntimeException("访问过于频繁,请稍后再试!");
}
}
/**
* 把用户IP和接口方法名拼接成 redis 的 key
* @param point 切入点
* @return 组合key
*/
private String getCombineKey(JoinPoint point) {
StringBuilder sb = new StringBuilder("rate_limit:");
ServletRequestAttributes attributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
HttpServletRequest request = attributes.getRequest();
sb.append(IPUtil.getIpAddr(request) );
MethodSignature signature = (MethodSignature) point.getSignature();
Method method = signature.getMethod();
Class<?> targetClass = method.getDeclaringClass();
// keyPrefix + "-" + class + "-" + method
return sb.append("-").append( targetClass.getName() )
.append("-").append(method.getName()).toString();
}
}
使用
@CrossOrigin
@RestController
@RequestMapping("/auth")
public class Auth {
@RateLimiter(time = 5,count = 2)
@GetMapping("/demo")
public String authDemo() {
return "ok";
}
}