在我们程序中,有时候需要对一些接口做访问控制,使程序更稳定,最常用的一种是通过ip限制,还有一种是通过用户名限制,也可以把一个接口限制死,在一段时间内只能访问多少次,这个根据自己需求来,不固定。在需要做限制的方法上加上一个自定义注解,用aop获取到这个方法,利用redis中的increment方法,去计数访问次数,超过访问次数,return一个自定义异常。
选用的是hash结构类型去存储访问次数,用访问路径作为外层key,ip作为内层key,访问次数作为value。
<dependency>
<groupId>org.springframework.bootgroupId>
<artifactId>spring-boot-starter-aopartifactId>
dependency>
<dependency>
<groupId>org.springframework.bootgroupId>
<artifactId>spring-boot-starter-data-redisartifactId>
dependency>
声明配置类
@Configuration
public class RedisConfig {
@Bean
@SuppressWarnings(value = { "unchecked", "rawtypes" }) //屏蔽一些无关紧张的警告
public RedisTemplate<Object,Object> redisTemplate(RedisConnectionFactory connectionFactory){
RedisTemplate<Object, Object> redisTemplate = new RedisTemplate<>();
//GenericJackson2JsonRedisSerializer比Jackson2JsonRedisSerializer效率低
//GenericJackson2JsonRedisSerializer jsonRedisSerializer = new GenericJackson2JsonRedisSerializer();
Jackson2JsonRedisSerializer jsonRedisSerializer = new Jackson2JsonRedisSerializer(Object.class);
redisTemplate.setConnectionFactory(connectionFactory);
// 使用StringRedisSerializer来序列化和反序列化redis的key值
redisTemplate.setKeySerializer(new StringRedisSerializer());
redisTemplate.setValueSerializer(jsonRedisSerializer);
//Hash的key也采用StringRedisSerializer的序列化方式
redisTemplate.setHashKeySerializer(new StringRedisSerializer());
redisTemplate.setHashValueSerializer(jsonRedisSerializer);
redisTemplate.afterPropertiesSet();
return redisTemplate;
}
}
application.yml配置
spring:
# redis配置
redis:
host: localhost
port: 6379
database: 0 #默认连接0号数据库
/**
* 接口访问频率注解,默认一分钟只能访问5次
*/
@Documented
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
public @interface AccessLimit {
long time() default 60000; 限制时间 单位:毫秒(默认值:一分钟)
int value() default 5;// 允许请求的次数(默认值:5次)
}
@Aspect
@Component
@Slf4j
public class InterfaceLimitAspect {
@Autowired
private RedisTemplate redisTemplate;
@Pointcut("@annotation(accessLimit)")
public void pt(AccessLimit accessLimit){}
@Around("pt(accessLimit)")
public Object Around(ProceedingJoinPoint joinPoint,AccessLimit accessLimit) throws Throwable {
// 获得request对象
ServletRequestAttributes sra =(ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
HttpServletRequest request = sra.getRequest();
log.info(request.getRequestURI());
//redis这里推荐使用hash类型,url为外层key,ip作为内层key,访问次数作为value
BoundHashOperations hashOps = redisTemplate.boundHashOps("interfaceLimit:"+request.getRequestURI());
//获取ip获取接口访问次数
Integer ipCount = (Integer)hashOps.get(request.getRemoteAddr());
Integer count = ipCount==null?0:ipCount;
//判断访问次数是否大于限制的次数
if(count>=accessLimit.value()){//超过次数,不执行目标方法
log.error("接口拦截:{} 请求超过限制频率【{}次/{}ms】,IP为{}",
request.getRequestURI(),
accessLimit.value(),
accessLimit.time(),
request.getRemoteAddr());
throw new AccessLimitException(ResultCodeEnum.ACCESS_LIMIT);
}else{
//请求时,设置有效时间, 记录加一
hashOps.increment(request.getRemoteAddr(),1);
hashOps.expire(accessLimit.time()*5, TimeUnit.MILLISECONDS);
}
Object result = joinPoint.proceed();
return result;
}
}
public class AccessLimitException extends RuntimeException{
private Integer code;
private String message;
public Integer getCode(){
return code;
}
public String getMessage(){
return message;
}
public AccessLimitException(ResultCodeEnum resultCodeEnum){
super(resultCodeEnum.getMessage());
this.code = resultCodeEnum.getCode();
this.message = resultCodeEnum.getMessage();
}
}
@RestControllerAdvice
@Slf4j
public class HandlerException {
@ExceptionHandler(Exception.class)
public Result handle(Throwable e){
log.info(e.getMessage());
return Result.build(null, 507,"系统错误");
}
@ExceptionHandler(AccessLimitException.class)
public Result AccessHandle(AccessLimitException e){
log.error(e.getMessage());
return Result.build(null,e.getCode(),e.getMessage());
}
}