限制API接口访问速率

文章目录

  • 依赖
  • 注解
  • aop
  • helper
  • Test

免责声明:本人无意侵权,奈何找不到原文作者,也找不到网址,于是自己记录一下,如果有侵权之嫌,请联系我删除文章

依赖

  <!-- https://mvnrepository.com/artifact/com.google.guava/guava -->
        <dependency>
            <groupId>com.google.guava</groupId>
            <artifactId>guava</artifactId>
            <version>32.1.3-jre</version>
        </dependency>

注解

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;


@Target(value = ElementType.METHOD)
@Retention(value = RetentionPolicy.RUNTIME)
public @interface RateConfigAnno {

    String limitType();

    double limitCount() default 5d;
}

aop

import cn.hutool.core.thread.ThreadUtil;
import com.alibaba.fastjson2.JSONObject;
import com.google.common.util.concurrent.RateLimiter;
import com.tjbchtyw.tjflowcontrol.annocation.RateConfigAnno;
import com.tjbchtyw.tjflowcontrol.helper.RateLimitHelper;
import jakarta.servlet.ServletOutputStream;
import jakarta.servlet.http.HttpServletResponse;
import org.aspectj.lang.JoinPoint;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Before;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Component;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;

import java.io.IOException;
import java.lang.reflect.Method;
import java.nio.charset.StandardCharsets;
import java.util.Objects;

@Aspect
@Component
public class GuavaLimitAop {

    private static final Logger logger = LoggerFactory.getLogger(GuavaLimitAop.class);

    @Before("@annotation(com.tjbchtyw.tjflowcontrol.annocation.RateConfigAnno)")
    public void limit(JoinPoint joinPoint) {
        //1、获取当前的调用方法
        Method currentMethod = getCurrentMethod(joinPoint);
        if (Objects.isNull(currentMethod)) {
            return;
        }
        //2、从方法注解定义上获取限流的类型
        String limitType = currentMethod.getAnnotation(RateConfigAnno.class).limitType();
        double limitCount = currentMethod.getAnnotation(RateConfigAnno.class).limitCount();
        //使用guava的令牌桶算法获取一个令牌,获取不到先等待
        RateLimiter rateLimiter = RateLimitHelper.getRateLimiter(limitType, limitCount);
      //  boolean b =true;
        boolean pass = rateLimiter.tryAcquire();
        if (pass) {
            System.out.println("获取到令牌");
        }else {
            //重试 仅测试用 有优化方案可以放在评论区
            for (int i = 0; i < 5; i++) {
                ThreadUtil.safeSleep(1000);
                System.out.println("第" + (i + 1)  + "次获取令牌");
                if(rateLimiter.tryAcquire()) break;
                if(i == 4){
                    System.out.println("在第" + (i + 1) +"次后未获取到令牌 开始限流");
                    HttpServletResponse resp = ((ServletRequestAttributes)RequestContextHolder.getRequestAttributes()).getResponse();
                    JSONObject jsonObject=new JSONObject();
                    jsonObject.put("success",false);
                    jsonObject.put("msg","限流中");
                    try {
                        output(resp, jsonObject.toJSONString());
                    }catch (Exception e){
                        logger.error("error,e:{}",e);
                    }
                }
            }

        }
    }

    private Method getCurrentMethod(JoinPoint joinPoint) {
        Method[] methods = joinPoint.getTarget().getClass().getMethods();
        Method target = null;
        for (Method method : methods) {
            if (method.getName().equals(joinPoint.getSignature().getName())) {
                target = method;
                break;
            }
        }
        return target;
    }

    public void output(HttpServletResponse response, String msg) throws IOException {
        response.setContentType("application/json;charset=UTF-8");
        ServletOutputStream outputStream = null;
        try {
            outputStream = response.getOutputStream();
            outputStream.write(msg.getBytes(StandardCharsets.UTF_8));
        } catch (IOException e) {
            e.printStackTrace();
        } finally {
            outputStream.flush();
            outputStream.close();
        }
    }
}

helper

import com.google.common.util.concurrent.RateLimiter;

import java.util.HashMap;
import java.util.Map;

public class RateLimitHelper {

    private RateLimitHelper(){}

    private static final Map<String,RateLimiter> rateMap = new HashMap<>();

    public static RateLimiter getRateLimiter(String limitType, double limitCount ){
        RateLimiter rateLimiter = rateMap.get(limitType);
        if(rateLimiter == null){
            rateLimiter = RateLimiter.create(limitCount);
            rateMap.put(limitType,rateLimiter);
        }
        return rateLimiter;
    }

}

Test

@RestController
    @Tag(name = "测试Controller", description = "这是描述")
    @RequestMapping("/pred-api")
    public class FlowController {
        @Autowired
        FlowContext flowContext;
    
        @PostMapping("/pdf/test")
        //Count 限制次数
        @RateConfigAnno(limitType = "makePdf",limitCount = 15)
        @Operation(summary = "限流接口")
        public String flowCont(@Parameter(name = "pdfParam", description = "参数对象,type标识执行不同的策略") @RequestBody Param Param) {
         
            return "test";
        }
    
    }

你可能感兴趣的:(java)