参考官方的文档:
https://www.django-rest-framework.org/api-guide/throttling/
-
全局配置:
在settings 里面全局配置
REST_FRAMEWORK = {
'DEFAULT_THROTTLE_CLASSES': [
'rest_framework.throttling.AnonRateThrottle',
'rest_framework.throttling.UserRateThrottle'
],
'DEFAULT_THROTTLE_RATES': {
'anon': '100/day',
'user': '1000/day'
}
}
anon 和 user分别对应默认的AnonRateThrottle 和UserRateThrottle 类
每个类都有一个 scope 的属性
**rest 内置了,这三个限速类 **
AnonRateThrottle scope 属性的值是 anon
UserRateThrottle scope 属性的值是user
每个用户进行简单的全局速率限制,那么 UserRateThrottle 是合适的
上面两个一个用户的所有请求都是累加的 (一个用户的所有视图请求次数累加到一起,判断是不是超速了)
同一 scope和 token 组成 唯一的key,对应请求次数
参考:
django rest framework throttling per user and per view
**ScopedRateThrottle 针对 不同api进行统计 视图必须包含 throttle_scope 属性 **
(其实如果你要对每个视图的每个用户进行速率限制,那么只需要自定义即可 比如 你把自定义的 scope 设置为 请求的 api 路径,文末会讲一下)
-
DEFAULT_THROTTLE_CLASSES :
指定全局使用的限速类
-
DEFAULT_THROTTLE_RATES :
指定全局的scope对应的限速字符串参数
-
DEFAULT_THROTTLE_CLASSES:上面的官方文档的配置其实是全局使用了这两个限速类,并且配置了对应的scope (因为这连个内置类只能从 settings 的 user 和anon 读取 对应的rate 参数 )
如果只想全局配置速率参数,那把限速类 :DEFAULT_THROTTLE_CLASSES 配置去掉即可。
-
单独对函数配置
-
不用全局配置的情况下,可以使用 throttle_classes装饰器 单独对某个 views 函数进行作用:
xxx.views.py
@api_view()
@throttle_classes([UserRateThrottle])
def xxx(request):
pass
throttle_classes 装饰器源码
def throttle_classes(throttle_classes):
def decorator(func):
func.throttle_classes = throttle_classes
return func
return decorator
-
可以用类的方式 APIView
因为 APIView 可以把属性throttle_classes 直接影响到类下面的每个视图函数。其实和 单独的装饰器原理是一样的
from rest_framework.response import Response
from rest_framework.throttling import UserRateThrottle
from rest_framework.views import APIView
class ExampleView(APIView):
throttle_classes = [UserRateThrottle]
def get(self, request, format=None):
content = {
'status': 'request was permitted'
}
return Response(content)
APIView 部分源码:
class APIView(View):
# The following policies may be set at either globally, or per-view.
renderer_classes = api_settings.DEFAULT_RENDERER_CLASSES
parser_classes = api_settings.DEFAULT_PARSER_CLASSES
authentication_classes = api_settings.DEFAULT_AUTHENTICATION_CLASSES
throttle_classes = api_settings.DEFAULT_THROTTLE_CLASSES
permission_classes = api_settings.DEFAULT_PERMISSION_CLASSES
content_negotiation_class = api_settings.DEFAULT_CONTENT_NEGOTIATION_CLASS
metadata_class = api_settings.DEFAULT_METADATA_CLASS
versioning_class = api_settings.DEFAULT_VERSIONING_CLASS
# Allow dependency injection of other settings to make testing easier.
settings = api_settings
#在 APIViews 里面怎么调用的限速类呢,看下面
获取 throttle 实例
def get_throttles(self):
"""
Instantiates and returns the list of throttles that this view uses.
"""
return [throttle() for throttle in self.throttle_classes]
# 执行检查 throttle 实例
def check_throttles(self, request):
"""
Check if request should be throttled.
Raises an appropriate exception if the request is throttled.
"""
for throttle in self.get_throttles():
if not throttle.allow_request(request, self):
self.throttled(request, throttle.wait())
# throttle 不通过速率校验返回这个函数,
def throttled(self, request, wait):
"""
If request is throttled, determine what kind of exception to raise.
"""
raise exceptions.Throttled(wait)
最后返回错误消息是在
class Throttled(APIException) 这个类里面定义的。
属于 rest_framework/exceptions.py 异常包
*******************************************************
**可以看出 APIView 的默认属性都是通过api_settings 来获取的 **
看 api_settings 部分源码:
class APISettings(object):
"""
A settings object, that allows API settings to be accessed as properties.
For example:
from rest_framework.settings import api_settings
print(api_settings.DEFAULT_RENDERER_CLASSES)
Any setting with string import paths will be automatically resolved
and return the class, rather than the string literal.
"""
def __init__(self, user_settings=None, defaults=None, import_strings=None):
if user_settings:
self._user_settings = self.__check_user_settings(user_settings)
self.defaults = defaults or DEFAULTS
self.import_strings = import_strings or IMPORT_STRINGS
self._cached_attrs = set()
@property
def user_settings(self):
if not hasattr(self, '_user_settings'):
self._user_settings = getattr(settings, 'REST_FRAMEWORK', {})
return self._user_settings
看一看出,在user_settings 里面通过读取 settings 里面的 名为 REST_FRAMEWORK 配置并且设置到当前 APISettings 实例。
这就解释了 为什么限速类 最终默认值 都是连接到 settings 里面的
读取顺序 当前函数属性(包括装饰器),然后是 APIVIew 属性 最后是settings REST_FRAMEWORK 默认配置
可以看出无论是 throttle_classes 装饰器,还是 使用类视图 继承 APIView
都是通过指定 throttle_classes 来获取用那个限速类 来限速
*******************************************************
-
自定义限速类
直接继承 SimpleRateThrottle 就可以了,然后指定至少
scope 或者 rate 属性(当然也可以都指定,那只会生效 rate)
scope 或rate 格式:
'nymber/time'
time 可以是 second,minute,hour或day
当然偷懒的做法就是直接继承 rest 里面写好限速类
class BurstRateThrottle(UserRateThrottle):
scope = 'burst'
class SustainedRateThrottle(UserRateThrottle):
scope = 'sustained'
除了 scope 和rate 属性,还有其他的一些属性,可以参考 官方 api 手册
那么限速类内部是怎么工作的呢? 看下面的源码分析一下就知道了
-
SimpleRateThrottle 部分源码
- 初始化部分
class SimpleRateThrottle(BaseThrottle):
"""
A simple cache implementation, that only requires `.get_cache_key()`
to be overridden.
The rate (requests / seconds) is set by a `rate` attribute on the View
class. The attribute is a string of the form 'number_of_requests/period'.
Period should be one of: ('s', 'sec', 'm', 'min', 'h', 'hour', 'd', 'day')
Previous request information used for throttling is stored in the cache.
"""
cache = default_cache
timer = time.time
cache_format = 'throttle_%(scope)s_%(ident)s'
scope = None
THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES
def __init__(self):
if not getattr(self, 'rate', None):
self.rate = self.get_rate()
self.num_requests, self.duration = self.parse_rate(self.rate)
- 函数 get_rate() 部分(获取频率的字符串)
def get_rate(self):
"""
Determine the string representation of the allowed request rate.
"""
if not getattr(self, 'scope', None):
msg = ("You must set either `.scope` or `.rate` for '%s' throttle" %
self.__class__.__name__)
raise ImproperlyConfigured(msg)
try:
return self.THROTTLE_RATES[self.scope]
except KeyError:
msg = "No default throttle rate set for '%s' scope" % self.scope
raise ImproperlyConfigured(msg)
- parse_rate 函数部分 (解析频率字符串的参数)
def parse_rate(self, rate):
"""
Given the request rate string, return a two tuple of:
,
"""
if rate is None:
return (None, None)
num, period = rate.split('/')
num_requests = int(num)
duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
return (num_requests, duration)
从上面可以看出来,如果你的类有rate 属性,则直接使用,如果没有,就从settings 里面那个全局的 DEFAULT_THROTTLE_RATES 字典里面用scope 的key取出(这个和django 默认的日志系统,的用法有相似之处)
- 最后 allow_request() 函数
这个函数就是判断请求是不是符合频率的:
def allow_request(self, request, view):
"""
Implement the check to see if the request should be throttled.
On success calls `throttle_success`.
On failure calls `throttle_failure`.
"""
if self.rate is None:
return True
self.key = self.get_cache_key(request, view)
if self.key is None:
return True
self.history = self.cache.get(self.key, [])
self.now = self.timer()
# Drop any requests from the history which have now passed the
# throttle duration
while self.history and self.history[-1] <= self.now - self.duration:
self.history.pop()
if len(self.history) >= self.num_requests:
return self.throttle_failure()
return self.throttle_success()
当然还有 get_cache_key wait 这两个主要的函数,这里就不在细讲了
然后在看一下 内置的 两个限速类的源码怎么写的:
AnonRateThrottle
class AnonRateThrottle(SimpleRateThrottle):
"""
Limits the rate of API calls that may be made by a anonymous users.
The IP address of the request will be used as the unique cache key.
"""
scope = 'anon'
def get_cache_key(self, request, view):
if request.user.is_authenticated:
return None # Only throttle unauthenticated requests.
return self.cache_format % {
'scope': self.scope,
'ident': self.get_ident(request)
}
UserRateThrottle
class UserRateThrottle(SimpleRateThrottle):
"""
Limits the rate of API calls that may be made by a given user.
The user id will be used as a unique cache key if the user is
authenticated. For anonymous requests, the IP address of the request will
be used.
"""
scope = 'user'
def get_cache_key(self, request, view):
if request.user.is_authenticated:
ident = request.user.pk # 登录用户使用密码作为 计数key
else:
ident = self.get_ident(request) # 没有登录用户使用 ip 作为 计数 key
return self.cache_format % {
'scope': self.scope,
'ident': ident
}
从上面可以看出,主要是 get_cache_key 函数要返回 scope 和 ident 属性,因为他们在父类中会有方法使用到。ident 是给请求计算次数用的 肯可能是匿名用户的ip 或者已认证用户的token
好了,总结这么多就差不多了 累死宝宝了
后续理解:
api_view() 装饰器的工作流程:
def api_view(http_method_names=None):
"""
Decorator that converts a function-based view into an APIView subclass.
Takes a list of allowed methods for the view as an argument.
"""
def decorator(func):
# 变成子类
WrappedAPIView = type(
six.PY3 and 'WrappedAPIView' or b'WrappedAPIView',
(APIView,),
{'__doc__': func.__doc__}
)
上面相当于 :
# class WrappedAPIView(APIView):
# pass
# WrappedAPIView.__doc__ = func.doc <--- Not possible to do this
def handler(self, *args, **kwargs):
return func(*args, **kwargs)
for method in http_method_names:
setattr(WrappedAPIView, method.lower(), handler)
WrappedAPIView.__name__ = func.__name__
WrappedAPIView.__module__ = func.__module__
WrappedAPIView.renderer_classes = getattr(func, 'renderer_classes',
APIView.renderer_classes)
WrappedAPIView.parser_classes = getattr(func, 'parser_classes',
APIView.parser_classes)
WrappedAPIView.authentication_classes = getattr(func, 'authentication_classes',
APIView.authentication_classes)
WrappedAPIView.throttle_classes = getattr(func, 'throttle_classes',
APIView.throttle_classes)
WrappedAPIView.permission_classes = getattr(func, 'permission_classes',
APIView.permission_classes)
WrappedAPIView.schema = getattr(func, 'schema',
APIView.schema)
return WrappedAPIView.as_view()
return decorator
从上面的解释可以看出,api_view 函数,其实是把装饰的函数变成了 APIView 的一个子类。然后把被装饰的函数里面的属性进行覆盖掉 APIView 里面的默认属性
主要有一下属性:
函数普通的 比如 name , doc 等
还有APIView 的内置属性,比如 renderer_classes 、parser_classes 、authentication_classes 、throttle_classes 、 permission_classes 、schema
注意最后调用 了 as_view( ) 函数,看的出来 rest_framework 的api_view 函数,其实是整合了其他装饰器。并最后返回一个 APIView.as_view( ) 函数(这就和 django里面使用 类视图,包括自动匹配 get post 方法 练联系起来了。)
还有一点
rest_farmwork 的视图使用 好像最后都要APIView 那个返回,不然报错?至少我工作项目里面会报错。可能有什么配置,以后再说吧
参考 rest_farmwork views 的写法(他里面说了 有类视图继承(APIView),和 函数视图(@api_view()) 这样就会保证 对接了 rest 的东西 比如 request 。
https://www.django-rest-framework.org/api-guide/views/
最后举个例子,对每个用户的每个接口进行速率的限制:
对特定用户,进行设置速率,比如,管理员不限次数,普通用户 五次每分钟:
class StrictRate(UserRateThrottle):
"""普通用户一分钟五次, 管理员不限次数"""
def allow_request(self, request, view):
# 管理员不限次数,普通用户 5次 一分钟
self.rate = None if request.user.is_superuser else "5/m"
super(StrictRate, self).allow_request(request, view)
就是继承,在allow_request 里面 根据user 属性设置 rate 就行了