restframework(二)

restframework

认证

基础使用

class Authentication(BaseAuthentication):
    """用户验证"""
    def authenticate(self, request):
        token = request.GET.get('token', None)
        token_obj = Token.objects.filter(token=token).first()
        if token_obj:
            if token_obj.token == token and token_obj.user.pk == request.session['user']:
                return token_obj.user.name, token_obj.token
        raise exceptions.AuthenticationFailed("验证失败!")
class PublishDetail(viewsets.ModelViewSet):
    # 在这里进行注册
    authentication_classes = [Authentication]
    queryset = Publish.objects.all()
    serializer_class = PublishModelSerializers

在全局配置该认证方法

settings.py

REST_FRAMEWORK = {
    'authentication_classes': [Authentication]
}

注意事项

  1. 复写BaseAuthentication 中的 authenticate方法, 返回值是一个元组(用户,认证) ===> 可通过 request.user 和 request.auth获取
  2. 在自己写的实现类中 声明使用这个认证 authentication_classes = […]
  3. 可以使用多个认证, 如果当前组件不进行认证 可返回None 将其放置到下一个组件中验证

源码流程

首先我们到APIView中(rest_framework.py\view.py)找到dispatch方法

def dispatch(self, request, *args, **kwargs):
    ...
    # 对reqeust进行进一步的封装
    request = self.initialize_request(request, *args, **kwargs)
    self.request = request
    self.headers = self.default_response_headers  # deprecate?

    try:
        # 执行认证 权限 频率访问 
        self.initial(request, *args, **kwargs)
        ...

多余代码被我省略了。 重点就是 initial方法执行了什么,

initial

def initial(self, request, *args, **kwargs):
    self.format_kwarg = self.get_format_suffix(**kwargs)

    # Perform content negotiation and store the accepted info on the request
    neg = self.perform_content_negotiation(request)
    request.accepted_renderer, request.accepted_media_type = neg

    # Determine the API version, if versioning is in use.
    version, scheme = self.determine_version(request, *args, **kwargs)
    request.version, request.versioning_scheme = version, scheme

    # Ensure that the incoming request is permitted
    # 认证
    self.perform_authentication(request)
    # 权限
    self.check_permissions(request)
    # 频率访问
    self.check_throttles(request)

perform_authentication

def perform_authentication(self, request):
    request.user
    # 这里执行了一个requests.user 注意这个request 是wsgi封装的 还是 在initial中封装后的request
    # 
    # 我们需要找到user方法, 因为request是被二次封装过的,所以我们应该找这个Request的类

在APIView中的dispatch中

dispatch

def dispatch(self, request, *args, **kwargs):
    ...
    # 这里进行赋值的 所以我们应该看下 initialize_request()如何执行的
    request = self.initialize_request(request, *args, **kwargs)
    self.request = request
    ...

initialize_request

def initialize_request(self, request, *args, **kwargs):
    """
    Returns the initial request object.
    """
    parser_context = self.get_parser_context(request)

    return Request(
        request,
        parsers=self.get_parsers(),
        authenticators=self.get_authenticators(),
        negotiator=self.get_content_negotiator(),
        parser_context=parser_context
    )
    # return 了一个 Request实例对象, 也就是说 我们要找的 user方法就在这个类 中, 

Request类中

@property
def user(self):
    """
    Returns the user associated with the current request, as authenticated
    by the authentication classes provided to the request.
    """
    if not hasattr(self, '_user'):
        with wrap_attributeerrors():
            # 主要逻辑在这里执行, 我们看下_authenticate究竟做了什么
            self._authenticate()
    return self._user

_authenticate

def _authenticate(self):
    """
    Attempt to authenticate the request using each authentication instance
    in turn.
    """
    for authenticator in self.authenticators:
        """
            authenticators = self.get_authenticators() ==> [auth() for auth in self.authentication_classes]
            authentication_classes 在基础使用中 定义的那个列表就是这个
            也就是authenticator ==> 类的实例对象
        """
        try:
            user_auth_tuple = authenticator.authenticate(self)
            # 这里每个类都执行了 authenticate()方法 所以我们如果要自定制 也应该写authenticate方法

        except exceptions.APIException:
            # 如果不通过可以直接抛出一个APIException 异常 
            self._not_authenticated()
            raise

        if user_auth_tuple is not None:
            self._authenticator = authenticator
            self.user, self.auth = user_auth_tuple
            # 将 authenticator.authenticate(self) 的返回值赋给 self.user, self.auth 
            # 1. authenticate() 方法返回一个元组
            # 2. 在外部 我们可以使用 request.user request.auth 取到这两个值
            return

    self._not_authenticated()

大体的方法也就这个样子, 我们还可以看下如果我们不写这个认证类,restframework 会怎么执行

self.get_authenticators() ==> [auth() for auth in self.authentication_classes] 
# 刚刚说的是 authentication_classes 是我们自己写的, 如果没写呢?
# 会在AVIView中找到 authentication_classes = api_settings.DEFAULT_AUTHENTICATION_CLASSES
# 那也就是说找下DEFAULT_AUTHENTICATION_CLASSES 是什么就好了
api_settings = APISettings(None, DEFAULTS, IMPORT_STRINGS)

APISettings

class APISettings(object):
    ...
    def __init__(self, user_settings=None, defaults=None, import_strings=None):
        if user_settings:
            self._user_settings = self.__check_user_settings(user_settings)
        self.defaults = defaults or DEFAULTS
        self.import_strings = import_strings or IMPORT_STRINGS
        self._cached_attrs = set()
    ...
    def __getattr__(self, attr):
        if attr not in self.defaults:
            raise AttributeError("Invalid API setting: '%s'" % attr)

        try:
            # Check if present in user settings
            val = self.user_settings[attr]
        except KeyError:
            # Fall back to defaults
            val = self.defaults[attr]

        # Coerce import strings into classes
        if attr in self.import_strings:
            val = perform_import(val, attr)

        # Cache the result
        self._cached_attrs.add(attr)
        setattr(self, attr, val)
        return val
    ...

DEFAULT_AUTHENTICATION_CLASSES 这个属性没找到,但是APISettings 类中 有一个 __getattr__方法, 便会执行该方法。

def __getattr__(self, attr):
        if attr not in self.defaults:
            raise AttributeError("Invalid API setting: '%s'" % attr)

        try:
            # Check if present in user settings
            val = self.user_settings[attr]
            """
                @property
                def user_settings(self):
                    if not hasattr(self, '_user_settings'):
                        self._user_settings = getattr(settings, 'REST_FRAMEWORK', {})
                        self._user_settings 是在 settings 中找 REST_FRAMEWORK ,如果找到了就返回,找不到就返回空
                    return self._user_settings
                注意该方法: 我们在使用全局配置的时候 使用的 REST_FRAMEWORK 就是从这里来的
            """
            # 因为我们要看默认的 在settings中并没有设置 所以这里会抛出异常
        except KeyError:
            # Fall back to defaults
            val = self.defaults[attr]
            # self.defaults = defaults or DEFAULTS 
            # api_settings = APISettings(None, DEFAULTS, IMPORT_STRINGS)
            # 也就是说 在DEFAULTS中找 DEFAULT_AUTHENTICATION_CLASSES
            # 'DEFAULT_AUTHENTICATION_CLASSES': (
            #        'rest_framework.authentication.SessionAuthentication',
            #        'rest_framework.authentication.BasicAuthentication'
            #    ),
            # 这两个类就是restframework 中定义的与权限相关的类了·, 我们可以看下他们写了些什么


        # Coerce import strings into classes
        if attr in self.import_strings:
            val = perform_import(val, attr)

        # Cache the result
        self._cached_attrs.add(attr)
        setattr(self, attr, val)
        return val

SessionAuthentication

这里只是为了看我们应该如何写认证类的内容,并不是为了看逻辑

class SessionAuthentication(BaseAuthentication):
    """
    Use Django's session framework for authentication.
    """

    def authenticate(self, request):
        """
        Returns a `User` if the request session currently has a logged in user.
        Otherwise returns `None`.
        """

        # Get the session-based user from the underlying HttpRequest object
        user = getattr(request._request, 'user', None)

        # Unauthenticated, CSRF validation not required
        if not user or not user.is_active:
            return None

        self.enforce_csrf(request)

        # CSRF passed with authenticated user
        return (user, None)

    def enforce_csrf(self, request):
        """
        Enforce CSRF validation for session based authentication.
        """
        reason = CSRFCheck().process_view(request, None, (), {})
        if reason:
            # CSRF failed, bail with explicit error message
            raise exceptions.PermissionDenied('CSRF Failed: %s' % reason)

BaseAuthentication

既然内置的认证都继承了BaseAuthentication 那我们是不是也可以继承这个类

class BaseAuthentication(object):
    """
    All authentication classes should extend BaseAuthentication.
    """

    def authenticate(self, request):
        """
        Authenticate the request and return a two-tuple of (user, token).
        """
        raise NotImplementedError(".authenticate() must be overridden.")

    def authenticate_header(self, request):
        """
        Return a string to be used as the value of the `WWW-Authenticate`
        header in a `401 Unauthenticated` response, or `None` if the
        authentication scheme should return `403 Permission Denied` responses.
        """
        pass

权限

基础使用

class VIP(BasePermission):
    """权限组建"""
    message = "少年冲钱吧,要不看不了"

    def has_permission(self, request, view):
        vip = User.objects.filter(pk=request.session['user']).first().vip
        if vip >= 1:
            return True
        return False
class PublishDetail(viewsets.ModelViewSet):
    permission_classes = [VIP]
    queryset = Publish.objects.all()
    serializer_class = PublishModelSerializers

在全局配置权限方法

REST_FRAMEWORK = {
    'permission_classes': [VIP]
}

注意事项

  1. 继承BasePermission 如果认证成功返回 True 否则返回 False
  2. 通过 perission_classes 进行注册
  3. 可以有多个权限判断

源码流程

还是从dispatch 方法中入手

APIView dispatch

def dispatch(self, request, *args, **kwargs):
    ...
    request = self.initialize_request(request, *args, **kwargs)
    self.request = request
    self.headers = self.default_response_headers  # deprecate?

    try:
        self.initial(request, *args, **kwargs)
    ...

initial方法

def initial(self, request, *args, **kwargs):
    ...
    # Ensure that the incoming request is permitted
    self.perform_authentication(request)
    # 权限
    self.check_permissions(request)
    self.check_throttles(request)

第一个方法我们已经看过了 ,接下来看第二个方法

check_permissions

def check_permissions(self, request):
    """
    Check if the request should be permitted.
    Raises an appropriate exception if the request is not permitted.
    """
    for permission in self.get_permissions():
        """
            def get_permissions(self):
                return [permission() for permission in self.permission_classes]
            permission_classes 就是我们自己实现的 那个permission_classes
            permission 就是我们写的 类的实例对象
        """
        if not permission.has_permission(request, self):
            # 这里可以看出 我们需要写一个 has_permission()方法 返回True or False
            self.permission_denied(
                request, message=getattr(permission, 'message', None)
            )
            # 如果返回的是False 则 显示一个 message 提示语句 
            # permission 是我们自己写的类 也就是说 我们可以在类中定一个 message 作为提示语句

我们看下restframework中默认的一些方法是如何执行的,过程与第一个方法 基本相似, 所以跳过寻找方法

'DEFAULT_PERMISSION_CLASSES': (
        'rest_framework.permissions.AllowAny',
    ),
class AllowAny(BasePermission):
    """
    Allow any access.
    This isn't strictly required, since you could use an empty
    permission_classes list, but it's useful because it makes the intention
    more explicit.
    """

    def has_permission(self, request, view):
        return True
# 所有请求都为True  好像也没啥看的 
# 继承 BasePermission 
# 返回值为True False

频率访问

基础使用

VISIT_RECORD = {}  # 访问记录
GAME_OVER = {}  # 被禁用的名单

class VisitThrottle(BaseThrottle):
    """访问次数"""
    def __init__(self):
        self.history = None  # 访问记录

    def allow_request(self, request, view):
        ctime = time.time()
        # 访问ip
        remote_addr = request.META.get('REMOTE_ADDR')
        # 访问浏览器
        http_user = request.META.get('HTTP_USER_AGENT')
        self.remote_addr = remote_addr
        if not http_user:
            return False

        if remote_addr in GAME_OVER:
            """被限制的ip10分钟内禁止访问"""
            if ctime-GAME_OVER[remote_addr] <= 10*60:
                return False

        if remote_addr not in VISIT_RECORD:
            """该ip第一次访问"""
            VISIT_RECORD[remote_addr] = [ctime]
            return True

        self.history = VISIT_RECORD.get(remote_addr)
        while self.history and 60 < ctime - self.history[-1]:
            """如果当前访问记录中的访问记录距当前时间超过1分钟,则将该访问记录进行清空"""
            self.history.pop()

        if len(self.history) < 20:
            """如果访问记录小于20个则可继续访问"""
            self.history.insert(0, ctime)
            return True
        else:
            GAME_OVER[remote_addr] = ctime
            return False

    def wait(self):
        ctime = time.time()
        return 60*10 - (ctime - GAME_OVER[self.remote_addr])
class PublishDetail(viewsets.ModelViewSet):
    throttle_classes = [VisitThrottle]
    queryset = Publish.objects.all()
    serializer_class = PublishModelSerializers

在全局中配置访问频率

REST_FRAMEWORK = {
    'throttle_classes': [VisitThrottle]
}

注意事项

  1. 继承BasePermission 如果认证成功返回 True 否则返回 False
  2. 通过 throttle_classes进行注册
  3. 可以有多个限流访问

源码流程

APIView dispatch

def dispatch(self, request, *args, **kwargs):
    ...
    request = self.initialize_request(request, *args, **kwargs)
    self.request = request
    self.headers = self.default_response_headers  # deprecate?

    try:
        self.initial(request, *args, **kwargs)
    ...

initial方法

def initial(self, request, *args, **kwargs):
    ...
    # Ensure that the incoming request is permitted
    self.perform_authentication(request)
    self.check_permissions(request)
    # 限流访问
    self.check_throttles(request)

check_throttles

def check_throttles(self, request):
    """
    Check if request should be throttled.
    Raises an appropriate exception if the request is throttled.
    """
    for throttle in self.get_throttles():
        """
                def get_throttles(self):
                    return [throttle() for throttle in self.throttle_classes]
                throttle_classes 是我门在类中注册时写的 throttle_classes
        """
        if not throttle.allow_request(request, self):
            # 这里可以看出 我们需要写一个 allow_request()方法 返回True or False
            self.throttled(request, throttle.wait())

而 restframework 中 内置了不少的限流访问的方法,我们可以直接调用

rest_framework.throttling

BaseThrottle

这是一个基类, 所有的频率访问的类都继承它

class BaseThrottle(object):
    """
    Rate throttling of requests.
    """

    def allow_request(self, request, view):
        """
        Return `True` if the request should be allowed, `False` otherwise.
        """
        raise NotImplementedError('.allow_request() must be overridden')

    def get_ident(self, request):
        """
        Identify the machine making the request by parsing HTTP_X_FORWARDED_FOR
        if present and number of proxies is > 0. If not use all of
        HTTP_X_FORWARDED_FOR if it is available, if not use REMOTE_ADDR.
        """
        xff = request.META.get('HTTP_X_FORWARDED_FOR')
        remote_addr = request.META.get('REMOTE_ADDR')
        num_proxies = api_settings.NUM_PROXIES

        if num_proxies is not None:
            if num_proxies == 0 or xff is None:
                return remote_addr
            addrs = xff.split(',')
            client_addr = addrs[-min(num_proxies, len(addrs))]
            return client_addr.strip()

        return ''.join(xff.split()) if xff else remote_addr

    def wait(self):
        """
        Optionally, return a recommended number of seconds to wait before
        the next request.
        """
        return None

提供了3个方法:

  • allow_request

    • 主要逻辑: return True or False
  • get_ident

    • 获取访问ip等信息
  • wait

    • 超时显示的内容

SimpleRateThrottle

这是restframework 提供的类,所以我们可以直接调用它来完成我们的需求。

settings.py中配置

REST_FRAMEWORK = {
    "DEFAULT_THROTTLE_RATES": {
        "Night": "3/m"
    }
}
class VisitDemo(SimpleRateThrottle):
    scope = 'Night'

    def get_cache_key(self, request, view):
        return self.get_ident(request)

既然这个类是继承BaseThrottle, 那么我们就先看__init__ 然后看 allow_request 毕竟调用的话主要调用的allow_request

class SimpleRateThrottle(BaseThrottle):
    cache = default_cache
    timer = time.time # 获取当前时间
    cache_format = 'throttle_%(scope)s_%(ident)s'  # 一个显示方式 
    scope = None
    THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES 

    def __init__(self):
        if not getattr(self, 'rate', None):
            self.rate = self.get_rate()
            # 获取rate
            """
             def get_rate(self):
                if not getattr(self, 'scope', None):
                    msg = ("You must set either `.scope` or `.rate` for '%s' throttle" %
                           self.__class__.__name__)
                    raise ImproperlyConfigured(msg)
                try:
                    return self.THROTTLE_RATES[self.scope]
                    # THROTTLE_RATES = DEFAULT_THROTTLE_RATES 也就是我们自己写的配置
                        "DEFAULT_THROTTLE_RATES": {
                            "Night": "3/m"
                        }
                    拿到 3/m 
                except KeyError:
                    msg = "No default throttle rate set for '%s' scope" % self.scope
                    raise ImproperlyConfigured(msg)
            """
            # self.rate = "3/m"
        self.num_requests, self.duration = self.parse_rate(self.rate)
        """
            def parse_rate(self, rate):
                if rate is None:
                    return (None, None)
                num, period = rate.split('/')
                num_requests = int(num)
                duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
                return (num_requests, duration)

            self.num_requests = 3 限制次数
            self.duration = m     限制时间

        """

    def get_cache_key(self, request, view):
        """
        Should return a unique cache-key which can be used for throttling.
        Must be overridden.

        May return `None` if the request should not be throttled.
        """
        """这个方法是需要我们自己重载的 获取用户的唯一标识"""
        raise NotImplementedError('.get_cache_key() must be overridden')



    def allow_request(self, request, view):
        """
        Implement the check to see if the request should be throttled.

        On success calls `throttle_success`.
        On failure calls `throttle_failure`.
        """
        if self.rate is None:
            return True

        self.key = self.get_cache_key(request, view)
        if self.key is None:
            return True

        self.history = self.cache.get(self.key, [])
        # 将访问记录加入历史记录
        self.now = self.timer()
        # 获取当前时间

        # Drop any requests from the history which have now passed the
        # throttle duration
        while self.history and self.history[-1] <= self.now - self.duration:
            # 判断访问次数
            self.history.pop()
        if len(self.history) >= self.num_requests:
            # 校验是否有权限访问
            return self.throttle_failure()
        return self.throttle_success()

    def throttle_success(self):
        """
        Inserts the current request's timestamp along with the key
        into the cache.
        """
        self.history.insert(0, self.now)
        self.cache.set(self.key, self.history, self.duration)
        return True

    def throttle_failure(self):
        """
        Called when a request to the API has failed due to throttling.
        """
        return False

    def wait(self):
        """
        Returns the recommended next request time in seconds.
        """
        if self.history:
            remaining_duration = self.duration - (self.now - self.history[-1])
        else:
            remaining_duration = self.duration

        available_requests = self.num_requests - len(self.history) + 1
        if available_requests <= 0:
            return None

        return remaining_duration / float(available_requests)

其实如果整体走下来会发现, 自己写的基础使用基本仿照它写的。

下面的几个类都是继承SimpleRateThrottle 也就是和我们使用的形式基本一致

AnonRateThrottle
class AnonRateThrottle(SimpleRateThrottle):
    """
    Limits the rate of API calls that may be made by a anonymous users.

    The IP address of the request will be used as the unique cache key.
    """
    scope = 'anon'

    def get_cache_key(self, request, view):
        if request.user.is_authenticated:
            return None  # Only throttle unauthenticated requests.

        return self.cache_format % {
            'scope': self.scope,
            'ident': self.get_ident(request)
        }
UserRateThrottle
class UserRateThrottle(SimpleRateThrottle):
    """
    Limits the rate of API calls that may be made by a given user.

    The user id will be used as a unique cache key if the user is
    authenticated.  For anonymous requests, the IP address of the request will
    be used.
    """
    scope = 'user'

    def get_cache_key(self, request, view):
        if request.user.is_authenticated:
            ident = request.user.pk
        else:
            ident = self.get_ident(request)

        return self.cache_format % {
            'scope': self.scope,
            'ident': ident
        }
ScopedRateThrottle
class ScopedRateThrottle(SimpleRateThrottle):
    """
    Limits the rate of API calls by different amounts for various parts of
    the API.  Any view that has the `throttle_scope` property set will be
    throttled.  The unique cache key will be generated by concatenating the
    user id of the request, and the scope of the view being accessed.
    """
    scope_attr = 'throttle_scope'

    def __init__(self):
        # Override the usual SimpleRateThrottle, because we can't determine
        # the rate until called by the view.
        pass

    def allow_request(self, request, view):
        # We can only determine the scope once we're called by the view.
        self.scope = getattr(view, self.scope_attr, None)

        # If a view does not have a `throttle_scope` always allow the request
        if not self.scope:
            return True

        # Determine the allowed request rate as we normally would during
        # the `__init__` call.
        self.rate = self.get_rate()
        self.num_requests, self.duration = self.parse_rate(self.rate)

        # We can now proceed as normal.
        return super(ScopedRateThrottle, self).allow_request(request, view)

    def get_cache_key(self, request, view):
        """
        If `view.throttle_scope` is not set, don't apply this throttle.

        Otherwise generate the unique cache key by concatenating the user id
        with the '.throttle_scope` property of the view.
        """
        if request.user.is_authenticated:
            ident = request.user.pk
        else:
            ident = self.get_ident(request)

        return self.cache_format % {
            'scope': self.scope,
            'ident': ident
        }

你可能感兴趣的:(django)