RedisTemplate自增时保证原子性的lua脚本限制接口请求频率

场景:限制请求后端接口的频率,例如1秒钟只能请求次数不能超过10次,通常的写法是:

1.先去从redis里面拿到当前请求次数

2.判断当前次数是否大于或等于限制次数

3.当前请求次数小于限制次数时进行自增

这三步在请求不是很密集的时候,程序执行很快,可能不会产生问题,如果两个请求几乎在同一时刻到来,我们第1步和第2步的判断是无法保证原子性的。

改进方式:使用redis的lua脚本,将"读取值、判断大小、自增"放到redis的一次操作中,redis底层所有的操作请求都是串行的,也就是一个请求执行完,才会执行下一个请求。

自增的lua脚本如下

    /**
     * 自增过期时间的原子性脚本
     */
    private String maxCountScriptText() {
        return "local key = KEYS[1]\n" +
                "local count = tonumber(ARGV[1])\n" +
                "local time = tonumber(ARGV[2])\n" +
                "local current = redis.call('get', key);\n" +
                "if current and tonumber(current) > count then\n" +
                "    return tonumber(current);\n" +
                "end\n" +
                "current = redis.call('incr', key)\n" +
                "if tonumber(current) == 1 then\n" +
                "    redis.call('expire', key, time)\n" +
                "end\n" +
                "return tonumber(current);";
    }

 将接口限流功能封装成一个注解@RateLimiter,在接口方法上面加上@RateLimiter就可以实现限流:

 redis工具类:

package com.zhou.redis.util;

import com.zhou.redis.dto.MyRedisMessage;
import com.zhou.redis.exception.LockException;
import com.zhou.redis.script.MaxCountQueryScript;
import com.zhou.redis.script.MaxCountScript;
import lombok.extern.slf4j.Slf4j;
import org.springframework.context.annotation.Configuration;
import org.springframework.data.redis.core.HashOperations;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.ValueOperations;

import java.time.Duration;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.TimeUnit;


@Configuration
@Slf4j
public class RedisUtil {
    public RedisTemplate redisTemplate;

    private MaxCountScript maxCountScript;
    private MaxCountQueryScript maxCountQueryScript;

    public RedisUtil(RedisTemplate redisTemplate, MaxCountScript maxCountScript, MaxCountQueryScript maxCountQueryScript) {
        this.redisTemplate = redisTemplate;
        this.maxCountScript = maxCountScript;
        this.maxCountQueryScript = maxCountQueryScript;
    }


    /**
     *  尝试加锁,返回加锁成功或者失败
     * @param time 秒
     **/

    public boolean tryLock(String key,Object value,Long time){
        if(time == null || time <= 0){
            time = 30L;
        }
        Boolean b = redisTemplate.opsForValue().setIfAbsent(key, value, Duration.ofSeconds(time));
        return b == null ? false : b;
    }

    

    /**
     *  释放锁(拿到锁之后才能调用释放锁)
     **/

    public boolean unLock(String key){
        Boolean b = redisTemplate.delete(key);
        return b == null ? false : b;
    }


    /**
     * 对key进行自增1
     * @param maxCount  最大值
     * @param time      增加次数
     * @return          自增后的值
     */
    public Long incr(String key,int maxCount, int time){
        List keys = Collections.singletonList(key);
        return  redisTemplate.execute(maxCountScript, keys, maxCount, time);
    }

    /**
     * 获得当前值
     */
    public Long incrNow(String key){
        List keys = Collections.singletonList(key);
        return  redisTemplate.execute(maxCountQueryScript, keys);
    }
}

 redis配置类:

package com.zhou.redis.config;


import com.fasterxml.jackson.annotation.JsonAutoDetect;
import com.fasterxml.jackson.annotation.PropertyAccessor;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.zhou.redis.listener.MyRedisListener;
import com.zhou.redis.script.MaxCountQueryScript;
import com.zhou.redis.script.MaxCountScript;
import com.zhou.redis.util.RedisTopic;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.data.redis.connection.RedisConnectionFactory;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.listener.PatternTopic;
import org.springframework.data.redis.listener.RedisMessageListenerContainer;
import org.springframework.data.redis.listener.Topic;
import org.springframework.data.redis.listener.adapter.MessageListenerAdapter;
import org.springframework.data.redis.serializer.Jackson2JsonRedisSerializer;
import org.springframework.data.redis.serializer.StringRedisSerializer;

import java.util.Arrays;
import java.util.List;


@Configuration
public class RedisConfig {
    @SuppressWarnings("all")
    @Bean
    public RedisTemplate redisTemplate(RedisConnectionFactory factory) {
        RedisTemplate template = new RedisTemplate<>();
        template.setConnectionFactory(factory);

        //Json序列化配置
        Jackson2JsonRedisSerializer jackson2JsonRedisSerializer = new Jackson2JsonRedisSerializer<>(Object.class);
        ObjectMapper om = new ObjectMapper();
        om.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.ANY);
        om.enableDefaultTyping(ObjectMapper.DefaultTyping.NON_FINAL);
        jackson2JsonRedisSerializer.setObjectMapper(om);

        //String的序列化
        StringRedisSerializer stringRedisSerializer = new StringRedisSerializer();

        //key采用string的序列化
        template.setKeySerializer(stringRedisSerializer);
        //hash的key采用string的序列化
        template.setHashKeySerializer(stringRedisSerializer);
        //value序列化采用jackson
        template.setValueSerializer(jackson2JsonRedisSerializer);
        //hash的value序列化方式采用jackson
        template.setHashValueSerializer(jackson2JsonRedisSerializer);
        template.afterPropertiesSet();
        return template;
    }
    /**
     * Redis消息监听器容器
     * 这个容器加载了RedisConnectionFactory和消息监听器
     * 可以添加多个监听不同话题的redis监听器,只需要把消息监听器和相应的消息订阅处理器绑定,该消息监听器
     * 通过反射技术调用消息订阅处理器的相关方法进行一些业务处理
     *
     * @param redisConnectionFactory 连接工厂
     * @param adapter                适配器
     * @return redis消息监听容器
     */
    @Bean
    @SuppressWarnings("all")
    public RedisMessageListenerContainer container(RedisConnectionFactory redisConnectionFactory,
                                                   FuncUpdateListener listener,
                                                   MessageListenerAdapter adapter) {
        RedisMessageListenerContainer container = new RedisMessageListenerContainer();
        // 监听所有库的key过期事件
        container.setConnectionFactory(redisConnectionFactory);
        // 所有的订阅消息,都需要在这里进行注册绑定,new PatternTopic(TOPIC_NAME1)表示发布的主题信息
        // 可以添加多个 messageListener,配置不同的通道
        List topicList = Arrays.asList(
                new PatternTopic(RedisTopic.TOPIC1),
                new PatternTopic(RedisTopic.TOPIC2)
        );
        container.addMessageListener(listener, topicList);

        /**
         * 设置序列化对象
         * 特别注意:1. 发布的时候需要设置序列化;订阅方也需要设置序列化
         *         2. 设置序列化对象必须放在[加入消息监听器]这一步后面,否则会导致接收器接收不到消息
         */
        Jackson2JsonRedisSerializer seria = new Jackson2JsonRedisSerializer(Object.class);
        ObjectMapper objectMapper = new ObjectMapper();
        objectMapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.ANY);
        objectMapper.enableDefaultTyping(ObjectMapper.DefaultTyping.NON_FINAL);
        seria.setObjectMapper(objectMapper);
        container.setTopicSerializer(seria);

        return container;
    }
    /**
     * 这个地方是给messageListenerAdapter 传入一个消息接受的处理器,利用反射的方法调用“receiveMessage”
     * 也有好几个重载方法,这边默认调用处理器的方法 叫OnMessage
     */
    @SuppressWarnings("all")
    @Bean
    public MessageListenerAdapter listenerAdapter() {
        //MessageListenerAdapter receiveMessage = new MessageListenerAdapter(printMessageReceiver, "receiveMessage");
        MessageListenerAdapter receiveMessage = new MessageListenerAdapter();

        Jackson2JsonRedisSerializer seria = new Jackson2JsonRedisSerializer(Object.class);
        ObjectMapper objectMapper = new ObjectMapper();
        objectMapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.ANY);
        objectMapper.enableDefaultTyping(ObjectMapper.DefaultTyping.NON_FINAL);
        seria.setObjectMapper(objectMapper);
        receiveMessage.setSerializer(seria);
        return receiveMessage;
    }

    @Bean
    public MaxCountScript maxCountScript() {
        return new MaxCountScript(maxCountScriptText());
    }
    @Bean
    public MaxCountQueryScript maxCountQueryScript() {
        return new MaxCountQueryScript(maxCountQueryScriptText());
    }

    /**
     * 自增过期时间的原子性脚本
     */
    private String maxCountScriptText() {
        return "local key = KEYS[1]\n" +
                "local count = tonumber(ARGV[1])\n" +
                "local time = tonumber(ARGV[2])\n" +
                "local current = redis.call('get', key);\n" +
                "if current and tonumber(current) > count then\n" +
                "    return tonumber(current);\n" +
                "end\n" +
                "current = redis.call('incr', key)\n" +
                "if tonumber(current) == 1 then\n" +
                "    redis.call('expire', key, time)\n" +
                "end\n" +
                "return tonumber(current);";
        /*return "local limitMaxCount = tonumber(ARGV[1])\n" +
                "local limitSecond = tonumber(ARGV[2])\n" +
                "local num = tonumber(redis.call('get', KEYS[1]) or '-1')\n" +
                "if limitMaxCount then\n" +
                "   return -1\n" +
                "end\n" +
                "if num == -1 then\n" +
                "    redis.call('incr', KEYS[1])\n" +
                "    redis.call('expire', KEYS[1], limitSecond)\n" +
                "    return 1\n" +
                "else\n" +
                "    if num >= limitMaxCount then\n" +
                "        return 0\n" +
                "    else\n" +
                "        redis.call('incr', KEYS[1])\n" +
                "        return 1\n" +
                "    end\n" +
                "end";*/
    }

    /**
     * 查询当前值脚本
     */
    private String maxCountQueryScriptText() {
        return "local key = KEYS[1]\n" +
                "local current = redis.call('get', key);\n" +
                "if current then\n" +
                "    return tonumber(current);\n" +
                "else\n" +
                "    return current\n" +
                "end\n";
    }
} 
  

 拦截模式枚举类:根据ip拦截或者方法拦截

package com.zhou.aop;

/**
 * @author lang.zhou
 * @since 2023/1/31 17:56
 */
public enum LimitType {
    IP,DEFAULT
}

 封装自定义注解:@RateLimiter

package com.zhou.aop;

import java.lang.annotation.*;

/**
 * @author lang.zhou
 * @since 2023/1/31 17:49
 */
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface RateLimiter {
    /**
     * 限流key
     */
    String key() default "RateLimiter";

    /**
     * 限流时间,单位秒
     */
    int time() default 60;

    /**
     * 限流次数
     */
    int count() default 100;

    /**
     * 限流类型
     */
    LimitType limitType() default LimitType.DEFAULT;

    /**
     * 限流后返回的文字
     */
    String limitMsg() default "访问过于频繁,请稍候再试";
}

 注解的切面逻辑:

package com.zhou.aop;

import com.zhou.redis.util.RedisUtil;
import com.zhou.common.utils.IpUtil;
import lombok.extern.slf4j.Slf4j;
import org.aspectj.lang.JoinPoint;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Before;
import org.aspectj.lang.reflect.MethodSignature;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;

import javax.servlet.http.HttpServletRequest;
import java.lang.reflect.Method;

/**
 * 接口限流切面
 * @author lang.zhou
 * @since 2023/1/31 17:50
 */
@Aspect
@Slf4j
@Component
public class RateLimiterAspect {

    @Autowired
    private RedisUtil redisUtils;

    @Before("@annotation(rateLimiter)")
    public void doBefore(JoinPoint point, RateLimiter rateLimiter) {
        int time = rateLimiter.time();
        int count = rateLimiter.count();
        String combineKey = getCombineKey(rateLimiter, point);
        try {
            Long number = redisUtils.incr(combineKey, count, time);
            if (number == null || number.intValue() > count){
                log.info("请求【{}】被拦截,{}秒内请求次数{}",combineKey,time,number);
                throw new RuntimeException(rateLimiter.limitMsg());
            }
        } catch (ServiceRuntimeException e) {
            throw e;
        } catch (Exception e) {
            throw new RuntimeException("网络繁忙,请稍候再试");
        }
    }

    /**
     * 获取限流key
     */
    public String getCombineKey(RateLimiter rateLimiter, JoinPoint point) {
        StringBuilder s = new StringBuilder(rateLimiter.key());
        ServletRequestAttributes requestAttributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
        if(requestAttributes != null){
            HttpServletRequest request = requestAttributes.getRequest();
            if (rateLimiter.limitType() == LimitType.IP) {
                s.append(IpUtil.getIpAddr(request)).append("-");
            }
        }

        MethodSignature signature = (MethodSignature) point.getSignature();
        Method method = signature.getMethod();
        Class targetClass = method.getDeclaringClass();
        s.append(targetClass.getName()).append(".").append(method.getName());
        return s.toString();
    }

}

 lua自增脚本类:

package com.zhou.redis.script;

import org.springframework.data.redis.core.script.DefaultRedisScript;

/**
 * @author lang.zhou
 * @since 2023/2/25
 */
public class MaxCountScript extends DefaultRedisScript {
    public MaxCountScript(String script) {
        super(script,Long.class);
    }
}

 lua查询当前值的脚本类:

package com.zhou.redis.script;

import org.springframework.data.redis.core.script.DefaultRedisScript;

/**
 * @author lang.zhou
 * @since 2023/2/25
 */
public class MaxCountQueryScript extends DefaultRedisScript {
    public MaxCountQueryScript(String script) {
        super(script,Long.class);
    }
}

 订阅消息通道的枚举:

package com.zhou.redis.util;


public class RedisTopic {
    public static final String TOPIC1 = "TOPIC1";
    public static final String TOPIC2 = "TOPIC2";
}

消息实体类: 

package com.zhou.redis.dto;

import lombok.Data;

import java.io.Serializable;

/**
 * redis订阅消息实体
 * @since 2022/11/11 17:34
 */
@Data
public class MyRedisMessage implements Serializable {
    private String msg;
}

订阅消息监听器: 

package com.zhou.redis.listener;

import com.zhou.redis.dto.MyRedisMessage;
import com.zhou.redis.util.RedisTopic;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.redis.connection.Message;
import org.springframework.data.redis.connection.MessageListener;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.stereotype.Component;

import javax.script.ScriptException;

/**
 * @author lang.zhou
 */
@Slf4j
@Component
public class MyRedisListener implements MessageListener {
    @Autowired
    private RedisTemplate redisTemplate;

    

    @Override
    public void onMessage(Message message, byte[] pattern) {
        String topic = new String(pattern);
        // 接收的topic
        log.info("channel:{}" , topic);

        if(RedisTopic.TOPIC1.equals(topic)){
            //
        }else if(RedisTopic.TOPIC2.equals(topic)){
            //序列化对象(特别注意:发布的时候需要设置序列化;订阅方也需要设置序列化)
            MyRedisMessage msg = (MyRedisMessage) redisTemplate.getValueSerializer().deserialize(message.getBody());
            log.info("message:{}",msg);
            
        }

    }
}

注解使用方式:1秒内一个ip最多只能请求10次

@RestController
@RequestMapping("/test/api")
public class CheckController{


    @PostMapping("/limit")
    @RateLimiter(time = 1, count = 10, limitType = LimitType.IP, limitMsg = "请求过于频繁,请稍后重试")
    public void limit(HttpServletRequest request){
        //执行业务代码
    }

}

你可能感兴趣的:(lua,redis,java)