首先一个请求会经过类视图的diapatch()方法,就以此方法为突破口,分析restFramework的认证系统。
先将类视图代码贴一下:
class Authentication(Object):
def authenticate(self, request):
# 内部进行验证,并返回一个元组
return (xx, xxx) # request.user, request.auth 最后会赋值给他俩
def authentication_header(self, request): # 这个必须写,不写就报错
pass
# 类视图
class ClassView(APIView):
authentication_classes = [Authentication, ]
def get(self, request, *args, **kwargs):
pass
def post(self, request, *args, **kwargs):
pass
-dispatch()源码:
为什么会执行dispatch()方法呢?因为只要是类视图,再urls.py中就会执行as_view()方法,restframework中的APIView类重写了as_view()方法,因为APIView是继承了django的View类,但是跟django上的as_view()差不多,APIView(View)下面有这样一行代码:view = super(APIView, cls).as_view(**initkwargs),其实就是调用APIView的父类的as_view()方法,其返回view()方法,源码如下:
def view(request, *args, **kwargs):
self = cls(**initkwargs)
if hasattr(self, 'get') and not hasattr(self, 'head'):
self.head = self.get
self.request = request
self.args = args
self.kwargs = kwargs
return self.dispatch(request, *args, **kwargs) # 这里返回dispatch()
def dispatch(self, request, *args, **kwargs):
self.args = args
self.kwargs = kwargs
# 其他的先不管,这里对原request进行了封装,现在进入initialize_request看看发生了什么
request = self.initialize_request(request, *args, **kwargs)
self.request = request
self.headers = self.default_response_headers # deprecate?
try:
self.initial(request, *args, **kwargs)
# Get the appropriate handler method
if request.method.lower() in self.http_method_names:
handler = getattr(self, request.method.lower(),
self.http_method_not_allowed)
else:
handler = self.http_method_not_allowed
response = handler(request, *args, **kwargs)
except Exception as exc:
response = self.handle_exception(exc)
self.response = self.finalize_response(request, response, *args, **kwargs)
return self.response
-initialize_request()源码:
def initialize_request(self, request, *args, **kwargs):
parser_context = self.get_parser_context(request)
# 这个函数返回了Request类实列,再进入Request类看看
return Request(
request,
parsers=self.get_parsers(),
authenticators=self.get_authenticators(),
negotiator=self.get_content_negotiator(),
parser_context=parser_context
)
-Request类的构造函数:
def __init__(self, request, parsers=None, authenticators=None,
negotiator=None, parser_context=None):
新的request已经是Request类的实列了,其中传入了原生的request,authenticators认证器,其他的也不知道什么意思。
如果调用原生的request,要这样写:request._request,具体看Request()的构造函数
这个验证器是怎么回事呢?authenticators=self.get_authenticators()
看看get_authenticators()函数:
def get_authenticators(self):
return [auth() for auth in self.authentication_classes]
[auth() for auth in self.authentication_classes]
这是一个列表生成式,不过里面元素是类实列
-其中在restframework中的APIView(View)类中,有这样一个字段:
authentication_classes = api_settings.DEFAULT_AUTHENTICATION_CLASSES
这个authentication_classes,如果没在api_settings中定义DEFAULT_AUTHENTICATION_CLASSES,就需要在你的类视图中添加
一般情况下,写API时候,如果要验证,是要重写验证类的:
class Authentication(Object):
def authenticate(self, request):
# 内部进行验证,并返回一个元组
return (xx, xxx) # request.user, request.auth 最后会赋值给他俩
def authentication_header(self, request): # 这个必须写,不写就报错
pass
# 类视图
class ClassView(APIView):
authentication_classes = [Authentication, ]
def get(self, request, *args, **kwargs):
pass
def post(self, request, *args, **kwargs):
pass
-现在再看下一步,dispatch()中try里的 self.initial(request, *args, **kwargs),这就是对封装后的request进行验证了
-initial(request, *args, **kwargs)源码:
def initial(self, request, *args, **kwargs):
self.format_kwarg = self.get_format_suffix(**kwargs)
neg = self.perform_content_negotiation(request)
request.accepted_renderer, request.accepted_media_type = neg
version, scheme = self.determine_version(request, *args, **kwargs)
request.version, request.versioning_scheme = version, scheme
# 其他的也看不懂,反正看到authentication了,这个就是执行验证啊
self.perform_authentication(request)
self.check_permissions(request)
self.check_throttles(request)
-perform_authentication(request)源码:
def perform_authentication(self, request):
request.user
这里是什么鬼?先ctrl+B进去看看user
这里的request实际上是Request(Object)的实列了,实列调用类方法,因为使用了@property,所以不用括号()
Request(Object)的user()方法:
@property
def user(self):
# self指的是当前Request(Object)类的实列,看它的构造函数,是没有_user属性的
if not hasattr(self, '_user'):
with wrap_attributeerrors():
self._authenticate() # 所以执行_authenticate()
return self._user
Request(Object)的_authenticate()方法:
def _authenticate(self):
for authenticator in self.authenticators: # 循环重写的认证实列,前面写有
try:
# 执行自定义认证类的authenticate()
# 1.认证成功,看user_auth_tuple名字知道是一个元组,执行下方if代码
# 2.没有认证成功,抛出异常。
# 3.如果自定义的认证器都通过,又没有抛出异常,那就跳出for循环,执行self._not_authenticated()
user_auth_tuple = authenticator.authenticate(self)
except exceptions.APIException:
self._not_authenticated()
raise
if user_auth_tuple is not None:
self._authenticator = authenticator
# 将元组中的值赋给request
self.user, self.auth = user_auth_tuple
return
self._not_authenticated()
-通过验证之后,dispatch会根据请求方法找到相应的视图函数并返回。
-REST Framework内置有认证基类:
from rest_framework.authentication import BaseAuthentication
class BaseAuthentication(object):
"""
All authentication classes should extend BaseAuthentication.
"""
def authenticate(self, request):
"""
Authenticate the request and return a two-tuple of (user, token).
"""
# 这里提示,.authenticate()必须要重写。前面有个点,就是说它是一个方法
raise NotImplementedError(".authenticate() must be overridden.")
def authenticate_header(self, request):
"""
Return a string to be used as the value of the `WWW-Authenticate`
header in a `401 Unauthenticated` response, or `None` if the
authentication scheme should return `403 Permission Denied` responses.
"""
pass
这个基类写的很清楚了,All authentication classes should extend BaseAuthentication,就不翻译了。
所以上面的实列代码需要改一下:
class Authentication(BaseAuthentication):
def authenticate(self, request):
# Authenticate the request and return a two-tuple of (user, token).
# 内部进行验证,并返回一个元组
return (user, token) # request.user, request.auth 最后会赋值给他俩
另外一个方法就可以不用写了,因为父类有嘛。。。。
另外一个方法看文档描述应该是认证失败的时候用的,目前还没研究到。。。
2019.3.22更新
下面将用户登录认证的代码贴上
1.首先要使用restFramework,要再settings.py中进行app注册
2.贴代码:
def md5(user):
ctime = str(time.time())
token = hashlib.md5()
token.update(user.encode('utf-8'))
token.update(ctime.encode('utf-8'))
return token.hexdigest()
# 如果不想让未登录的用户访问某些页面
# 那么就可以对未登录用户进行验证,下面是认证器,是否携带登录时产生的token码
class UserAuthentication(BaseAuthentication):
def authenticate(self, request):
# 提取token,这里没有写前端,只是使用Postman进行测试
token = request._request.GET.get('token')
# 验证token是否正确
token_obj = UserToken.objects.filter(token=token).first()
if not token_obj:
raise exceptions.AuthenticationFailed('用户未登录')
else:
# 下面的元组内的元素会分别赋值给 request.user request.auth
return (token_obj.user, token_obj)
# 先写一个rest framework登录视图
class LoginView(APIView):
def post(self, request, *args, **kwargs):
# 定义一个result,返回给前端
restful = {'code': 100, 'message': None, 'data': None}
username = request._request.POST.get('username')
password = request._request.POST.get('password')
try:
user_obj = UserInfo.objects.get(username=username, password=password)
if user_obj:
token = md5(username)
# 如果用户登录,创建一个token值,第一次登录就创建,否则就是更新
UserToken.objects.update_or_create(user=user_obj, defaults={'token': token})
restful['code'] = 200
restful['message'] = "登录成功"
else:
restful['code'] = 400
restful['message'] = "用户输入错误"
except Exception as e:
pass
return JsonResponse(restful)
# 再写一个类视图,用来限制未登录的用户访问 个人中心吧
class UserCenter(APIView):
authentication_classes=[UserAuthentication,]
def get(self, request, *args, **kwargs):
return HttpResponse('欢迎来到个人中心')
def post(self, request, *args, **kwargs):
return HttpResponse('post')
2019.3.24更新
昨天给自己放了个假。。。
除了用户登录认证,还有权限认证了。
如果用户认证源码搞定了,那么后面的源码都是差不多了。。
权限认证的代码,这里就不贴模型代码了。。。
# 只有黄金会员权限访问
class UserPermission(BasePermission):
message = '没有权限访问'
def has_permission(self, request, view):
# 只有登录后才可以访问的,所以先经过登录认证,这时可以从request.user获取用户
user = request.user
user_type = user.user_type
if user_type == 3:
return True
else:
# raise exceptions.PermissionDenied('没有权限访问') 这里最好不要这么写,因为验证失败,系统自己会抛出错误,定义一个message就好
# if not permission.has_permission(request, self): 这里会用has_permission()返回的值进行判断
return False
-再写一个类视图代码:
class UserSVip(APIView):
authentication_classes = [UserAuthentication,] # 添加前面的用户认证器
permission_classes = [UserPermission,] # 添加权限
def get(self, request, *args, **kwargs):
return HttpResponse('这是只有黄金会员才能访问的页面。。。')
def post(self, request, *args, **kwargs):
pass
-权限认证还是比较简单的,继承权限的基类,重写has_permission()方法就好了。。
-接下来是节流认证,用来控制用户访问频率,免得恶意一直访问。。。
源码如下,这是rest_framework自带的一个节流系统,比自己写好多了,拿来就用。。
class SimpleRateThrottle(BaseThrottle):
"""
A simple cache implementation, that only requires `.get_cache_key()`
to be overridden.
The rate (requests / seconds) is set by a `rate` attribute on the View
class. The attribute is a string of the form 'number_of_requests/period'.
Period should be one of: ('s', 'sec', 'm', 'min', 'h', 'hour', 'd', 'day')
Previous request information used for throttling is stored in the cache.
"""
cache = default_cache #使用django的缓存
timer = time.time # 拿到当前时间
cache_format = 'throttle_%(scope)s_%(ident)s'
scope = None # 先定义为None, 实际写的时候这个值要进行覆盖
# 再django的设置中写的rest_famework设置,它是一个字典,key为scope的值,value为rate
THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES
def __init__(self): # 实列化的时候,就会执行构造函数
if not getattr(self, 'rate', None): # 如果没有rate
self.rate = self.get_rate() # 得到一个标识值,下面有get_rate()函数
# 拿到rate值后,执行parse_rate()函数,在下面
self.num_requests, self.duration = self.parse_rate(self.rate)
# 获取缓存key的函数,这里面什么都没有,需要重写覆盖
def get_cache_key(self, request, view):
"""
Should return a unique cache-key which can be used for throttling.
Must be overridden.
May return `None` if the request should not be throttled.
"""
raise NotImplementedError('.get_cache_key() must be overridden')
def get_rate(self):
"""
Determine the string representation of the allowed request rate.
"""
if not getattr(self, 'scope', None): # 判断是否有scope值
msg = ("You must set either `.scope` or `.rate` for '%s' throttle" %
self.__class__.__name__)
raise ImproperlyConfigured(msg)
try:
# 如果有scope值的话,从配置中拿到,scope是作为一个健存在的,拿到value,也就是rate
return self.THROTTLE_RATES[self.scope]
except KeyError:
msg = "No default throttle rate set for '%s' scope" % self.scope
raise ImproperlyConfigured(msg)
def parse_rate(self, rate):
"""
Given the request rate string, return a two tuple of:
,
"""
if rate is None:
return (None, None)
# 其实rate是一个字符串,类似于“5/m”,每分钟最多访问五次,这里进行分割
num, period = rate.split('/')
num_requests = int(num)
# 不一定非要设置为m,还有每天,每小时等等
duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
return (num_requests, duration)
def allow_request(self, request, view): # 然后执行这个函数
"""
Implement the check to see if the request should be throttled.
On success calls `throttle_success`.
On failure calls `throttle_failure`.
"""
if self.rate is None:
return True
# 设置一个缓存key,从当前访问request中提取一个值,可以将用户id,ip作为key,以此函数返回的值为准
self.key = self.get_cache_key(request, view)
if self.key is None:
return True
# 从缓存中拿到key对应的value,如果没有就设置一个空列表
self.history = self.cache.get(self.key, [])
self.now = self.timer()
# Drop any requests from the history which have now passed the
# throttle duration
# 这里对history进行判断,没有值,while就不执行了,因为你没有访问记录再value中
# duration==60
while self.history and self.history[-1] <= self.now - self.duration:
self.history.pop()
if len(self.history) >= self.num_requests:
return self.throttle_failure()
# 上面几行代码都是对history进行判断,满足情况就执行throttle_success(),不满足就执行throttle_failure()
return self.throttle_success()
def throttle_success(self):
"""
Inserts the current request's timestamp along with the key
into the cache.
"""
self.history.insert(0, self.now) # 将刚访问的实际插入到history中
# 设置缓存的value,有效时间为duration的值
self.cache.set(self.key, self.history, self.duration)
return True
def throttle_failure(self):
"""
Called when a request to the API has failed due to throttling.
"""
return False
# 下面的就很简单了,如果访问频率超过限制,就会执行这个函数
def wait(self):
"""
Returns the recommended next request time in seconds.
"""
if self.history:
remaining_duration = self.duration - (self.now - self.history[-1])
else:
remaining_duration = self.duration
available_requests = self.num_requests - len(self.history) + 1
if available_requests <= 0:
return None
return remaining_duration / float(available_requests)
-从上面的代码分析来看,如果使用自带的节流认证SimpleRateThrottle(BaseThrottle):,只需继承这个类,然后设置scope的值,这个值随便写,想写啥就写啥,再写一个get_cache_key()函数,返回缓存需要的值就好了,得到这个值可以根据基类BaseThrottle(Object)中的get_ident(self, request)方法就好了。
所以,只需要写的代码就四行:
class UserThrottle(SimpleRateThrottle):
scope = 'stu'
def get_cache_key(self, request, view):
return self.get_ident(request)
然后再django配置中:
REST_FRAMEWORK = {
"DEFAULT_THROTTLE_RATES":{
'stu': '5/m',
}
}
配置到类视图中:
class UserSVip(APIView):
authentication_classes = [UserAuthentication,]
permission_classes = [UserPermission,]
throttle_classes = [UserThrottle,]
def get(self, request, *args, **kwargs):
return HttpResponse('这是只有黄金会员才能访问的页面。。。')
def post(self, request, *args, **kwargs):
pass
这样就简单了分析了用户认证,权限认证,节流认证三个基本的功能,还有很多需要自己去研究。。。。