利用springboot+redis+lua配合注解实现接口限流

Spring Boot:是由Pivotal团队提供的全新框架,其设计目的是用来简化Spring应用的创建、运行、调试、部署等。使用Spring Boot可以做到专注于Spring应用的开发,而无需过多关注XML的配置。Spring Boot使用“习惯优于配置”的理念,简单来说,它提供了一堆依赖打包,并已经按照使用习惯解决了依赖问题。使用Spring Boot可以不用或者只需要很少的Spring配置就可以让企业项目快速运行起来

Redis:全称 Remote Dictionary Server(即远程字典服务),它是一个基于内存实现的键值型非关系(NoSQL)数据库,由意大利人 Salvatore Sanfilippo 使用 C 语言编写。

Lua: 是一种轻量小巧的脚本语言,用标准C语言编写并以源代码形式开放, 其设计目的是为了嵌入应用程序中,从而为应用程序提供灵活的扩展和定制功能。

注解:注解也叫元数据,即一种描述数据的数据。注解是JDK1.5版本开始引入的一个特性,用于对代码进行说明,可以对包、类、接口、字段、方法参数、局部变量等进行注解。

课前结束,进入正题!

1.准备工作

首先我们创建一个 Spring Boot 工程,引入 Web 和 Redis 依赖,同时考虑到接口限流一般是通过注解来标记,而注解是通过 AOP 来解析的,所以我们还需要加上 AOP 的依赖,最终的依赖如下:

 
        
            org.springframework.boot
            spring-boot-starter-data-redis
        
        
        
            org.springframework.boot
            spring-boot-starter-web
        
        
        
            org.springframework.boot
            spring-boot-starter-aop
        
        
        
            org.apache.httpcomponents
            httpcore
            4.4.12
        

然后提前准备好一个 Redis 实例,这里我们项目配置好之后,直接配置一下 Redis 的基本信息即可,如下:

spring.redis.host=192.168.19.10
spring.redis.port=6379
spring.redis.password=123123

好啦,准备工作就算是到位了。

2. 限流注解

接下来我们创建一个限流注解,我们将限流分为两种情况:

  1. 针对当前接口的全局性限流,例如该接口可以在 1 分钟内访问 100 次。

  2. 针对某一个 IP 地址的限流,例如某个 IP 地址可以在 1 分钟内访问 100 次

针对这两种情况,我们创建一个枚举类:

public class CurrentLimitingEnum {
    public enum LimitType {
        /**
         * 默认策略全局限流
         */
        DEFAULT,
        /**
         * 根据请求者IP进行限流
         */
        IP
    }
}

接下来我们来创建限流注解:

@Target(ElementType.METHOD)//作用在那个位置
@Retention(RetentionPolicy.RUNTIME)//作用在什么时候
@Documented
//格式:修饰符(public abstract【默认且唯一】),返回值类型,属性名,默认值(可忽略)
public @interface RateLimiter {
    /**
     * 限流key
     */
    public String key() default "rate_limit";

    /**
     * 限流时间,单位秒
     */
    public int time() default 60;

    /**
     * 限流次数
     */
    public int count() default 100;

    /**
     * 限流类型
     */

    public CurrentLimitingEnum.LimitType limitType() default CurrentLimitingEnum.LimitType.DEFAULT;
}

第一个参数限流的 key,这个仅仅是一个前缀,将来完整的 key 是这个前缀再加上接口方法的完整路径,共同组成限流 key,这个 key 将被存入到 Redis 中。

另外三个参数好理解,我就不多说了。

好了,将来哪个接口需要限流,就在哪个接口上添加 @RateLimiter 注解,然后配置相关参数即可。

3. 定制 RedisTemplate

小伙伴们知道,在 Spring Boot 中,我们其实更习惯使用 Spring Data Redis 来操作 Redis,不过默认的 RedisTemplate 有一个小坑,就是序列化用的是 JdkSerializationRedisSerializer,不知道小伙伴们有没有注意过,直接用这个序列化工具将来存到 Redis 上的 key 和 value 都会莫名其妙多一些前缀,这就导致你用命令读取的时候可能会出错。

例如存储的时候,key 是 name,value 是 javaboy,但是当你在命令行操作的时候,get name 却获取不到你想要的数据,原因就是存到 redis 之后 name 前面多了一些字符,此时只能继续使用 RedisTemplate 将之读取出来。

我们用 Redis 做限流会用到 Lua 脚本,使用 Lua 脚本的时候,就会出现上面说的这种情况,所以我们需要修改 RedisTemplate 的序列化方案。

修改 RedisTemplate 序列化方案,代码如下:

@Configuration
public class RedisConfig {

    @Bean
    public RedisTemplate redisTemplate(RedisConnectionFactory connectionFactory) {
        RedisTemplate redisTemplate = new RedisTemplate<>();
        redisTemplate.setConnectionFactory(connectionFactory);
        // 使用Jackson2JsonRedisSerialize 替换默认序列化(默认采用的是JDK序列化)
        Jackson2JsonRedisSerializer jackson2JsonRedisSerializer = new Jackson2JsonRedisSerializer<>(Object.class);
        ObjectMapper om = new ObjectMapper();
        om.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.ANY);
        om.enableDefaultTyping(ObjectMapper.DefaultTyping.NON_FINAL);
        jackson2JsonRedisSerializer.setObjectMapper(om);
        redisTemplate.setKeySerializer(jackson2JsonRedisSerializer);
        redisTemplate.setValueSerializer(jackson2JsonRedisSerializer);
        redisTemplate.setHashKeySerializer(jackson2JsonRedisSerializer);
        redisTemplate.setHashValueSerializer(jackson2JsonRedisSerializer);
        return redisTemplate;
    }
} 
  

 4. 开发 Lua 脚本

  1. 在 Redis 服务端定义好 Lua 脚本,然后计算出来一个散列值,在 Java 代码中,通过这个散列值锁定要执行哪个 Lua 脚本。

  2. 直接在 Java 代码中将 Lua 脚本定义好,然后发送到 Redis 服务端去执行。

 我们在 resources 目录下新建 lua 文件夹专门用来存放 lua 脚本,脚本内容如下:

local key = KEYS[1]
local count = tonumber(ARGV[1])
local time = tonumber(ARGV[2])
local current = redis.call('get', key)
if current and tonumber(current) > count then
    return tonumber(current)
end
current = redis.call('incr', key)
if tonumber(current) == 1 then
    redis.call('expire', key, time)
end
return tonumber(current)

这个脚本其实不难,大概瞅一眼就知道干啥用的。KEYS 和 ARGV 都是一会调用时候传进来的参数,tonumber 就是把字符串转为数字,redis.call 就是执行具体的 redis 指令,具体流程是这样: 

  1. 首先获取到传进来的 key 以及 限流的 count 和时间 time。

  2. 通过 get 获取到这个 key 对应的值,这个值就是当前时间窗内这个接口可以访问多少次。

  3. 如果是第一次访问,此时拿到的结果为 nil,否则拿到的结果应该是一个数字,所以接下来就判断,如果拿到的结果是一个数字,并且这个数字还大于 count,那就说明已经超过流量限制了,那么直接返回查询的结果即可。

  4. 如果拿到的结果为 nil,说明是第一次访问,此时就给当前 key 自增 1,然后设置一个过期时间。

  5. 最后把自增 1 后的值返回就可以了。

接下来我们在一个 Bean 中来加载这段 Lua 脚本,如下:

local key = KEYS[1]
local count = tonumber(ARGV[1])
local time = tonumber(ARGV[2])
local current = redis.call('get', key)
if current and tonumber(current) > count then
    return tonumber(current)
end
current = redis.call('incr', key)
if tonumber(current) == 1 then
    redis.call('expire', key, time)
end
return tonumber(current)

5. 注解解析(重难点)

接下来我们就需要自定义切面,来解析这个注解了,我们来看看切面的定义:

@Aspect
@Component
public class RateLimiterAspect {
    private static final Logger log = LoggerFactory.getLogger(RateLimiterAspect.class);
    @Autowired
    private RedisTemplate redisTemplate;

    @Autowired
    private RedisScript limitScript;

    @Before("@annotation(rateLimiter)")
    public void doBefore(JoinPoint point, RateLimiter rateLimiter) throws Throwable {
        String key = rateLimiter.key();
        int time = rateLimiter.time();
        int count = rateLimiter.count();

        String combineKey = getCombineKey(rateLimiter, point);
        List keys = Collections.singletonList(combineKey);
        try {
            Long number = redisTemplate.execute(limitScript, keys, count, time);
            if (number==null || number.intValue() > count) {
                throw new ServiceException("访问过于频繁,小哥哥注意节制哦");
            }
            log.info("限制请求'{}',当前请求'{}',缓存key'{}'", count, number.intValue(), key);
        } catch (ServiceException e) {
            throw e;
        } catch (Exception e) {
            throw new RuntimeException("服务器限流异常,请稍候再试");
        }
    }

    public String getCombineKey(RateLimiter rateLimiter, JoinPoint point) {
        StringBuffer stringBuffer = new StringBuffer(rateLimiter.key());
        if (rateLimiter.limitType() == CurrentLimitingEnum.LimitType.IP) {
            stringBuffer.append(IPUtils.getIpAddr(((ServletRequestAttributes) RequestContextHolder.currentRequestAttributes()).getRequest())).append("-");
        }
        MethodSignature signature = (MethodSignature) point.getSignature();
        Method method = signature.getMethod();
        Class targetClass = method.getDeclaringClass();
        stringBuffer.append(targetClass.getName()).append("-").append(method.getName());
        return stringBuffer.toString();
    }
} 
  

 这个切面就是拦截所有加了 @RateLimiter 注解的方法,在前置通知中对注解进行处理。

  1. 首先获取到注解中的 key、time 以及 count 三个参数。

  2. 获取一个组合的 key,所谓的组合的 key,就是在注解的 key 属性基础上,再加上方法的完整路径,如果是 IP 模式的话,就再加上 IP 地址。以 IP 模式为例,最终生成的 key 类似这样:rate_limit:127.0.0.1-org.javaboy.ratelimiter.controller.HelloController-hello(如果不是 IP 模式,那么生成的 key 中就不包含 IP 地址)。

  3. 将生成的 key 放到集合中。

  4. 通过 redisTemplate.execute 方法取执行一个 Lua 脚本,第一个参数是脚本所封装的对象,第二个参数是 key,对应了脚本中的 KEYS,后面是可变长度的参数,对应了脚本中的 ARGV。

  5. 将 Lua 脚本执行的结果与 count 进行比较,如果大于 count,就说明过载了,抛异常就行了。

6. 接口测试

@RestController
public class IndexController {
    @RequestMapping("/index")
    @RateLimiter(time = 5, count = 3, limitType = CurrentLimitingEnum.LimitType.IP)
    public String hello(HttpServletRequest request, HttpServletResponse response) {
        ServletContext servletContext = request.getServletContext();
        Integer num = (Integer) servletContext.getAttribute("num");
        if (num == null) {
            servletContext.setAttribute("num", 1);
        } else {
            servletContext.setAttribute("num", ++num);
        }
        SimpleDateFormat dateFormat = new SimpleDateFormat("yyyy:MM:dd HH:mm:ss");
        String dateTime = dateFormat.format(new Date());
        return "您是第" + servletContext.getAttribute("num") + "次访问了,时间是:" + dateTime+"
"+"5秒内只可以访问3次哦,小哥哥"; } }

7. 全局异常处理

 

public class ServiceException extends RuntimeException {
    private static final long serialVersionUID = 1L;
    private String msg;
    private int code = 500;

    public ServiceException(String msg) {
        super(msg);
        this.msg = msg;
    }

    public ServiceException(String msg, Throwable e) {
        super(msg, e);
        this.msg = msg;
    }

    public ServiceException(String msg, int code) {
        super(msg);
        this.msg = msg;
        this.code = code;
    }

    public ServiceException(String msg, int code, Throwable e) {
        super(msg, e);
        this.msg = msg;
        this.code = code;
    }

    public String getMsg() {
        return msg;
    }

    public void setMsg(String msg) {
        this.msg = msg;
    }

    public int getCode() {
        return code;
    }

    public void setCode(int code) {
        this.code = code;
    }
}

 

@RestControllerAdvice
public class GlobalException {
    @ExceptionHandler(ServiceException.class)
    public R serviceException(ServiceException e) {
            R r = new R();
            r.put("code", e.getCode());
            r.put("msg", e.getMessage());
            return r;
    }
}

8.工具类(实体类R,IPUtils) 

public class R extends HashMap {
	private static final long serialVersionUID = 1L;
	
	public R() {
		put("code", 0);
		put("msg", "success");
	}
	
	public static R error() {
		return error(HttpStatus.SC_INTERNAL_SERVER_ERROR, "未知异常,请联系管理员");
	}
	
	public static R error(String msg) {
		return error(HttpStatus.SC_INTERNAL_SERVER_ERROR, msg);
	}
	
	public static R error(int code, String msg) {
		R r = new R();
		r.put("code", code);
		r.put("msg", msg);
		return r;
	}

	public static R ok(String msg) {
		R r = new R();
		r.put("msg", msg);
		return r;
	}
	
	public static R ok(Map map) {
		R r = new R();
		r.putAll(map);
		return r;
	}
	
	public static R ok() {
		return new R();
	}

	public R put(String key, Object value) {
		super.put(key, value);
		return this;
	}
}
public class IPUtils {
	private static Logger logger = LoggerFactory.getLogger(IPUtils.class);

	/**
	 * 获取IP地址
	 * 
	 * 使用Nginx等反向代理软件, 则不能通过request.getRemoteAddr()获取IP地址
	 * 如果使用了多级反向代理的话,X-Forwarded-For的值并不止一个,而是一串IP地址,X-Forwarded-For中第一个非unknown的有效IP字符串,则为真实IP地址
	 */
	public static String getIpAddr(HttpServletRequest request) {
    	String ip = null;
        try {
            ip = request.getHeader("x-forwarded-for");
            if (StringUtils.isEmpty(ip) || "unknown".equalsIgnoreCase(ip)) {
                ip = request.getHeader("Proxy-Client-IP");
            }
            if (StringUtils.isEmpty(ip) || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
                ip = request.getHeader("WL-Proxy-Client-IP");
            }
            if (StringUtils.isEmpty(ip) || "unknown".equalsIgnoreCase(ip)) {
                ip = request.getHeader("HTTP_CLIENT_IP");
            }
            if (StringUtils.isEmpty(ip) || "unknown".equalsIgnoreCase(ip)) {
                ip = request.getHeader("HTTP_X_FORWARDED_FOR");
            }
            if (StringUtils.isEmpty(ip) || "unknown".equalsIgnoreCase(ip)) {
                ip = request.getRemoteAddr();
            }
        } catch (Exception e) {
        	logger.error("IPUtils ERROR ", e);
        }
        
//        //使用代理,则获取第一个IP地址
//        if(StringUtils.isEmpty(ip) && ip.length() > 15) {
//			if(ip.indexOf(",") > 0) {
//				ip = ip.substring(0, ip.indexOf(","));
//			}
//		}
        
        return ip;
    }
	
}

我们启动项目!!

第一次访问我们正常访问

利用springboot+redis+lua配合注解实现接口限流_第1张图片

 然后我们5秒内点击3次以上

利用springboot+redis+lua配合注解实现接口限流_第2张图片

大功告成!!!

你可能感兴趣的:(JavaEE,redis,lua,spring,boot)