class Authentication(BaseAuthentication):
"""用户验证"""
def authenticate(self, request):
token = request.GET.get('token', None)
token_obj = Token.objects.filter(token=token).first()
if token_obj:
if token_obj.token == token and token_obj.user.pk == request.session['user']:
return token_obj.user.name, token_obj.token
raise exceptions.AuthenticationFailed("验证失败!")
class PublishDetail(viewsets.ModelViewSet):
# 在这里进行注册
authentication_classes = [Authentication]
queryset = Publish.objects.all()
serializer_class = PublishModelSerializers
在settings.py
中
REST_FRAMEWORK = {
'authentication_classes': [Authentication]
}
首先我们到APIView中(rest_framework.py\view.py)找到dispatch
方法
def dispatch(self, request, *args, **kwargs):
...
# 对reqeust进行进一步的封装
request = self.initialize_request(request, *args, **kwargs)
self.request = request
self.headers = self.default_response_headers # deprecate?
try:
# 执行认证 权限 频率访问
self.initial(request, *args, **kwargs)
...
多余代码被我省略了。 重点就是 initial
方法执行了什么,
initial
def initial(self, request, *args, **kwargs):
self.format_kwarg = self.get_format_suffix(**kwargs)
# Perform content negotiation and store the accepted info on the request
neg = self.perform_content_negotiation(request)
request.accepted_renderer, request.accepted_media_type = neg
# Determine the API version, if versioning is in use.
version, scheme = self.determine_version(request, *args, **kwargs)
request.version, request.versioning_scheme = version, scheme
# Ensure that the incoming request is permitted
# 认证
self.perform_authentication(request)
# 权限
self.check_permissions(request)
# 频率访问
self.check_throttles(request)
perform_authentication
def perform_authentication(self, request):
request.user
# 这里执行了一个requests.user 注意这个request 是wsgi封装的 还是 在initial中封装后的request
#
# 我们需要找到user方法, 因为request是被二次封装过的,所以我们应该找这个Request的类
在APIView中的dispatch中
dispatch
def dispatch(self, request, *args, **kwargs):
...
# 这里进行赋值的 所以我们应该看下 initialize_request()如何执行的
request = self.initialize_request(request, *args, **kwargs)
self.request = request
...
initialize_request
def initialize_request(self, request, *args, **kwargs):
"""
Returns the initial request object.
"""
parser_context = self.get_parser_context(request)
return Request(
request,
parsers=self.get_parsers(),
authenticators=self.get_authenticators(),
negotiator=self.get_content_negotiator(),
parser_context=parser_context
)
# return 了一个 Request实例对象, 也就是说 我们要找的 user方法就在这个类 中,
Request
类中
@property
def user(self):
"""
Returns the user associated with the current request, as authenticated
by the authentication classes provided to the request.
"""
if not hasattr(self, '_user'):
with wrap_attributeerrors():
# 主要逻辑在这里执行, 我们看下_authenticate究竟做了什么
self._authenticate()
return self._user
_authenticate
def _authenticate(self):
"""
Attempt to authenticate the request using each authentication instance
in turn.
"""
for authenticator in self.authenticators:
"""
authenticators = self.get_authenticators() ==> [auth() for auth in self.authentication_classes]
authentication_classes 在基础使用中 定义的那个列表就是这个
也就是authenticator ==> 类的实例对象
"""
try:
user_auth_tuple = authenticator.authenticate(self)
# 这里每个类都执行了 authenticate()方法 所以我们如果要自定制 也应该写authenticate方法
except exceptions.APIException:
# 如果不通过可以直接抛出一个APIException 异常
self._not_authenticated()
raise
if user_auth_tuple is not None:
self._authenticator = authenticator
self.user, self.auth = user_auth_tuple
# 将 authenticator.authenticate(self) 的返回值赋给 self.user, self.auth
# 1. authenticate() 方法返回一个元组
# 2. 在外部 我们可以使用 request.user request.auth 取到这两个值
return
self._not_authenticated()
大体的方法也就这个样子, 我们还可以看下如果我们不写这个认证类,restframework 会怎么执行
self.get_authenticators() ==> [auth() for auth in self.authentication_classes]
# 刚刚说的是 authentication_classes 是我们自己写的, 如果没写呢?
# 会在AVIView中找到 authentication_classes = api_settings.DEFAULT_AUTHENTICATION_CLASSES
# 那也就是说找下DEFAULT_AUTHENTICATION_CLASSES 是什么就好了
api_settings = APISettings(None, DEFAULTS, IMPORT_STRINGS)
APISettings
class APISettings(object):
...
def __init__(self, user_settings=None, defaults=None, import_strings=None):
if user_settings:
self._user_settings = self.__check_user_settings(user_settings)
self.defaults = defaults or DEFAULTS
self.import_strings = import_strings or IMPORT_STRINGS
self._cached_attrs = set()
...
def __getattr__(self, attr):
if attr not in self.defaults:
raise AttributeError("Invalid API setting: '%s'" % attr)
try:
# Check if present in user settings
val = self.user_settings[attr]
except KeyError:
# Fall back to defaults
val = self.defaults[attr]
# Coerce import strings into classes
if attr in self.import_strings:
val = perform_import(val, attr)
# Cache the result
self._cached_attrs.add(attr)
setattr(self, attr, val)
return val
...
DEFAULT_AUTHENTICATION_CLASSES
这个属性没找到,但是APISettings 类中 有一个 __getattr__
方法, 便会执行该方法。
def __getattr__(self, attr):
if attr not in self.defaults:
raise AttributeError("Invalid API setting: '%s'" % attr)
try:
# Check if present in user settings
val = self.user_settings[attr]
"""
@property
def user_settings(self):
if not hasattr(self, '_user_settings'):
self._user_settings = getattr(settings, 'REST_FRAMEWORK', {})
self._user_settings 是在 settings 中找 REST_FRAMEWORK ,如果找到了就返回,找不到就返回空
return self._user_settings
注意该方法: 我们在使用全局配置的时候 使用的 REST_FRAMEWORK 就是从这里来的
"""
# 因为我们要看默认的 在settings中并没有设置 所以这里会抛出异常
except KeyError:
# Fall back to defaults
val = self.defaults[attr]
# self.defaults = defaults or DEFAULTS
# api_settings = APISettings(None, DEFAULTS, IMPORT_STRINGS)
# 也就是说 在DEFAULTS中找 DEFAULT_AUTHENTICATION_CLASSES
# 'DEFAULT_AUTHENTICATION_CLASSES': (
# 'rest_framework.authentication.SessionAuthentication',
# 'rest_framework.authentication.BasicAuthentication'
# ),
# 这两个类就是restframework 中定义的与权限相关的类了·, 我们可以看下他们写了些什么
# Coerce import strings into classes
if attr in self.import_strings:
val = perform_import(val, attr)
# Cache the result
self._cached_attrs.add(attr)
setattr(self, attr, val)
return val
SessionAuthentication
这里只是为了看我们应该如何写认证类的内容,并不是为了看逻辑
class SessionAuthentication(BaseAuthentication):
"""
Use Django's session framework for authentication.
"""
def authenticate(self, request):
"""
Returns a `User` if the request session currently has a logged in user.
Otherwise returns `None`.
"""
# Get the session-based user from the underlying HttpRequest object
user = getattr(request._request, 'user', None)
# Unauthenticated, CSRF validation not required
if not user or not user.is_active:
return None
self.enforce_csrf(request)
# CSRF passed with authenticated user
return (user, None)
def enforce_csrf(self, request):
"""
Enforce CSRF validation for session based authentication.
"""
reason = CSRFCheck().process_view(request, None, (), {})
if reason:
# CSRF failed, bail with explicit error message
raise exceptions.PermissionDenied('CSRF Failed: %s' % reason)
BaseAuthentication
既然内置的认证都继承了BaseAuthentication 那我们是不是也可以继承这个类
class BaseAuthentication(object):
"""
All authentication classes should extend BaseAuthentication.
"""
def authenticate(self, request):
"""
Authenticate the request and return a two-tuple of (user, token).
"""
raise NotImplementedError(".authenticate() must be overridden.")
def authenticate_header(self, request):
"""
Return a string to be used as the value of the `WWW-Authenticate`
header in a `401 Unauthenticated` response, or `None` if the
authentication scheme should return `403 Permission Denied` responses.
"""
pass
class VIP(BasePermission):
"""权限组建"""
message = "少年冲钱吧,要不看不了"
def has_permission(self, request, view):
vip = User.objects.filter(pk=request.session['user']).first().vip
if vip >= 1:
return True
return False
class PublishDetail(viewsets.ModelViewSet):
permission_classes = [VIP]
queryset = Publish.objects.all()
serializer_class = PublishModelSerializers
REST_FRAMEWORK = {
'permission_classes': [VIP]
}
还是从dispatch
方法中入手
APIView dispatch
def dispatch(self, request, *args, **kwargs):
...
request = self.initialize_request(request, *args, **kwargs)
self.request = request
self.headers = self.default_response_headers # deprecate?
try:
self.initial(request, *args, **kwargs)
...
initial
方法
def initial(self, request, *args, **kwargs):
...
# Ensure that the incoming request is permitted
self.perform_authentication(request)
# 权限
self.check_permissions(request)
self.check_throttles(request)
第一个方法我们已经看过了 ,接下来看第二个方法
check_permissions
def check_permissions(self, request):
"""
Check if the request should be permitted.
Raises an appropriate exception if the request is not permitted.
"""
for permission in self.get_permissions():
"""
def get_permissions(self):
return [permission() for permission in self.permission_classes]
permission_classes 就是我们自己实现的 那个permission_classes
permission 就是我们写的 类的实例对象
"""
if not permission.has_permission(request, self):
# 这里可以看出 我们需要写一个 has_permission()方法 返回True or False
self.permission_denied(
request, message=getattr(permission, 'message', None)
)
# 如果返回的是False 则 显示一个 message 提示语句
# permission 是我们自己写的类 也就是说 我们可以在类中定一个 message 作为提示语句
我们看下restframework中默认的一些方法是如何执行的,过程与第一个方法 基本相似, 所以跳过寻找方法
'DEFAULT_PERMISSION_CLASSES': (
'rest_framework.permissions.AllowAny',
),
class AllowAny(BasePermission):
"""
Allow any access.
This isn't strictly required, since you could use an empty
permission_classes list, but it's useful because it makes the intention
more explicit.
"""
def has_permission(self, request, view):
return True
# 所有请求都为True 好像也没啥看的
# 继承 BasePermission
# 返回值为True False
VISIT_RECORD = {} # 访问记录
GAME_OVER = {} # 被禁用的名单
class VisitThrottle(BaseThrottle):
"""访问次数"""
def __init__(self):
self.history = None # 访问记录
def allow_request(self, request, view):
ctime = time.time()
# 访问ip
remote_addr = request.META.get('REMOTE_ADDR')
# 访问浏览器
http_user = request.META.get('HTTP_USER_AGENT')
self.remote_addr = remote_addr
if not http_user:
return False
if remote_addr in GAME_OVER:
"""被限制的ip10分钟内禁止访问"""
if ctime-GAME_OVER[remote_addr] <= 10*60:
return False
if remote_addr not in VISIT_RECORD:
"""该ip第一次访问"""
VISIT_RECORD[remote_addr] = [ctime]
return True
self.history = VISIT_RECORD.get(remote_addr)
while self.history and 60 < ctime - self.history[-1]:
"""如果当前访问记录中的访问记录距当前时间超过1分钟,则将该访问记录进行清空"""
self.history.pop()
if len(self.history) < 20:
"""如果访问记录小于20个则可继续访问"""
self.history.insert(0, ctime)
return True
else:
GAME_OVER[remote_addr] = ctime
return False
def wait(self):
ctime = time.time()
return 60*10 - (ctime - GAME_OVER[self.remote_addr])
class PublishDetail(viewsets.ModelViewSet):
throttle_classes = [VisitThrottle]
queryset = Publish.objects.all()
serializer_class = PublishModelSerializers
REST_FRAMEWORK = {
'throttle_classes': [VisitThrottle]
}
APIView dispatch
def dispatch(self, request, *args, **kwargs):
...
request = self.initialize_request(request, *args, **kwargs)
self.request = request
self.headers = self.default_response_headers # deprecate?
try:
self.initial(request, *args, **kwargs)
...
initial
方法
def initial(self, request, *args, **kwargs):
...
# Ensure that the incoming request is permitted
self.perform_authentication(request)
self.check_permissions(request)
# 限流访问
self.check_throttles(request)
check_throttles
def check_throttles(self, request):
"""
Check if request should be throttled.
Raises an appropriate exception if the request is throttled.
"""
for throttle in self.get_throttles():
"""
def get_throttles(self):
return [throttle() for throttle in self.throttle_classes]
throttle_classes 是我门在类中注册时写的 throttle_classes
"""
if not throttle.allow_request(request, self):
# 这里可以看出 我们需要写一个 allow_request()方法 返回True or False
self.throttled(request, throttle.wait())
而 restframework 中 内置了不少的限流访问的方法,我们可以直接调用
这是一个基类, 所有的频率访问的类都继承它
class BaseThrottle(object):
"""
Rate throttling of requests.
"""
def allow_request(self, request, view):
"""
Return `True` if the request should be allowed, `False` otherwise.
"""
raise NotImplementedError('.allow_request() must be overridden')
def get_ident(self, request):
"""
Identify the machine making the request by parsing HTTP_X_FORWARDED_FOR
if present and number of proxies is > 0. If not use all of
HTTP_X_FORWARDED_FOR if it is available, if not use REMOTE_ADDR.
"""
xff = request.META.get('HTTP_X_FORWARDED_FOR')
remote_addr = request.META.get('REMOTE_ADDR')
num_proxies = api_settings.NUM_PROXIES
if num_proxies is not None:
if num_proxies == 0 or xff is None:
return remote_addr
addrs = xff.split(',')
client_addr = addrs[-min(num_proxies, len(addrs))]
return client_addr.strip()
return ''.join(xff.split()) if xff else remote_addr
def wait(self):
"""
Optionally, return a recommended number of seconds to wait before
the next request.
"""
return None
提供了3个方法:
allow_request
get_ident
wait
这是restframework 提供的类,所以我们可以直接调用它来完成我们的需求。
在settings.py
中配置
REST_FRAMEWORK = {
"DEFAULT_THROTTLE_RATES": {
"Night": "3/m"
}
}
class VisitDemo(SimpleRateThrottle):
scope = 'Night'
def get_cache_key(self, request, view):
return self.get_ident(request)
既然这个类是继承BaseThrottle, 那么我们就先看__init__
然后看 allow_request
毕竟调用的话主要调用的allow_request
class SimpleRateThrottle(BaseThrottle):
cache = default_cache
timer = time.time # 获取当前时间
cache_format = 'throttle_%(scope)s_%(ident)s' # 一个显示方式
scope = None
THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES
def __init__(self):
if not getattr(self, 'rate', None):
self.rate = self.get_rate()
# 获取rate
"""
def get_rate(self):
if not getattr(self, 'scope', None):
msg = ("You must set either `.scope` or `.rate` for '%s' throttle" %
self.__class__.__name__)
raise ImproperlyConfigured(msg)
try:
return self.THROTTLE_RATES[self.scope]
# THROTTLE_RATES = DEFAULT_THROTTLE_RATES 也就是我们自己写的配置
"DEFAULT_THROTTLE_RATES": {
"Night": "3/m"
}
拿到 3/m
except KeyError:
msg = "No default throttle rate set for '%s' scope" % self.scope
raise ImproperlyConfigured(msg)
"""
# self.rate = "3/m"
self.num_requests, self.duration = self.parse_rate(self.rate)
"""
def parse_rate(self, rate):
if rate is None:
return (None, None)
num, period = rate.split('/')
num_requests = int(num)
duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
return (num_requests, duration)
self.num_requests = 3 限制次数
self.duration = m 限制时间
"""
def get_cache_key(self, request, view):
"""
Should return a unique cache-key which can be used for throttling.
Must be overridden.
May return `None` if the request should not be throttled.
"""
"""这个方法是需要我们自己重载的 获取用户的唯一标识"""
raise NotImplementedError('.get_cache_key() must be overridden')
def allow_request(self, request, view):
"""
Implement the check to see if the request should be throttled.
On success calls `throttle_success`.
On failure calls `throttle_failure`.
"""
if self.rate is None:
return True
self.key = self.get_cache_key(request, view)
if self.key is None:
return True
self.history = self.cache.get(self.key, [])
# 将访问记录加入历史记录
self.now = self.timer()
# 获取当前时间
# Drop any requests from the history which have now passed the
# throttle duration
while self.history and self.history[-1] <= self.now - self.duration:
# 判断访问次数
self.history.pop()
if len(self.history) >= self.num_requests:
# 校验是否有权限访问
return self.throttle_failure()
return self.throttle_success()
def throttle_success(self):
"""
Inserts the current request's timestamp along with the key
into the cache.
"""
self.history.insert(0, self.now)
self.cache.set(self.key, self.history, self.duration)
return True
def throttle_failure(self):
"""
Called when a request to the API has failed due to throttling.
"""
return False
def wait(self):
"""
Returns the recommended next request time in seconds.
"""
if self.history:
remaining_duration = self.duration - (self.now - self.history[-1])
else:
remaining_duration = self.duration
available_requests = self.num_requests - len(self.history) + 1
if available_requests <= 0:
return None
return remaining_duration / float(available_requests)
其实如果整体走下来会发现, 自己写的基础使用基本仿照它写的。
下面的几个类都是继承SimpleRateThrottle
也就是和我们使用的形式基本一致
class AnonRateThrottle(SimpleRateThrottle):
"""
Limits the rate of API calls that may be made by a anonymous users.
The IP address of the request will be used as the unique cache key.
"""
scope = 'anon'
def get_cache_key(self, request, view):
if request.user.is_authenticated:
return None # Only throttle unauthenticated requests.
return self.cache_format % {
'scope': self.scope,
'ident': self.get_ident(request)
}
class UserRateThrottle(SimpleRateThrottle):
"""
Limits the rate of API calls that may be made by a given user.
The user id will be used as a unique cache key if the user is
authenticated. For anonymous requests, the IP address of the request will
be used.
"""
scope = 'user'
def get_cache_key(self, request, view):
if request.user.is_authenticated:
ident = request.user.pk
else:
ident = self.get_ident(request)
return self.cache_format % {
'scope': self.scope,
'ident': ident
}
class ScopedRateThrottle(SimpleRateThrottle):
"""
Limits the rate of API calls by different amounts for various parts of
the API. Any view that has the `throttle_scope` property set will be
throttled. The unique cache key will be generated by concatenating the
user id of the request, and the scope of the view being accessed.
"""
scope_attr = 'throttle_scope'
def __init__(self):
# Override the usual SimpleRateThrottle, because we can't determine
# the rate until called by the view.
pass
def allow_request(self, request, view):
# We can only determine the scope once we're called by the view.
self.scope = getattr(view, self.scope_attr, None)
# If a view does not have a `throttle_scope` always allow the request
if not self.scope:
return True
# Determine the allowed request rate as we normally would during
# the `__init__` call.
self.rate = self.get_rate()
self.num_requests, self.duration = self.parse_rate(self.rate)
# We can now proceed as normal.
return super(ScopedRateThrottle, self).allow_request(request, view)
def get_cache_key(self, request, view):
"""
If `view.throttle_scope` is not set, don't apply this throttle.
Otherwise generate the unique cache key by concatenating the user id
with the '.throttle_scope` property of the view.
"""
if request.user.is_authenticated:
ident = request.user.pk
else:
ident = self.get_ident(request)
return self.cache_format % {
'scope': self.scope,
'ident': ident
}