django-rest-framework(五)(节流)

节流

需求:控制用户的访问的频率(10秒中内访问三次)

思想:设立一个全局变量字典,对于匿名用户,将用户的IP作为字典的键;对于登陆用户,将用户的用户名作为字典的键。设置字典的值为一个列表,列表中存储了用户访问的时间,可以通过对列表中的时间进行操作来控制访问频率

代码实现

from django.http import JsonResponse
import time
from rest_framework.views import APIView
from api.models import UserInfo, UserToken

#可以将字典放在缓存中
VISIT_RECODE = dict()

class MyThrottles(object):
    def __init__(self):
        self.history = None

    def allow_request(self, request, view):
        # 获取用户访问的ID
        remote_addr = request._request.META.get('REMOTE_ADDR')
        ctime = time.time()
        # 如果用户是第一次访问,将用户的IP作为字典的键,将用户的第一访问时间作为值放入字典中
        if remote_addr not in VISIT_RECODE:
            VISIT_RECODE[remote_addr] = [ctime, ]

        # 如果用户不是第一次访问,获取列表
        histiory = VISIT_RECODE.get(remote_addr)
        self.history = histiory

        # 筛选出10秒中列表的时间个数
        while histiory and histiory[-1] < ctime - 10:
            histiory.pop()

        # 如果是小于3个,则可以进行访问;否则不能进行访问
        if len(histiory) < 3:
            histiory.insert(0, ctime)
            return True  # 返回True则表示可以进行访问
            # return False # 返回False则表示不能进行访问

    def wait(self):
        ctime = time.time()
        return 10-(ctime - self.history[-1])


class AuthView(APIView):
    '''
    用于登录认证,控制用户的访问的频率(10秒中内访问三次)
    '''
    authentication_classes = []
    throttle_classes = [MyThrottles, ]

    def post(self, request, *args, **kwargs):
        ret = {'code': 1000, 'msg': None}
        try:
            user = request._request.POST.get('username')
            password = request._request.POST.get('password')
            obj = UserInfo.objects.filter(username=user, password=password).first()
            # print(obj)
            if not obj:
                ret['code'] = 1001
                ret['msg'] = '用户名或者密码错误'

            # 创建token
            token = md5(user)
            # logging.info(token)
            print(token)

            UserToken.objects.update_or_create(user=obj, defaults={'token': token})
        except Exception as e:
            print(e)

        return JsonResponse(ret)

节流的全局使用

  • 在setting中配置
REST_FRAMEWORK = {
    #认证配置
    'DEFAULT_AUTHENTICATION_CLASSES': ['api.utils.auth.MyAuthtication'],
    # 当认证为匿名用户时,request.user = '匿名用户'/None
    'UNAUTHENTICATED_USER': lambda: '匿名用户', # None
    # 当认证为匿名用户时,request.auth = None
    'UNAUTHENTICATED_TOKEN': None,

    # 权限配置
    'DEFAULT_PERMISSION_CLASSES': ['api.utils.permission.MyPermission'],

    # 节流配置
    'DEFAULT_THROTTLE_CLASSES': ['api.utils.throttles.MyThrottles']
}
  • 将节流类写在新的地方

django-rest-framework中内置的节流

from __future__ import unicode_literals

import time

from django.core.cache import cache as default_cache
from django.core.exceptions import ImproperlyConfigured

from rest_framework.settings import api_settings


class BaseThrottle(object):
    """
    Rate throttling of requests.
    """

    def allow_request(self, request, view):
        """
        Return `True` if the request should be allowed, `False` otherwise.
        """
        raise NotImplementedError('.allow_request() must be overridden')

    def get_ident(self, request):
        """
        Identify the machine making the request by parsing HTTP_X_FORWARDED_FOR
        if present and number of proxies is > 0. If not use all of
        HTTP_X_FORWARDED_FOR if it is available, if not use REMOTE_ADDR.
        """
        xff = request.META.get('HTTP_X_FORWARDED_FOR')
        remote_addr = request.META.get('REMOTE_ADDR')
        num_proxies = api_settings.NUM_PROXIES

        if num_proxies is not None:
            if num_proxies == 0 or xff is None:
                return remote_addr
            addrs = xff.split(',')
            client_addr = addrs[-min(num_proxies, len(addrs))]
            return client_addr.strip()

        return ''.join(xff.split()) if xff else remote_addr

    def wait(self):
        """
        Optionally, return a recommended number of seconds to wait before
        the next request.
        """
        return None


class SimpleRateThrottle(BaseThrottle):
    """
    A simple cache implementation, that only requires `.get_cache_key()`
    to be overridden.

    The rate (requests / seconds) is set by a `rate` attribute on the View
    class.  The attribute is a string of the form 'number_of_requests/period'.

    Period should be one of: ('s', 'sec', 'm', 'min', 'h', 'hour', 'd', 'day')

    Previous request information used for throttling is stored in the cache.
    """
    cache = default_cache
    timer = time.time
    cache_format = 'throttle_%(scope)s_%(ident)s'
    scope = None
    THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES

    def __init__(self):
        if not getattr(self, 'rate', None):
            self.rate = self.get_rate()
        self.num_requests, self.duration = self.parse_rate(self.rate)

    def get_cache_key(self, request, view):
        """
        Should return a unique cache-key which can be used for throttling.
        Must be overridden.

        May return `None` if the request should not be throttled.
        """
        raise NotImplementedError('.get_cache_key() must be overridden')

    def get_rate(self):
        """
        Determine the string representation of the allowed request rate.
        """
        if not getattr(self, 'scope', None):
            msg = ("You must set either `.scope` or `.rate` for '%s' throttle" %
                   self.__class__.__name__)
            raise ImproperlyConfigured(msg)

        try:
            return self.THROTTLE_RATES[self.scope]
        except KeyError:
            msg = "No default throttle rate set for '%s' scope" % self.scope
            raise ImproperlyConfigured(msg)

    def parse_rate(self, rate):
        """
        Given the request rate string, return a two tuple of:
        , 
        """
        if rate is None:
            return (None, None)
        num, period = rate.split('/')
        num_requests = int(num)
        duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
        return (num_requests, duration)

    def allow_request(self, request, view):
        """
        Implement the check to see if the request should be throttled.

        On success calls `throttle_success`.
        On failure calls `throttle_failure`.
        """
        if self.rate is None:
            return True

        self.key = self.get_cache_key(request, view)
        if self.key is None:
            return True

        self.history = self.cache.get(self.key, [])
        self.now = self.timer()

        # Drop any requests from the history which have now passed the
        # throttle duration
        while self.history and self.history[-1] <= self.now - self.duration:
            self.history.pop()
        if len(self.history) >= self.num_requests:
            return self.throttle_failure()
        return self.throttle_success()

    def throttle_success(self):
        """
        Inserts the current request's timestamp along with the key
        into the cache.
        """
        self.history.insert(0, self.now)
        self.cache.set(self.key, self.history, self.duration)
        return True

    def throttle_failure(self):
        """
        Called when a request to the API has failed due to throttling.
        """
        return False

    def wait(self):
        """
        Returns the recommended next request time in seconds.
        """
        if self.history:
            remaining_duration = self.duration - (self.now - self.history[-1])
        else:
            remaining_duration = self.duration

        available_requests = self.num_requests - len(self.history) + 1
        if available_requests <= 0:
            return None

        return remaining_duration / float(available_requests)


class AnonRateThrottle(SimpleRateThrottle):
    """
    Limits the rate of API calls that may be made by a anonymous users.

    The IP address of the request will be used as the unique cache key.
    """
    scope = 'anon'

    def get_cache_key(self, request, view):
        if request.user.is_authenticated:
            return None  # Only throttle unauthenticated requests.

        return self.cache_format % {
            'scope': self.scope,
            'ident': self.get_ident(request)
        }


class UserRateThrottle(SimpleRateThrottle):
    """
    Limits the rate of API calls that may be made by a given user.

    The user id will be used as a unique cache key if the user is
    authenticated.  For anonymous requests, the IP address of the request will
    be used.
    """
    scope = 'user'

    def get_cache_key(self, request, view):
        if request.user.is_authenticated:
            ident = request.user.pk
        else:
            ident = self.get_ident(request)

        return self.cache_format % {
            'scope': self.scope,
            'ident': ident
        }


class ScopedRateThrottle(SimpleRateThrottle):
    """
    Limits the rate of API calls by different amounts for various parts of
    the API.  Any view that has the `throttle_scope` property set will be
    throttled.  The unique cache key will be generated by concatenating the
    user id of the request, and the scope of the view being accessed.
    """
    scope_attr = 'throttle_scope'

    def __init__(self):
        # Override the usual SimpleRateThrottle, because we can't determine
        # the rate until called by the view.
        pass

    def allow_request(self, request, view):
        # We can only determine the scope once we're called by the view.
        self.scope = getattr(view, self.scope_attr, None)

        # If a view does not have a `throttle_scope` always allow the request
        if not self.scope:
            return True

        # Determine the allowed request rate as we normally would during
        # the `__init__` call.
        self.rate = self.get_rate()
        self.num_requests, self.duration = self.parse_rate(self.rate)

        # We can now proceed as normal.
        return super(ScopedRateThrottle, self).allow_request(request, view)

    def get_cache_key(self, request, view):
        """
        If `view.throttle_scope` is not set, don't apply this throttle.

        Otherwise generate the unique cache key by concatenating the user id
        with the '.throttle_scope` property of the view.
        """
        if request.user.is_authenticated:
            ident = request.user.pk
        else:
            ident = self.get_ident(request)

        return self.cache_format % {
            'scope': self.scope,
            'ident': ident
        }
  • 继承BaseThrottle类
    要实现两个方法allow_request、wait
  • 继承SimpleRateThrottle类
    重写get_cache_key方法,并设置参数scope和setting中的配置

节流类的书写

from rest_framework.throttling import BaseThrottle, SimpleRateThrottle

class MyThrottles(SimpleRateThrottle):
    '''匿名用户的节流限制'''
    scope = 'wang'
    def get_cache_key(self, request, view):
        return self.get_ident(request)


class MyThrottles1(SimpleRateThrottle):
    '''登陆用户的节流限制'''
    scope = 'wang1'
    def get_cache_key(self, request, view):
        return request.user.username

setting中的配置

REST_FRAMEWORK = {
    #认证配置
    'DEFAULT_AUTHENTICATION_CLASSES': ['api.utils.auth.MyAuthtication'],
    # 当认证为匿名用户时,request.user = '匿名用户'/None
    'UNAUTHENTICATED_USER': lambda: '匿名用户', # None
    # 当认证为匿名用户时,request.auth = None
    'UNAUTHENTICATED_TOKEN': None,

    # 权限配置
    'DEFAULT_PERMISSION_CLASSES': ['api.utils.permission.MyPermission'],

    # 节流配置
    'DEFAULT_THROTTLE_CLASSES': ['api.utils.throttles.MyThrottles'],
    'DEFAULT_THROTTLE_RATES': {
        'wang': '3/m', # 1分钟内访问3次 
        'wang1': '10/m' # 1分钟内访问10次
    },
}

view的书写

class AuthView(APIView):
    '''
    用于登录认证,控制用户的访问的频率(1分钟中内访问三次)
    '''
    authentication_classes = []
    permission_classes = []
	throttle_classes = [MyThrottles, ]
	
    def post(self, request, *args, **kwargs):
        ret = {'code': 1000, 'msg': None}
        try:
            user = request._request.POST.get('username')
            password = request._request.POST.get('password')
            obj = UserInfo.objects.filter(username=user, password=password).first()
            # print(obj)
            if not obj:
                ret['code'] = 1001
                ret['msg'] = '用户名或者密码错误'

            # 创建token
            token = md5(user)
            # logging.info(token)
            print(token)

            UserToken.objects.update_or_create(user=obj, defaults={'token': token})
        except Exception as e:
            print(e)

        return JsonResponse(ret)


class OrderView(APIView):
    '''
    订单相关业务,只有SVIP能够访问,控制用户的访问的频率(1分钟中内访问十次)
    '''

    def get(self, request, *args, **kwargs):
        # self.dispatch
        print(request.user, request.auth)
        ret = {'code': 1000, 'msg': None, 'data': None}
        ret['data'] = ORDER_DICT
        return JsonResponse(ret)

你可能感兴趣的:(django-rest-framework(五)(节流))