django rest farmework 限制接口请求速率

参考官方的文档:
https://www.django-rest-framework.org/api-guide/throttling/

30568817.jpg

  • 全局配置:

在settings 里面全局配置

REST_FRAMEWORK = {
    'DEFAULT_THROTTLE_CLASSES': [
        'rest_framework.throttling.AnonRateThrottle',
        'rest_framework.throttling.UserRateThrottle'
    ],
    'DEFAULT_THROTTLE_RATES': {
        'anon': '100/day',
        'user': '1000/day'
    }
}

anon 和 user分别对应默认的AnonRateThrottle 和UserRateThrottle 类
每个类都有一个 scope 的属性

**rest 内置了,这三个限速类 **
AnonRateThrottle scope 属性的值是 anon
UserRateThrottle scope 属性的值是user
每个用户进行简单的全局速率限制,那么 UserRateThrottle 是合适的
上面两个一个用户的所有请求都是累加的 (一个用户的所有视图请求次数累加到一起,判断是不是超速了)
同一 scope和 token 组成 唯一的key,对应请求次数
参考:

django rest framework throttling per user and per view

**ScopedRateThrottle 针对 不同api进行统计 视图必须包含 throttle_scope 属性 **
(其实如果你要对每个视图的每个用户进行速率限制,那么只需要自定义即可 比如 你把自定义的 scope 设置为 请求的 api 路径,文末会讲一下)

  • DEFAULT_THROTTLE_CLASSES :

指定全局使用的限速类

  • DEFAULT_THROTTLE_RATES :

指定全局的scope对应的限速字符串参数

  • DEFAULT_THROTTLE_CLASSES:上面的官方文档的配置其实是全局使用了这两个限速类,并且配置了对应的scope (因为这连个内置类只能从 settings 的 user 和anon 读取 对应的rate 参数 )

如果只想全局配置速率参数,那把限速类 :DEFAULT_THROTTLE_CLASSES 配置去掉即可。

  • 单独对函数配置

  • 不用全局配置的情况下,可以使用 throttle_classes装饰器 单独对某个 views 函数进行作用:
xxx.views.py 
@api_view()
@throttle_classes([UserRateThrottle])
def xxx(request):
      pass

throttle_classes 装饰器源码

def throttle_classes(throttle_classes):
    def decorator(func):
        func.throttle_classes = throttle_classes
        return func
    return decorator
  • 可以用类的方式 APIView

因为 APIView 可以把属性throttle_classes 直接影响到类下面的每个视图函数。其实和 单独的装饰器原理是一样的

from rest_framework.response import Response
from rest_framework.throttling import UserRateThrottle
from rest_framework.views import APIView

class ExampleView(APIView):
    throttle_classes = [UserRateThrottle]

    def get(self, request, format=None):
        content = {
            'status': 'request was permitted'
        }
        return Response(content)

APIView 部分源码:

class APIView(View):

    # The following policies may be set at either globally, or per-view.
    renderer_classes = api_settings.DEFAULT_RENDERER_CLASSES
    parser_classes = api_settings.DEFAULT_PARSER_CLASSES
    authentication_classes = api_settings.DEFAULT_AUTHENTICATION_CLASSES
    throttle_classes = api_settings.DEFAULT_THROTTLE_CLASSES
    permission_classes = api_settings.DEFAULT_PERMISSION_CLASSES
    content_negotiation_class = api_settings.DEFAULT_CONTENT_NEGOTIATION_CLASS
    metadata_class = api_settings.DEFAULT_METADATA_CLASS
    versioning_class = api_settings.DEFAULT_VERSIONING_CLASS

    # Allow dependency injection of other settings to make testing easier.
    settings = api_settings
#在 APIViews 里面怎么调用的限速类呢,看下面
获取 throttle 实例
    def get_throttles(self):
        """
        Instantiates and returns the list of throttles that this view uses.
        """
        return [throttle() for throttle in self.throttle_classes]
# 执行检查 throttle 实例
    def check_throttles(self, request):
        """
        Check if request should be throttled.
        Raises an appropriate exception if the request is throttled.
        """
        for throttle in self.get_throttles():
            if not throttle.allow_request(request, self):
                self.throttled(request, throttle.wait())
# throttle 不通过速率校验返回这个函数,
    def throttled(self, request, wait):
        """
        If request is throttled, determine what kind of exception to raise.
        """
        raise exceptions.Throttled(wait)

最后返回错误消息是在

class Throttled(APIException) 这个类里面定义的。
属于 rest_framework/exceptions.py  异常包
*******************************************************

**可以看出 APIView 的默认属性都是通过api_settings 来获取的 **
看 api_settings 部分源码:

class APISettings(object):
    """
    A settings object, that allows API settings to be accessed as properties.
    For example:

        from rest_framework.settings import api_settings
        print(api_settings.DEFAULT_RENDERER_CLASSES)

    Any setting with string import paths will be automatically resolved
    and return the class, rather than the string literal.
    """
    def __init__(self, user_settings=None, defaults=None, import_strings=None):
        if user_settings:
            self._user_settings = self.__check_user_settings(user_settings)
        self.defaults = defaults or DEFAULTS
        self.import_strings = import_strings or IMPORT_STRINGS
        self._cached_attrs = set()

    @property
    def user_settings(self):
        if not hasattr(self, '_user_settings'):
            self._user_settings = getattr(settings, 'REST_FRAMEWORK', {})
        return self._user_settings

看一看出,在user_settings 里面通过读取 settings 里面的 名为 REST_FRAMEWORK 配置并且设置到当前 APISettings 实例。
这就解释了 为什么限速类 最终默认值 都是连接到 settings 里面的
读取顺序 当前函数属性(包括装饰器),然后是 APIVIew 属性 最后是settings REST_FRAMEWORK 默认配置

可以看出无论是 throttle_classes 装饰器,还是 使用类视图 继承 APIView
都是通过指定 throttle_classes 来获取用那个限速类 来限速

*******************************************************
  • 自定义限速类

直接继承 SimpleRateThrottle 就可以了,然后指定至少
scope 或者 rate 属性(当然也可以都指定,那只会生效 rate)
scope 或rate 格式:

'nymber/time'      
time 可以是 second,minute,hour或day

当然偷懒的做法就是直接继承 rest 里面写好限速类

class BurstRateThrottle(UserRateThrottle):
    scope = 'burst'

class SustainedRateThrottle(UserRateThrottle):
    scope = 'sustained'

除了 scope 和rate 属性,还有其他的一些属性,可以参考 官方 api 手册

30568818.jpg

那么限速类内部是怎么工作的呢? 看下面的源码分析一下就知道了

  • SimpleRateThrottle 部分源码

  • 初始化部分
class SimpleRateThrottle(BaseThrottle):
    """
    A simple cache implementation, that only requires `.get_cache_key()`
    to be overridden.

    The rate (requests / seconds) is set by a `rate` attribute on the View
    class.  The attribute is a string of the form 'number_of_requests/period'.

    Period should be one of: ('s', 'sec', 'm', 'min', 'h', 'hour', 'd', 'day')

    Previous request information used for throttling is stored in the cache.
    """
    cache = default_cache
    timer = time.time
    cache_format = 'throttle_%(scope)s_%(ident)s'
    scope = None
    THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES

    def __init__(self):
        if not getattr(self, 'rate', None):
            self.rate = self.get_rate()
        self.num_requests, self.duration = self.parse_rate(self.rate)
  • 函数 get_rate() 部分(获取频率的字符串)

    def get_rate(self):
        """
        Determine the string representation of the allowed request rate.
        """
        if not getattr(self, 'scope', None):
            msg = ("You must set either `.scope` or `.rate` for '%s' throttle" %
                   self.__class__.__name__)
            raise ImproperlyConfigured(msg)

        try:
            return self.THROTTLE_RATES[self.scope]
        except KeyError:
            msg = "No default throttle rate set for '%s' scope" % self.scope
            raise ImproperlyConfigured(msg)

  • parse_rate 函数部分 (解析频率字符串的参数)
    def parse_rate(self, rate):
        """
        Given the request rate string, return a two tuple of:
        , 
        """
        if rate is None:
            return (None, None)
        num, period = rate.split('/')
        num_requests = int(num)
        duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
        return (num_requests, duration)

从上面可以看出来,如果你的类有rate 属性,则直接使用,如果没有,就从settings 里面那个全局的 DEFAULT_THROTTLE_RATES 字典里面用scope 的key取出(这个和django 默认的日志系统,的用法有相似之处)

  • 最后 allow_request() 函数
    这个函数就是判断请求是不是符合频率的:
    def allow_request(self, request, view):
        """
        Implement the check to see if the request should be throttled.

        On success calls `throttle_success`.
        On failure calls `throttle_failure`.
        """
        if self.rate is None:
            return True

        self.key = self.get_cache_key(request, view)
        if self.key is None:
            return True

        self.history = self.cache.get(self.key, [])
        self.now = self.timer()

        # Drop any requests from the history which have now passed the
        # throttle duration
        while self.history and self.history[-1] <= self.now - self.duration:
            self.history.pop()
        if len(self.history) >= self.num_requests:
            return self.throttle_failure()
        return self.throttle_success()
当然还有 get_cache_key wait 这两个主要的函数,这里就不在细讲了
30568819.jpg

然后在看一下 内置的 两个限速类的源码怎么写的:
AnonRateThrottle

class AnonRateThrottle(SimpleRateThrottle):
    """
    Limits the rate of API calls that may be made by a anonymous users.

    The IP address of the request will be used as the unique cache key.
    """
    scope = 'anon'

    def get_cache_key(self, request, view):
        if request.user.is_authenticated:
            return None  # Only throttle unauthenticated requests.

        return self.cache_format % {
            'scope': self.scope,
            'ident': self.get_ident(request)
        }

UserRateThrottle

class UserRateThrottle(SimpleRateThrottle):
    """
    Limits the rate of API calls that may be made by a given user.

    The user id will be used as a unique cache key if the user is
    authenticated.  For anonymous requests, the IP address of the request will
    be used.
    """
    scope = 'user'

    def get_cache_key(self, request, view):
        if request.user.is_authenticated:
            ident = request.user.pk   # 登录用户使用密码作为 计数key
        else:
            ident = self.get_ident(request)  # 没有登录用户使用 ip 作为  计数 key 

        return self.cache_format % {
            'scope': self.scope,
            'ident': ident
        }

从上面可以看出,主要是 get_cache_key 函数要返回 scope 和 ident 属性,因为他们在父类中会有方法使用到。ident 是给请求计算次数用的 肯可能是匿名用户的ip 或者已认证用户的token

好了,总结这么多就差不多了 累死宝宝了

后续理解:
api_view() 装饰器的工作流程:

def api_view(http_method_names=None):
    """
    Decorator that converts a function-based view into an APIView subclass.
    Takes a list of allowed methods for the view as an argument.
    """
    def decorator(func):
        # 变成子类
        WrappedAPIView = type(
            six.PY3 and 'WrappedAPIView' or b'WrappedAPIView',
            (APIView,),
            {'__doc__': func.__doc__}
        )
上面相当于 :
        #     class WrappedAPIView(APIView):
        #         pass
        #     WrappedAPIView.__doc__ = func.doc    <--- Not possible to do this

        def handler(self, *args, **kwargs):
            return func(*args, **kwargs)

        for method in http_method_names:
            setattr(WrappedAPIView, method.lower(), handler)

        WrappedAPIView.__name__ = func.__name__
        WrappedAPIView.__module__ = func.__module__

        WrappedAPIView.renderer_classes = getattr(func, 'renderer_classes',
                                                  APIView.renderer_classes)

        WrappedAPIView.parser_classes = getattr(func, 'parser_classes',
                                                APIView.parser_classes)

        WrappedAPIView.authentication_classes = getattr(func, 'authentication_classes',
                                                        APIView.authentication_classes)

        WrappedAPIView.throttle_classes = getattr(func, 'throttle_classes',
                                                  APIView.throttle_classes)

        WrappedAPIView.permission_classes = getattr(func, 'permission_classes',
                                                    APIView.permission_classes)

        WrappedAPIView.schema = getattr(func, 'schema',
                                        APIView.schema)

        return WrappedAPIView.as_view()

    return decorator

从上面的解释可以看出,api_view 函数,其实是把装饰的函数变成了 APIView 的一个子类。然后把被装饰的函数里面的属性进行覆盖掉 APIView 里面的默认属性
主要有一下属性:
函数普通的 比如 namedoc
还有APIView 的内置属性,比如 renderer_classes 、parser_classes 、authentication_classes 、throttle_classes 、 permission_classes 、schema

注意最后调用 了 as_view( ) 函数,看的出来 rest_framework 的api_view 函数,其实是整合了其他装饰器。并最后返回一个 APIView.as_view( ) 函数(这就和 django里面使用 类视图,包括自动匹配 get post 方法 练联系起来了。)
还有一点
rest_farmwork 的视图使用 好像最后都要APIView 那个返回,不然报错?至少我工作项目里面会报错。可能有什么配置,以后再说吧
参考 rest_farmwork views 的写法(他里面说了 有类视图继承(APIView),和 函数视图(@api_view()) 这样就会保证 对接了 rest 的东西 比如 request 。
https://www.django-rest-framework.org/api-guide/views/

最后举个例子,对每个用户的每个接口进行速率的限制:


对特定用户,进行设置速率,比如,管理员不限次数,普通用户 五次每分钟:

class StrictRate(UserRateThrottle):
    """普通用户一分钟五次, 管理员不限次数"""

    def allow_request(self, request, view):
        # 管理员不限次数,普通用户 5次 一分钟
        self.rate = None if request.user.is_superuser else "5/m"
        super(StrictRate, self).allow_request(request, view)

就是继承,在allow_request 里面 根据user 属性设置 rate 就行了

你可能感兴趣的:(django rest farmework 限制接口请求速率)