【springboot】spring的Aop结合Redis实现对短信接口的限流

前言

场景: 为了限制短信验证码接口的访问次数,防止被刷,结合Aop和redis根据用户ip对用户限流

1.准备工作

首先我们创建一个 Spring Boot 工程,引入 Web 和 Redis 依赖,同时考虑到接口限流一般是通过注解来标记,而注解是通过 AOP 来解析的,所以我们还需要加上 AOP 的依赖,最终的依赖如下:

        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-aop</artifactId>
        </dependency>
         <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-data-redis</artifactId>
        </dependency>
      	<dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-web</artifactId>
        </dependency>

然后提前准备好一个 Redis 实例,这里我们项目配置好之后,直接配置一下 Redis 的基本信息即可,比如:

spring.redis.host=localhost
spring.redis.port=6379
spring.redis.password=123

2.限流注解

接下来我们创建一个限流注解,我们将限流分为两种情况:
1:针对当前接口的全局性限流,例如该接口可以在 1 分钟内访问 100 次。
2: 针对某一个 IP 地址的限流,例如某个 IP 地址可以在 1 分钟内访问 100 次

针对这两种情况,我们创建一个枚举类:

public enum LimitType {
    /**
     * 默认策略全局限流
     */
    DEFAULT,
    /**
     * 根据请求者IP进行限流
     */
    IP
}

接下来我们来创建限流注解:

@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface RateLimiter {
    /**
     * 限流key
     */
    String key() default "rate_limit:";

    /**
     * 限流时间,单位秒
     */
    int time() default 60;

    /**
     * 限流次数
     */
    int count() default 100;

    /**
     * 限流类型
     */
    LimitType limitType() default LimitType.DEFAULT;
}

第一个参数限流的 key,这个仅仅是一个前缀,将来完整的 key 是这个前缀再加上接口方法的完整路径,共同组成限流 key,这个 key 将被存入到 Redis 中。

另外三个参数好理解,我就不多说了。

好了,将来哪个接口需要限流,就在哪个接口上添加 @RateLimiter 注解,然后配置相关参数即可。

3. 定制或者选择redisTemplate

1. 定制 RedisTemplate(看需要,我使用第二种方案)

在 Spring Boot 中,我们其实更习惯使用 Spring Data Redis 来操作 Redis,不过默认的 RedisTemplate 有一个小坑,就是序列化用的是 JdkSerializationRedisSerializer,不知道小伙伴们有没有注意过,直接用这个序列化工具将来存到 Redis 上的 key 和 value 都会莫名其妙多一些前缀,这就导致你用命令读取的时候可能会出错。

例如存储的时候,key 是 name,value 是 javaboy,但是当你在命令行操作的时候,get name 却获取不到你想要的数据,原因就是存到 redis 之后 name 前面多了一些字符,此时只能继续使用 RedisTemplate 将之读取出来。

我们用 Redis 做限流会用到 Lua 脚本,使用 Lua 脚本的时候,就会出现上面说的这种情况,所以我们需要修改 RedisTemplate 的序列化方案。

修改 RedisTemplate 序列化方案,此配置用到了jackson2JsonRedisSerializer 的序列化器(忘了要不要引入依赖),代码参考案例如下:

@Configuration
public class RedisConfig {

    @Bean
    public RedisTemplate<Object, Object> redisTemplate(RedisConnectionFactory connectionFactory) {
        RedisTemplate<Object, Object> redisTemplate = new RedisTemplate<>();
        redisTemplate.setConnectionFactory(connectionFactory);
        // 使用Jackson2JsonRedisSerialize 替换默认序列化(默认采用的是JDK序列化)
        Jackson2JsonRedisSerializer<Object> jackson2JsonRedisSerializer = new Jackson2JsonRedisSerializer<>(Object.class);
        ObjectMapper om = new ObjectMapper();
        om.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.ANY);
        om.enableDefaultTyping(ObjectMapper.DefaultTyping.NON_FINAL);
        jackson2JsonRedisSerializer.setObjectMapper(om);
        redisTemplate.setKeySerializer(jackson2JsonRedisSerializer);
        redisTemplate.setValueSerializer(jackson2JsonRedisSerializer);
        redisTemplate.setHashKeySerializer(jackson2JsonRedisSerializer);
        redisTemplate.setHashValueSerializer(jackson2JsonRedisSerializer);
        return redisTemplate;
    }
}

2.直接使用StringRedisTemplate

StringRedisTemplate是Spring Data Redis定义好的一个操作redis的模版,继承了redisTemplate,默认使用了字符串序列序列化器(就是key和value都是用String进行存储,定制RedisTemplate狭义无非是定义哪种存储类型的序列化器,上面第一种数一种Json形式的序列化器,在本文章中在实现限流功能上没有区别)。

  • 我选择这种的理由,懒,不用配置
  • 但是注意传入的key和value都要先转为toString()转为字符串,不然会报错

4. 开放lua脚本

  1. lua脚本操作Redis的目的就是为了保证多个Redis操作的原子性,如果其中一条Redis操作出错了,可以抛出异常给springboot进行Redis回滚。如果不知道怎么使用lua脚本,可以去B站黑马Redis关于lua的几节补补。

2.脚本流程的意思大概如下:

  • 首先获取到传进来的 key 以及 限流的 count 和时间 time。
  • 通过 get 获取到这个 key 对应的值,这个值就是当前时间窗内这个接口可以访问多少次。
  • 如果是第一次访问,此时拿到的结果为 nil,否则拿到的结果应该是一个数字,所以接下来就判断,如果拿到的结果是一个数字,并且这个数字还大于 count,那 就说明已经超过流量限制了,那么直接返回查询的结果即可。
  • 如果拿到的结果为 nil,说明是第一次访问,此时就给当前 key 自增 1,然后设置一个过期时间。
  • 最后把自增 1 后的值返回就可以了。
-- redis限流脚本
-- key参数
local key = KEYS[1]
-- 限流的次数
local limitCount = tonumber(ARGV[1])
-- 限流的时间
local limitTime = tonumber(ARGV[2])
-- 获取当前时间
local currentCount = redis.call('get', key)
-- if 获取key的当前数 > limitCount 则返回最大值
if currentCount and tonumber(currentCount) > limitCount then
    return tonumber(currentCount)
end
-- key 自增1
currentCount = redis.call("incr",key)
-- if key的值 == 1 设置过期限流过期时间
if tonumber(currentCount) == 1 then
    redis.call("expire",key,limitTime)
end
-- 返回key的值
return tonumber(currentCount)

5.注解解析

  • springboot记得在main方法开启aop注解功能(不会自己查)
  • 核心代码如下(看不懂问gpt)
  • 代码中的异常为自定义异常,可换成自己的异常类抛出或处理异常
@Component
@Aspect
@Slf4j
public class RateLimiterAspect {

    @Resource
    private StringRedisTemplate stringRedisTemplate;


    @Resource
    private RedisScript<Long> limitScript;

    @Before("@annotation(rateLimiter)")
    public void doBefore(JoinPoint point, RateLimiter rateLimiter) {
        String key = rateLimiter.key();
        int time = rateLimiter.time();
        int count = rateLimiter.count();

        String combineKey = getCombineKey(rateLimiter, point);
        List<String> keys = Collections.singletonList(combineKey);
        try {
            Long number = stringRedisTemplate.execute(limitScript, keys, String.valueOf(count), String.valueOf(time));
            if (number==null || number.intValue() > count) {
                throw new BusinessException(ErrorCode.PARAMS_ERROR,"访问过于频繁,请稍候再试");
            }
            log.info("限制请求'{}',当前请求'{}',缓存key'{}'", count, number.intValue(), keys.get(0));
        } catch (ServiceException e) {
            throw e;
        } catch (Exception e) {
            throw new BusinessException(ErrorCode.SYSTEM_ERROR, "系统繁忙,请稍候再试");
        }
    }

    /**
     * 获取ip为key
     * @param rateLimiter
     * @param point
     * @return
     */
    public String getCombineKey(RateLimiter rateLimiter, JoinPoint point) {
        StringBuffer stringBuffer = new StringBuffer(rateLimiter.key());
        if (rateLimiter.limitType() == LimitType.IP) {
            stringBuffer.append(
                    IpUtils.getIpAddr(((ServletRequestAttributes) RequestContextHolder.currentRequestAttributes())
                            .getRequest()))
                    .append("-");
        }
        MethodSignature signature = (MethodSignature) point.getSignature();
        Method method = signature.getMethod();
        Class<?> targetClass = method.getDeclaringClass();
        stringBuffer.append(targetClass.getName()).append("-").append(method.getName());
        return stringBuffer.toString();
    }

}

根据HttpRequest获取用户ip的工具类(看不懂细节问gpt)

public class IpUtils {
    public static String getIpAddr(HttpServletRequest request) {
        String ipAddress = null;
        try {
            ipAddress = request.getHeader("x-forwarded-for");
            if (ipAddress == null || ipAddress.length() == 0 || "unknown".equalsIgnoreCase(ipAddress)) {
                ipAddress = request.getHeader("Proxy-Client-IP");
            }
            if (ipAddress == null || ipAddress.length() == 0 || "unknown".equalsIgnoreCase(ipAddress)) {
                ipAddress = request.getHeader("WL-Proxy-Client-IP");
            }
            if (ipAddress == null || ipAddress.length() == 0 || "unknown".equalsIgnoreCase(ipAddress)) {
                ipAddress = request.getRemoteAddr();
                if (ipAddress.equals("127.0.0.1")) {
                    // 根据网卡取本机配置的IP
                    try {
                        ipAddress = InetAddress.getLocalHost().getHostAddress();
                    } catch (UnknownHostException e) {
                        e.printStackTrace();
                    }
                }
            }
            // 通过多个代理的情况,第一个IP为客户端真实IP,多个IP按照','分割
            if (ipAddress != null) {
                if (ipAddress.contains(",")) {
                    return ipAddress.split(",")[0];
                } else {
                    return ipAddress;
                }
            } else {
                return "";
            }
        } catch (Exception e) {
            e.printStackTrace();
            return "";
        }
    }
}

6.接口测试

如下: 根据用户IP地址,60秒,只能调用一次接口

    @GetMapping("/message")
    @RateLimiter(time = 60,count = 1,limitType = LimitType.IP)
    public BaseResponse<String> sendMessage(String phone,HttpServletRequest request) {
        if (StringUtils.isBlank(phone)) {
            throw new BusinessException(ErrorCode.PARAMS_ERROR);
        }
        boolean result = userVenueReservationService.sendMessage(phone,request);
        return ResultUtils.success(result ? "发送成功" : "发送失败");
    }

有错再补~~

你可能感兴趣的:(项目记录,spring,boot,spring,redis)