【Django】DRF源码分析之三大认证

纸上得来终觉浅,绝知此事要躬行。

前言

之前在【Django】DRF源码分析之五大模块文章中没有讲到认证模块,本章就主要来谈谈认证模块中的三大认证,首先我们先回顾一下DRF请求的流程:

  1. 前台发送请求,后台接受,进行urls.py中的url匹配,执行对应类视图调用as_view()方法
from django.conf.urls import url
from . import views

urlpatterns = [
    url(r'^v1/users/$', views.User.as_view())
]
  1. 之后在APIView中调用父类as_view(),并且在闭包中调用了dispatch()方法,该方法调用的APIView类中的(该类重写了父类)
    def dispatch(self, request, *args, **kwargs):
        
        ......
        # 请求模块和解析模块
        request = self.initialize_request(request, *args, **kwargs)
        
        ......

        try:
            # 三大认证模块
            self.initial(request, *args, **kwargs)

            ......

            # 响应模块
            response = handler(request, *args, **kwargs)

        except Exception as exc:
            # 异常模块
            response = self.handle_exception(exc)
        # 渲染模块
        self.response = self.finalize_response(request, response, *args, **kwargs)
        return self.response
  1. 先执行self.initialize_request(request, *args, **kwargs)请求模块,此步骤是rest_framework对request进行了扩展封装和兼容
    def initialize_request(self, request, *args, **kwargs):
        """
        Returns the initial request object.
        """
        parser_context = self.get_parser_context(request)

        return Request(
            request,
            parsers=self.get_parsers(),
            authenticators=self.get_authenticators(),
            negotiator=self.get_content_negotiator(),
            parser_context=parser_context
        )
  1. 请求模块走完之后,接下来就是认证模块,执行self.initial(request, *args, **kwargs),点击源码进入。
    def initial(self, request, *args, **kwargs):
        ......

        # 认证组件
        self.perform_authentication(request)

        # 权限组件
        self.check_permissions(request)

        # 限流组件
        self.check_throttles(request)

目前请求走到这里,就是我们今天需要讨论的认证模块,分为三个部分,以此来从源码进行剖析。

认证组件

源码分析

首先执行的就是self.perform_authentication(request)(认证组件),点击源码查看:

    def perform_authentication(self, request):
        
        request.user

我们发现该方法只有一行代码,没有返回值,也没有赋值,也不能继续点击进入(可能会出现一堆的东西),但是我们的目的就是找认证组件认证方法,所以猜想这句话就是调用方法,是不是很可能被@property装饰了,还是通过request对象调用的,所以我们就rest_framework/request.py下面找Request类(因为之前的请求模块对原生request进行了扩展就是使用的该类)中的user方法,发现源码如下:

    @property
    def user(self):
        """
        Returns the user associated with the current request, as authenticated
        by the authentication classes provided to the request.
        """
        if not hasattr(self, '_user'):
            with wrap_attributeerrors():
                # 没用户,认证用户
                self._authenticate()
        # 有用户,直接返回
        return self._user

发现对于认证的函数只调用了self._authenticate(),我们继续点击进入,源码分析如下图:

【Django】DRF源码分析之三大认证_第1张图片

结合上图我们需要分析出下面几个问题:

  1. self.authenticators是啥?
  2. authenticate(self)方法执行了什么玩意?
  3. self._not_authenticated()方法做了啥?

先解决第一个问题(self.authenticators是啥?)我们直接点击进去发现跑到了Request类的__init__方法肯定不对,我们往回找,他是request的属性,记得之前请求模块对request进行了扩展,回去发现在APIView类的下面有self.initialize_request(request, *args, **kwargs)方法中有下面的代码:

Request(
    request,
    parsers=self.get_parsers(),
    authenticators=self.get_authenticators(),
    negotiator=self.get_content_negotiator(),
    parser_context=parser_context
)

发现传入了authenticators,他等于self.get_authenticators()的返回值,所以我们去查找self.get_authenticators()的源码:

    def get_authenticators(self):
        """
        Instantiates and returns the list of authenticators that this view can use.
        """
        return [auth() for auth in self.authentication_classes]

发现结果是一个列表推导式,所以上图中可以进行遍历,而且列表中装的也都是对象,我们就去看看到底是什么类的对象,点击查看authentication_classes,发现是通过authentication_classes = api_settings.DEFAULT_AUTHENTICATION_CLASSES配置的,所以我们可以去api_settings查找,结果代码如下:

'DEFAULT_AUTHENTICATION_CLASSES': [
        'rest_framework.authentication.SessionAuthentication',
        'rest_framework.authentication.BasicAuthentication'
    ],

默认写了两个类,ok,目前我们已经知道第一个问题(self.authenticators是啥?),他默认就是这两个类的对象列表,下面就是解决第二个问题(authenticate(self)方法执行了什么玩意?),该方法是这两个类的方法,所以我们去这两个默认的类去查看,源码查看文件rest_framework/authentication.py,下面是UML图以及类继承图:

【Django】DRF源码分析之三大认证_第2张图片

【Django】DRF源码分析之三大认证_第3张图片

由此我们发现BaseAuthentication是其他类的父类,而且每个类都有authenticate方法,我们首先查看一下BasicAuthentication类中的authenticate方法实现:

def get_authorization_header(request):
    """
    Return request's 'Authorization:' header, as a bytestring.

    Hide some test client ickyness where the header can be unicode.
    """
    auth = request.META.get('HTTP_AUTHORIZATION', b'')
    if isinstance(auth, str):
        # Work around django test client oddness
        auth = auth.encode(HTTP_HEADER_ENCODING)
    return auth

class BasicAuthentication(BaseAuthentication):
    """
    HTTP Basic authentication against username/password.
    """
    www_authenticate_realm = 'api'

    def authenticate(self, request):
        """
        Returns a `User` if a correct username and password have been supplied
        using HTTP Basic authentication.  Otherwise returns `None`.
        """
        auth = get_authorization_header(request).split() # 第一步:从请求头获取token信息按照空格分割

        if not auth or auth[0].lower() != b'basic': # 第二步:判断我们的值格式:“basic xxxxxxxx”,就是有两段,中间空格隔开
            return None
     
        # 校验分割长度是不是等于2
        if len(auth) == 1: 
            msg = _('Invalid basic header. No credentials provided.')
            raise exceptions.AuthenticationFailed(msg)
        elif len(auth) > 2:
            msg = _('Invalid basic header. Credentials string should not contain spaces.')
            raise exceptions.AuthenticationFailed(msg)

        # 把token值按照一定规则解密
        try:
            auth_parts = base64.b64decode(auth[1]).decode(HTTP_HEADER_ENCODING).partition(':')
        except (TypeError, UnicodeDecodeError, binascii.Error):
            msg = _('Invalid basic header. Credentials not correctly base64 encoded.')
            raise exceptions.AuthenticationFailed(msg)
        
        userid, password = auth_parts[0], auth_parts[2]
        return self.authenticate_credentials(userid, password, request)

    def authenticate_credentials(self, userid, password, request=None):
        """
        Authenticate the userid and password against username and password
        with optional request for context.
        """
        credentials = {
            get_user_model().USERNAME_FIELD: userid,
            'password': password
        }
        user = authenticate(request=request, **credentials)

        if user is None:
            raise exceptions.AuthenticationFailed(_('Invalid username/password.'))

        if not user.is_active:
            raise exceptions.AuthenticationFailed(_('User inactive or deleted.'))

        return (user, None)

    def authenticate_header(self, request):
        return 'Basic realm="%s"' % self.www_authenticate_realm

分析过程:
1. 调用get_authorization_header从请求头获取,Authorization 的值,一般就是token信息,并且按照空格分割
2. 分割完成,判断我们的第一部分是不是basic
3. 校验分割长度是不是等于2
4. 把token值按照一定规则解密,分配
5. 调用self.authenticate_credentials(userid, password, request),可以看成通过解密的信息查询用户,最终返回元祖类型的数据

目第二个问题也已经解决,得知返回的结果是一个(user,None)的元祖,然后把元祖信息拆分给request.userrequest.auth。如果其中任意一个地方发生异常都会调用self._not_authenticated(),下面我们就来看看第三个问题(self._not_authenticated()方法做了啥?)

    def _not_authenticated(self):
        """
        Set authenticator, user & authtoken representing an unauthenticated request.

        Defaults are None, AnonymousUser & None.
        """
        self._authenticator = None

        if api_settings.UNAUTHENTICATED_USER:
            self.user = api_settings.UNAUTHENTICATED_USER()
        else:
            self.user = None

        if api_settings.UNAUTHENTICATED_TOKEN:
            self.auth = api_settings.UNAUTHENTICATED_TOKEN()
        else:
            self.auth = None

源码其实很简单,就是给self.userself.auth赋值,其实就相当于给request.userrequest.auth赋值。其中api_settings.UNAUTHENTICATED_USER()表示的是一个匿名用户也可以理解为游客,而api_settings.UNAUTHENTICATED_TOKEN(),默认值为None。可以在api_settings中查看'UNAUTHENTICATED_USER': 'django.contrib.auth.models.AnonymousUser''UNAUTHENTICATED_TOKEN': None,

整个认证的过程分析完成,我们可以知道大致流程就是:

  • 未携带认证信息的用户访问(游客或匿名用户),返回None赋值给经过_not_authenticated方法,赋值为request.user和request.auth
  • 携带认证信息的用户访问,返回(user,None)赋值给request.user和request.auth
  • 携带错误认证信息或者认证信息失效的用户,抛出异常调用_not_authenticated方法,赋值给request.user和request.auth

也即是说我们可以通过在类视图的request对象直接获取当前访问的用户,判断他是登录用户还是游客。

自定义认证类

通过源码的分析,我们可以知道实现自定义认证类必要条件,继承BaseAuthentication,然后实现authenticate方法,至于验证的逻辑可以结合业务编写,最终返回(user,auth)的元祖

#继承BaseAuthentication
class MyAuthentication(BaseAuthentication):
    def authenticate(self, request):   #重写authenticate方法
        # 1. 从请求的META获取token信息
        # 2. 判断信息是否合法或这不存在
          # 2.1 不存在:表示游客,返回None
          # 2.2 存在但是错误:非法用户,抛出异常
          # 2.3 存在且正确:返回 (用户, 认证信息)

        return (user,None)
  • 全局配置(settings.py)
REST_FRAMEWORK = {
    # 认证类配置
    'DEFAULT_AUTHENTICATION_CLASSES': [
        'rest_framework.authentication.SessionAuthentication',
        'rest_framework.authentication.BasicAuthentication',
        'xxxx.xxxxxx.MyAuthentication'
        # eg:'utils.authentications.MyAuthentication'
    ]
}
  • 局部配置(CBV)
def xxxx(APIView):
    authentication_classes = (MyAuthentication,SessionAuthentication,BasicAuthentication)
    ......
    def get():
        ......

权限组件

源码分析

经过认证组件之后我们知道request对象中保存这当前请求的用户,下面执行self.check_permissions(request)方法,点击进入源码

    def check_permissions(self, request):
        """
        Check if the request should be permitted.
        Raises an appropriate exception if the request is not permitted.
        """
        for permission in self.get_permissions():
            if not permission.has_permission(request, self):
                self.permission_denied(
                    request, message=getattr(permission, 'message', None)
                )

看到这个过程简直是似曾相识,对,他和认证组件一个设计模式,通过self.get_permissions()获取权限类的对象列表,然后遍历,源码:

    def get_permissions(self):
        """
        Instantiates and returns the list of permissions that this view requires.
        """
        return [permission() for permission in self.permission_classes]

发现同样是一个列表推导式,查看源码发现他和认证组件就是放在一起,接着点击self.permission_classes,同样发现permission_classes = api_settings.DEFAULT_PERMISSION_CLASSES也是由api_settings配置得到,从rest_framework/settings.py文件找到得到默认配置。

'DEFAULT_PERMISSION_CLASSES': [
        'rest_framework.permissions.AllowAny',
    ],

默认配置了一个AllowAny,此时程序开始遍历权限类的对象,执行has_permission方法的到返回值,如果为False表示他没有权限,继续执行self.permission_denied()方法直接抛出异常,为True则遍历下一个,直到全部为True,遍历结束,什么也不做就表示拥有配置的所有权限。

所以首先,让我们去了解has_permission到底做了什么?以及系统默认包含了那些认证类,通过rest_framework/permissions.py可以查看所有的权限类

【Django】DRF源码分析之三大认证_第4张图片

【Django】DRF源码分析之三大认证_第5张图片

发现大致分为以下几个类,BasePermission类是其他类的父类,而且每个类都实现了has_permission方法。

class BasePermission(metaclass=BasePermissionMetaclass):
    """
    A base class from which all permission classes should inherit.
    """

    def has_permission(self, request, view):
        """
        Return `True` if permission is granted, `False` otherwise.
        """
        return True

    def has_object_permission(self, request, view, obj):
        """
        Return `True` if permission is granted, `False` otherwise.
        """
        return True


class AllowAny(BasePermission):
    """
    Allow any access.
    This isn't strictly required, since you could use an empty
    permission_classes list, but it's useful because it makes the intention
    more explicit.
    """

    def has_permission(self, request, view):
        return True


class IsAuthenticated(BasePermission):
    """
    Allows access only to authenticated users.
    """

    def has_permission(self, request, view):
        return bool(request.user and request.user.is_authenticated)


class IsAdminUser(BasePermission):
    """
    Allows access only to admin users.
    """

    def has_permission(self, request, view):
        return bool(request.user and request.user.is_staff)


class IsAuthenticatedOrReadOnly(BasePermission):
    """
    The request is authenticated as a user, or is a read-only request.
    """

    def has_permission(self, request, view):
        return bool(
            request.method in SAFE_METHODS or
            request.user and
            request.user.is_authenticated
        )

接下来分别解释一个每个类:

  • AllowAny:直接返回True,任何用户拥有权限
  • IsAuthenticated:必须是认证信息通过的用户
  • IsAdminUser:必须是认证信息通过的用户且is_staff为True的用户,数据库保存的结果为1
  • IsAuthenticatedOrReadOnly:表示通过认证用户拥有权限或者游客以及认证失败的用户只能有SAFE_METHODS属性内定义的请求方法,默认为SAFE_METHODS = ('GET', 'HEAD', 'OPTIONS')

权限组件相对过程比较简单,因为他是建立在认证组件基础之上,下面就让我们自定义权限组件。

自定义权限组件

通过源码的分析,我们可以同样也知道实现自定义权限类必要条件,继承BasePermission,然后实现has_permission方法,最终通过判断返回True或False。

from rest_framework.permissions import BasePermission

class MyPermission(BasePermission):
    def has_permission(self, request, view):
        # 判断逻辑xxxxxxx
        # 返回True或False
        return True or Flase
  • 全局配置(settings.py)
REST_FRAMEWORK = {
    #  权限类配置
    'DEFAULT_PERMISSION_CLASSES': [
        'utils.permissions.MyPermission',
    ],
}
  • 局部配置(CBV)
def xxxx(APIView):
    permission_classes = (MyPermission,)
    .....
    def get():
        ....

限流组件

源码分析

前面的认证和权限组件处理完成之后接下来就是限流组件,代码运行到self.check_throttles(request),点击查看源码:

    def check_throttles(self, request):
        """
        Check if request should be throttled.
        Raises an appropriate exception if the request is throttled.
        """
        throttle_durations = []
        for throttle in self.get_throttles():
            if not throttle.allow_request(request, self):
                throttle_durations.append(throttle.wait())

        if throttle_durations:
            # Filter out `None` values which may happen in case of config / rate
            # changes, see #1438
            durations = [
                duration for duration in throttle_durations
                if duration is not None
            ]

            duration = max(durations, default=None)
            self.throttled(request, duration)

首先定义了一个throttle_durations空列表,之后又是循环遍历self.get_throttles(),可以想象他和认证组件、权限组件应该是一个样子,返回限流类对象列表,源码如下:

    def get_throttles(self):
        """
        Instantiates and returns the list of throttles that this view uses.
        """
        return [throttle() for throttle in self.throttle_classes]

果然是一个列表推导式,保存的是限流类对象,同样我们也会想到它应该也是通过api_settings配置,点击self.throttle_classes查看throttle_classes = api_settings.DEFAULT_THROTTLE_CLASSES,继续查看默认配置信息。

'DEFAULT_THROTTLE_CLASSES': [],

发现结果是一个空列表,意思也就是,默认并没有采用任何一个类来限制用户请求频率。通过认证的类定义在rest_frameworks/authentication.py和权限类定义在rest_framework/permission.py,我们应该在rest_framework下面查找类似throttle类,即rest_frameworks/throttling.py,并且通过源码应该不难发现他们应该都实现了allow_request方法。

【Django】DRF源码分析之三大认证_第6张图片

【Django】DRF源码分析之三大认证_第7张图片

通过类的继承关系我们发现SimpleRateThrottle是其他三个类的父类,它虽然继承自BaseThrottle,但是基本重写了BaseThrottle的方法。下面我们查看一下SimpleRateThrottle类的源码以及相关的之类:

class SimpleRateThrottle(BaseThrottle):
    """
    A simple cache implementation, that only requires `.get_cache_key()`
    to be overridden.

    The rate (requests / seconds) is set by a `rate` attribute on the View
    class.  The attribute is a string of the form 'number_of_requests/period'.

    Period should be one of: ('s', 'sec', 'm', 'min', 'h', 'hour', 'd', 'day')

    Previous request information used for throttling is stored in the cache.
    """
    cache = default_cache
    timer = time.time
    cache_format = 'throttle_%(scope)s_%(ident)s'
    scope = None
    THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES

    def __init__(self):
        if not getattr(self, 'rate', None):
            self.rate = self.get_rate()
        self.num_requests, self.duration = self.parse_rate(self.rate)

    def get_cache_key(self, request, view):
        """
        Should return a unique cache-key which can be used for throttling.
        Must be overridden.

        May return `None` if the request should not be throttled.
        """
        raise NotImplementedError('.get_cache_key() must be overridden')

    def get_rate(self):
        """
        Determine the string representation of the allowed request rate.
        """
        if not getattr(self, 'scope', None):
            msg = ("You must set either `.scope` or `.rate` for '%s' throttle" %
                   self.__class__.__name__)
            raise ImproperlyConfigured(msg)

        try:
            return self.THROTTLE_RATES[self.scope]
        except KeyError:
            msg = "No default throttle rate set for '%s' scope" % self.scope
            raise ImproperlyConfigured(msg)

    def parse_rate(self, rate):
        """
        Given the request rate string, return a two tuple of:
        , 
        """
        if rate is None:
            return (None, None)
        num, period = rate.split('/')
        num_requests = int(num)
        duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
        return (num_requests, duration)

    def allow_request(self, request, view):
        """
        Implement the check to see if the request should be throttled.

        On success calls `throttle_success`.
        On failure calls `throttle_failure`.
        """
        if self.rate is None:
            return True

        self.key = self.get_cache_key(request, view)
        if self.key is None:
            return True

        # django缓存
        # 1.导包
        # 2.添加缓存 cache.set(key,value,exp)
        # 3.获取缓存 cache.get(key,default_value)
        self.history = self.cache.get(self.key, [])
        self.now = self.timer()

        # Drop any requests from the history which have now passed the
        # throttle duration
        while self.history and self.history[-1] <= self.now - self.duration:
            self.history.pop()
        if len(self.history) >= self.num_requests:
            return self.throttle_failure()
        return self.throttle_success()

    def throttle_success(self):
        """
        Inserts the current request's timestamp along with the key
        into the cache.
        """
        self.history.insert(0, self.now)
        self.cache.set(self.key, self.history, self.duration)
        return True

    def throttle_failure(self):
        """
        Called when a request to the API has failed due to throttling.
        """
        return False

    def wait(self):
        """
        Returns the recommended next request time in seconds.
        """
        if self.history:
            remaining_duration = self.duration - (self.now - self.history[-1])
        else:
            remaining_duration = self.duration

        available_requests = self.num_requests - len(self.history) + 1
        if available_requests <= 0:
            return None

        return remaining_duration / float(available_requests)


class AnonRateThrottle(SimpleRateThrottle):
    """
    Limits the rate of API calls that may be made by a anonymous users.

    The IP address of the request will be used as the unique cache key.
    """
    scope = 'anon'

    def get_cache_key(self, request, view):
        if request.user.is_authenticated:
            return None  # Only throttle unauthenticated requests.

        return self.cache_format % {
            'scope': self.scope,
            'ident': self.get_ident(request)
        }


class UserRateThrottle(SimpleRateThrottle):
    """
    Limits the rate of API calls that may be made by a given user.

    The user id will be used as a unique cache key if the user is
    authenticated.  For anonymous requests, the IP address of the request will
    be used.
    """
    scope = 'user'

    def get_cache_key(self, request, view):
        if request.user.is_authenticated:
            ident = request.user.pk
        else:
            ident = self.get_ident(request)

        return self.cache_format % {
            'scope': self.scope,
            'ident': ident
        }

之前我们说过默认的配置文件中没有定义限流类,假定我们任意选择一种类进行定义,在遍历限流类对象列表时都会调用allow_request方法,由于UserRateThrottleAnonRateThrottle不存在该方法,所以查找其父类,最终无论使用那个类都会调用SimpleRateThrottle类的allow_request方法。

下面假定我们定义了UserRateThrottle类,所以此时代码执行SimpleRateThrottle类的allow_request方法。此时需要注意的是,遍历的是限流类对象列表,也就是说明类进行了初始化,也就是说在之前执行了SimpleRateThrottle类的__init__方法。所以我们先从实例化开始分析。

def __init__(self):
        if not getattr(self, 'rate', None):
            self.rate = self.get_rate()
        self.num_requests, self.duration = self.parse_rate(self.rate)

代码分析:

  1. 判断对象是否有rate属性,不存在执行self.get_rate()
    def get_rate(self):
        """
        Determine the string representation of the allowed request rate.
        """
        if not getattr(self, 'scope', None):
            msg = ("You must set either `.scope` or `.rate` for '%s' throttle" %
                   self.__class__.__name__)
            raise ImproperlyConfigured(msg)

        try:
            return self.THROTTLE_RATES[self.scope]
        except KeyError:
            msg = "No default throttle rate set for '%s' scope" % self.scope
            raise ImproperlyConfigured(msg)

从当前类对象(UserRateThrottle)获取scope,不存在直接抛异常,通过第一大段代码得到scope = 'anon',继续往下走,执行self.THROTTLE_RATES[self.scope],也就是执行self.THROTTLE_RATES['anon'],此时我们需要知道THROTTLE_RATES是什么,发现UserRateThrottle类不存在,去他的父类(SimpleRateThrottle)找,发现 THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES,同样是配置文件中的,老方法取rest_framework/settings.py查找。

'DEFAULT_THROTTLE_RATES': {
        'user': None,
        'anon': None,
    },
  1. 得到结果为None,也就是说get_rate(),返回None,所以self.rate = None,接着执行self.num_requests, self.duration = self.parse_rate(self.rate),把参数带入查看self.parse_rate(None),源码如下:
    def parse_rate(self, rate):
        """
        Given the request rate string, return a two tuple of:
        , 
        """
        if rate is None:
            return (None, None)
        num, period = rate.split('/')
        num_requests = int(num)
        duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
        return (num_requests, duration)

发现我们传入的参数是None,就直接返回了(None,None),但是接着我们继续往下看,我们发现它需要的是一个有“/”的字符串,用来分割,之后从字典匹配。举个例子:
假如我们配置的不是None,而是100/mabc,运行到num, period = rate.split('/'),num="100",period="mabc",之后把num强制转换为int,接着去{'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]],也就是{'s': 1, 'm': 60, 'h': 3600, 'd': 86400}["m"],此时duration=60,最终返回(100,60)

通过猜想我们应该差不多懂了,他的意思就是,我们配置一个xx/xxx格式的字符串,“/"前面表示请求次数,”/“后面只要是s/m/h/d任意开头的字符串就可以,其实他的意思就是表示秒、分、时、天。也就是我们配置一个100/m即表示一分钟之内最多100次访问,1000/h表示一小时之内最多1000次访问。只不过它的单位是秒。

现在初始化完成了,下面就是请求allow_request方法:

    def allow_request(self, request, view):
        """
        Implement the check to see if the request should be throttled.

        On success calls `throttle_success`.
        On failure calls `throttle_failure`.
        """
        if self.rate is None: # 第一步
            return True

        self.key = self.get_cache_key(request, view) # 第二步
        if self.key is None:
            return True

        # django缓存
        # 1.导包
        # 2.添加缓存 cache.set(key,value,exp)
        # 3.获取缓存 cache.get(key,default_value)
        self.history = self.cache.get(self.key, []) # 第三步
        self.now = self.timer()

        # Drop any requests from the history which have now passed the
        # throttle duration
        while self.history and self.history[-1] <= self.now - self.duration: # 第四步
            self.history.pop()
        if len(self.history) >= self.num_requests: # 第五步
            return self.throttle_failure()   # 第六步
        return self.throttle_success()      # 第六步

分析:

  1. 判断是否配置限流,例如:"user":None,"user":"100/m",未配置直接返回True,表示不限流
  2. 执行self.get_cache_key(request, view)
        def get_cache_key(self, request, view):
        if request.user.is_authenticated:
            ident = request.user.pk
        else:
            ident = self.get_ident(request)

        return self.cache_format % {
            'scope': self.scope,
            'ident': ident
        }

判断用户时候认证成功,认证通过,返回格式化的字符串,因为cache_format = 'throttle_%(scope)s_%(ident)s'所以self.cache_format=throttle_user_1ident表示用户的id也就是pk。
3. self.cache.get(self.key, []),通过返回的字符串去缓存获取。不存在赋一个空列表给self.history,否则把缓存记录给self.history
4. 判断self.history是否存在且它的最后一个记录值是否小于当前时间减去设定的访问频率时间,小于这去掉末尾元素继续循环判断,否则结束循环
5. 循环结束或条件不满足跳出,判断现在的self.history列表长度是否大于等于配置的请求次数。
6. 大于等于执行self.throttle_failure(),否则self.throttle_success()

  • self.throttle_failure()
    def throttle_failure(self):
        """
        Called when a request to the API has failed due to throttling.
        """
        return False

直接返回False,表示此次请求达到限流要求

  • self.throttle_success()
    def throttle_success(self):
        """
        Inserts the current request's timestamp along with the key
        into the cache.
        """
        self.history.insert(0, self.now)
        self.cache.set(self.key, self.history, self.duration)
        return True

把当前时间插入到self.history头部,并且在缓存中保存,之后返回True。

所以整体思路就是,先看配置,不存在配置(默认配置None)或配置为None,直接返回True,表示不限流。存在配置,解析出请求次数和单位时间,然后在调用allow_request,从缓存获取历史请求时间列表。缓存没有,表示第一次请求或者缓存过期经过一系列判断,最终执行throttle_success(),当前请求通过,并把请求时间记录存入缓存,返回True。如果缓存存在且满足设置的请求条件,比如:一分钟3次,我请求了两次,第三次我是在第一次请求完成1分钟之后请求的,此时缓存已经过期了,或者第三次请求距离第一次请求间隔为1分钟以内,但是我还有此次请求的机会,仍然返回True。但是,如果缓存存在且不满足设置的请求条件,同样是一分钟三次,我请求了三次,本次是第四次且距离第一次请求是1分钟之内,那么此次请求就被拒绝,返回False。至于怎么拒绝就是后面的执行

  • 请求频率达到配置要求
    def check_throttles(self, request):
        throttle_durations = []
        for throttle in self.get_throttles():
            if not throttle.allow_request(request, self):
                throttle_durations.append(throttle.wait())
       ......

请求返回False,if判断成立,调用throttle.wait(),获取还需等待多长时间可以进行下一次访问

    def wait(self):
        """
        Returns the recommended next request time in seconds.
        """
        if self.history:
            remaining_duration = self.duration - (self.now - self.history[-1])
        else:
            remaining_duration = self.duration

        available_requests = self.num_requests - len(self.history) + 1
        if available_requests <= 0:
            return None

        return remaining_duration / float(available_requests)

自定义限流类

通过源码的分析之后,我们虽然感到限流类的过程较为复杂,但是他的配置方式还是很简单,首先自定义一个限流类

from rest_framework.throttling import SimpleRateThrottle

class SMSRateThrottle(SimpleRateThrottle):
    scope = 'sms'

    # 只对提交手机号的get方法进行限制
    def get_cache_key(self, request, view):
        mobile = request.query_params.get('mobile')
        # 没有手机号就不做频率限制
        if not mobile:
            return None
        # 返回的信息可以用手机号动态变化,且不易重复的字符串,作为缓存的key
        return 'throttle_%(scope)s_%(ident)s' % {'scope': self.scope, 'ident': mobile}
  • 全局配置(settings.py)
REST_FRAMEWORK = {
    # 频率类配置
    'DEFAULT_THROTTLE_CLASSES': [
        'utils.throttling.SMSRateThrottle',
    ],
    # 频率限制条件配置
    'DEFAULT_THROTTLE_RATES': {
        'sms': '1/min'
    },
}
  • 局部配置(CBV)
def xxxx(APIView):
    throttle_classes = (SMSRateThrottle,)
    .....
    def get():
        ....

相关参考:https://www.cnblogs.com/wangcuican/p/11723103.html

你可能感兴趣的:(【Django】DRF源码分析之三大认证)