基于切面的访问次数限制

对一个某些接口,比如获取验证码接口想限制用户10分钟内最多获取5次验证码。实现如下:
AccessLimit.java

import java.lang.annotation.*;
import java.util.concurrent.TimeUnit;

/**
 * 接口防刷注解(访问限制)
 *
 * @author redreamer
 */
@Documented
@Target({ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
public @interface AccessLimit {

    /**
     * 限制的时间长度
     */
    int timeLength();

    /**
     * 限制的时间长度单位
     */
    TimeUnit timeUnit() default TimeUnit.SECONDS;

    /**
     * 最大访问次数
     */
    int maxCount();

    /**
     * 唯一标识的参数名:作为唯一的条件
     */
    String keyArgName() default "";

    /**
     * 超限提示语
     */
    String message() default "";
}

AccessLimitAop.java

import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.aspectj.lang.JoinPoint;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Before;
import org.aspectj.lang.annotation.Pointcut;
import org.aspectj.lang.reflect.CodeSignature;
import org.aspectj.lang.reflect.MethodSignature;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.stereotype.Component;

import javax.annotation.Resource;
import java.lang.reflect.Method;
import java.nio.charset.StandardCharsets;
import java.util.Base64;
import java.util.Optional;
import java.util.concurrent.TimeUnit;

/**
 * 防刷切面实现类
 *
 * @author redreamer
 */
@Slf4j
@Aspect
@Component
public class AccessLimitAop {

    @Resource
    private RedisTemplate<String, String> redisTemplate;

    /**
     * 切入点
     */
    @Pointcut("@annotation(mypackage.accesslimit.AccessLimit)")
    public void pointcut() {
    }


    /**
     * 处理前
     */
    @Before("pointcut()")
    public void joinPoint(JoinPoint joinPoint) throws Exception {
        MethodSignature methodSignature = (MethodSignature) joinPoint.getSignature();
        Method method = joinPoint.getTarget().getClass().getMethod(methodSignature.getName(),
                methodSignature.getParameterTypes());

        AccessLimit accessLimit = method.getAnnotation(AccessLimit.class);
        String methodFullName = method.getDeclaringClass().getName() + "." + method.getName();
        String argName = accessLimit.keyArgName();

        String paramValue = "";
        String[] parameterNames = ((CodeSignature) joinPoint.getStaticPart().getSignature()).getParameterNames();
        for (int i = 0; i < parameterNames.length; i++) {
            String parameterName = parameterNames[i];
            if (argName.equals(parameterName)) {
                paramValue = (String) joinPoint.getArgs()[i];
                break;
            }
        }

		// 可能参数为中文
		String base64Str = toBase64String(paramValue);
        int timeLength = accessLimit.timeLength();
        int maxCount = accessLimit.maxCount();
        TimeUnit timeUnit = accessLimit.timeUnit();
        String message = accessLimit.message();

        String key = methodFullName + base64Str;

        int count = Optional.ofNullable(redisTemplate.boundValueOps(key).get()).map(Integer::valueOf).orElse(0);
        if (count <= maxCount) {
            redisTemplate.boundValueOps(key).set(String.valueOf(count + 1), timeLength, timeUnit);
        } else {
            log.warn("访问过于频繁:{} - {}", methodFullName, paramValue);
            message = StringUtils.isBlank(message) ? "您的访问过于频繁,请稍后再试!" : message;
            throw new ServiceException(message);
        }
        
    }

    /**
     * 对象转换为base64字符串
     *
     * @param paramValue 参数值
     * @return base64字符串
     */
    private String toBase64String(String paramValue) throws Exception {
        if (StringUtils.isEmpty(paramValue)) {
            return null;
        }
        Base64.Encoder encoder = Base64.getEncoder();
        byte[] bytes = paramValue.getBytes(StandardCharsets.UTF_8);
        return encoder.encodeToString(bytes);
    }

}

使用示例:

@ApiImplicitParams({
            @ApiImplicitParam(paramType = "query", dataType = "String", name = "p1", value = "参数1"),
            @ApiImplicitParam(paramType = "query", dataType = "String", name = "p2", value = "参数2"),
            @ApiImplicitParam(paramType = "query", dataType = "String", name = "phoneNumber", value = "手机号码")
    })
@ApiOperation(value = "获取手机验证码")
@AccessLimit(timeLength = 10, timeUnit = TimeUnit.MINUTES, maxCount = 5, keyArgName = "phoneNumber")
@GetMapping("/code")
public ResponseEntity getCheckCode(@RequestParam(value = "p1") String p1,
                                   @RequestParam(value = "p2") String p2,
                                   @Mobile @RequestParam("phoneNumber") String phoneNumber) {
    String code = checkCodeService.sendSms(CountryCodeEnum.CN, phoneNumber, 6);
    return ResponseEntity.ok(ImmutableMap.of("code", code));
}

你可能感兴趣的:(Java)