Django Rest Framework 源码解析--节流

Django Rest Framework 源码解析--节流

接上一篇文章;restframework重写的dispatch()方法中,执行了inital()函数。inital()中check_throttles((request) 方法实现了请求的访问频率控制功能。
1、check_throttles(request)函数中,循环了限流类的对象列表,依次执行限流对象的 allow_request() 方法

def check_throttles(self, request):
    """
    Check if request should be throttled.
    Raises an appropriate exception if the request is throttled.
    
    检查请求是否应该被节流。
    如果被节流,则抛出相应异常。
    """
    # 遍历 限流对象列表,如果返回Fasle,则被限流,抛出异常(可传参数throttle.wait()的返回值)
    for throttle in self.get_throttles():
        # 如果被节流,返回False,则抛出相应异常
        if not throttle.allow_request(request, self):
            self.throttled(request, throttle.wait())

2、self.get_throttles()限流对象列表生成式

def get_throttles(self):
    """
    Instantiates and returns the list of throttles that this view uses.
    """
    # 可以在view中重写throttle_classes,指定限流对象列表
    # 也可以在setting.py中定义
    return [throttle() for throttle in self.throttle_classes]

3、如果被节流,则通过throttled()方法,抛出相应异常,可传入wait作为等待时间的参数

def throttled(self, request, wait):
    """
    If request is throttled, determine what kind of exception to raise.
    """
    # wait参数,出入的值是 节流类的wait()方法的返回值(单位:秒)
    raise exceptions.Throttled(wait)

使用示例:

1、自定义限流类,继承BaseThrottle类,重写 allow_request()wait() 这两个方法

from rest_framework.throttling import BaseThrottle
 
import time
 
# 保存访问记录
VISIT_RECORD = {}
 
class VisitThrottle(BaseThrottle):
    """
    自定义限流类:60秒内只能访问3次,超过就限流
    
    返回True,允许请求访问
    返回False,禁止请求访问
    """
    
    
    # 通过 self.history这个对象的成员变量,
    # 在allow_request()和 wait()这两个成员方法之间传递history的值
    def __init__(self):
        self.history = None  # 初始化访问记录
 
    def allow_request(self, request, view):
        # 获取用户ip  
        remote_addr = self.get_ident(request)
        timer = time.time()
        if remote_addr not in REQ_RECORD:
            VISIT_RECORD[remote_addr]=[timer]
            return True
        # 获取当前ip的历史访问记录
        history = VISIT_RECORD[remote_addr]
        self.history = history
        # 如果有历史访问记录,并且最早一次的访问记录离当前时间超过60s,就删除最早的那个访问记录,
        # 只要为True,就一直循环删除最早的一次访问记录
        while history and history[-1] < timer - 60:
            history.pop()
 
        # 60秒内的访问记录,是否超过3次
        # 如果没有超过,则记录这次访问,并返回True,允许访问
        # 如果超过,则返回False,禁止访问
        if len(history) < 3:
            history.insert(0, timer)
            return True
        return False
 
    def wait(self):
        '''还需要等多久才能访问'''
        timer = time.time()
        return 60 - (timer - self.history[-1])

2、在View中调用限流类

class TestView(APIView):
    # 在View中重写throttle_classes限流类列表,一般只写一个限流,
    # 或者不限流,使列表为空,throttle_classes = []
    throttle_classes = [VisitThrottle, ]
 
    def get(self,request,*args,**kwargs):
        pass

3、或者在setting.py中指定全站默认使用的限流类的路径

REST_FRAMEWORK = {
    "DEFAULT_THROTTLE_CLASSES": ('plugin.restframework.throttling.VisitThrottle')
}

内置的限流类

与认证类、鉴权类通常继承BaseXXX类,高度自定义不同,限流类我个人觉得继承restframework提供的其他内置的类更方便

例如继承 SimpleRateThrottle

class TestTrottle(SimpleRateThrottle):
    # 设定规定时间内能访问的次数,例如 3/m, 1/s, 1000/h, 9999/day
    # 通常设定在setting.py中
    THROTTLE_RATES = {
        "Test": '5/m'
    }
    # 指定scope值为 查找THROTTLE_RATES的key
    scope = "Test"
    
    def get_cache_key(self, request, view):
        # 通过ip限制节流
        return self.get_ident(request)
        
        # return request.user.pk   # 通常也使用requser.user作为标识一个用户的ID

你可能感兴趣的:(Django Rest Framework 源码解析--节流)