restframework学习笔记——APIView源码解读之访问节流即控制访问频率

通过APIView中的dispatch()中执行self.initial(request, *args, **kwargs),认证和权限判断都要经过这一步。

def initial(self, request, *args, **kwargs):
    """
    Runs anything that needs to occur prior to calling the method handler.
    """
    self.format_kwarg = self.get_format_suffix(**kwargs)

    # Perform content negotiation and store the accepted info on the request
    neg = self.perform_content_negotiation(request)
    request.accepted_renderer, request.accepted_media_type = neg

    # Determine the API version, if versioning is in use.
    version, scheme = self.determine_version(request, *args, **kwargs)
    request.version, request.versioning_scheme = version, scheme

    # Ensure that the incoming request is permitted
    self.perform_authentication(request)
    self.check_permissions(request)
    # 控制访问频率..........................
    self.check_throttles(request)

from rest_framework.throttling import BaseThrottle
class BaseThrottle(object):
    """
    Rate throttling of requests.
    """

    def allow_request(self, request, view):
        """
        Return `True` if the request should be allowed, `False` otherwise.
        """
        raise NotImplementedError('.allow_request() must be overridden')
	#获取标识
    def get_ident(self, request):
        """
        Identify the machine making the request by parsing HTTP_X_FORWARDED_FOR
        if present and number of proxies is > 0. If not use all of
        HTTP_X_FORWARDED_FOR if it is available, if not use REMOTE_ADDR.
        """
        xff = request.META.get('HTTP_X_FORWARDED_FOR')
        #获取请求的IP地址并返回
        remote_addr = request.META.get('REMOTE_ADDR')
        num_proxies = api_settings.NUM_PROXIES

        if num_proxies is not None:
            if num_proxies == 0 or xff is None:
                return remote_addr
            addrs = xff.split(',')
            client_addr = addrs[-min(num_proxies, len(addrs))]
            return client_addr.strip()

        return ''.join(xff.split()) if xff else remote_addr

    def wait(self):
        """
        Optionally, return a recommended number of seconds to wait before
        the next request.
        """
        return None

SimpleRateThrottle继承了BaseThrottle,并做进一步的操作

class SimpleRateThrottle(BaseThrottle):
	cache = default_cache
    timer = time.time
    cache_format = 'throttle_%(scope)s_%(ident)s'
    scope = None
    THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES
	....

1.在构造函数中先反射获取“rate”,但是SimpleRateThrottle()全局中没有;
2.所以执行self.get_rate()。
3.在get_rate()中又反射获取“scope”,“scope”在全局中为None;
4.继续在try中执行return self.THROTTLE_RATES[self.scope],即以scope为键取值并返回,THROTTLE_RATES在全局中为“THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES”,是配置文件,需要去settings.py找.
所以,视图类需要继承SimpleRateThrottle,在视图类中自定义scope的名称scope = ‘xxx’,并且这个名称’xxx’会在配置文件中作为键,值就是你自己设置的访问频率
REST_FRAMEWORK = {
“DEFAULT_THROTTLE_RATES”:{
“xxx”: ‘5/m’,
}
}
当你配置好scope后,get_rate(self)就可以获取它的值,返回给self.rate,然后再由self.parse_rate(self.rate)解析——》转下文

 def __init__(self):
        if not getattr(self, 'rate', None):
            self.rate = self.get_rate()
        self.num_requests, self.duration = self.parse_rate(self.rate)
   
       
 def get_rate(self):
	  """
	  Determine the string representation of the allowed request rate.
	  """
	  if not getattr(self, 'scope', None):
	      msg = ("You must set either `.scope` or `.rate` for '%s' throttle" %
	             self.__class__.__name__)
	      raise ImproperlyConfigured(msg)
	
	  try:
	      return self.THROTTLE_RATES[self.scope]
	  except KeyError:
	      msg = "No default throttle rate set for '%s' scope" % self.scope
	      raise ImproperlyConfigured(msg)

获取rate后,交由parse_rate(self,rate)处理。根据这里可以看出,我们的scope配置必须要符合它的要求,才能正确解析。解析出来频率和时间周期->(num_requests, duration)并返回

def parse_rate(self, rate):
    """
    Given the request rate string, return a two tuple of:
    , 
    """
    #配置文件中的scope值要符合要求,例如5/m、20/h
    if rate is None:
        return (None, None)
    num, period = rate.split('/')
    num_requests = int(num)
    duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
    return (num_requests, duration)

获取rate后继续往下,到 allow_request(self, request, view),其中self.key = self.get_cache_key(request, view)

 def allow_request(self, request, view):
     """
     Implement the check to see if the request should be throttled.

     On success calls `throttle_success`.
     On failure calls `throttle_failure`.
     """
     if self.rate is None:
         return True
	'''
	get_cache_key()中啥都没写,所以需要我们自己重写,返回一个			unique cache-key,且这个unique cache-key可以被用来限制访问,例如IP地址、用户名称等等。如果此请求不用限制访问,返回None即可
	'''
     self.key = self.get_cache_key(request, view)
     if self.key is None:
         return True
	‘’‘
	self.cache就是前面类全局中的cache = default_cache,Django内置的缓存
	’‘’
     self.history = self.cache.get(self.key, [])
     self.now = self.timer()

     # Drop any requests from the history which have now passed the
     # throttle duration
     '''
     循环判断,historay列表中的最后一个和时间差对比,如果超出就pop掉。最后如果history长度小于设置的请求频率数,就不限制请求
     '''
     while self.history and self.history[-1] <= self.now - self.duration:
         self.history.pop()
     if len(self.history) >= self.num_requests:
         return self.throttle_failure()
     return self.throttle_success()

你可能感兴趣的:(Django学习笔记)