关于REST Framework认证分析

首先一个请求会经过类视图的diapatch()方法,就以此方法为突破口,分析restFramework的认证系统。
先将类视图代码贴一下:

    class Authentication(Object):
    
    	def authenticate(self, request):
    		#  内部进行验证,并返回一个元组
    		return (xx, xxx)   # request.user, request.auth  最后会赋值给他俩
    		
    	def authentication_header(self, request):  # 这个必须写,不写就报错
    		pass
				
	# 类视图
	class ClassView(APIView):
				
		authentication_classes = [Authentication, ]
				
		def get(self, request, *args, **kwargs):
			pass
				
		def post(self, request, *args, **kwargs):
			pass

-dispatch()源码:
为什么会执行dispatch()方法呢?因为只要是类视图,再urls.py中就会执行as_view()方法,restframework中的APIView类重写了as_view()方法,因为APIView是继承了django的View类,但是跟django上的as_view()差不多,APIView(View)下面有这样一行代码:view = super(APIView, cls).as_view(**initkwargs),其实就是调用APIView的父类的as_view()方法,其返回view()方法,源码如下:

    def view(request, *args, **kwargs):
        self = cls(**initkwargs)
        if hasattr(self, 'get') and not hasattr(self, 'head'):
            self.head = self.get
        self.request = request
        self.args = args
        self.kwargs = kwargs
        return self.dispatch(request, *args, **kwargs)  # 这里返回dispatch()
	    
    def dispatch(self, request, *args, **kwargs):
    	self.args = args
    	self.kwargs = kwargs
    	# 其他的先不管,这里对原request进行了封装,现在进入initialize_request看看发生了什么
        request = self.initialize_request(request, *args, **kwargs)
    	self.request = request
    	self.headers = self.default_response_headers  # deprecate?
    	
	    try:
	       self.initial(request, *args, **kwargs)
	
	       # Get the appropriate handler method
	       if request.method.lower() in self.http_method_names:
	           handler = getattr(self, request.method.lower(),
	                                  self.http_method_not_allowed)
	       else:
	           handler = self.http_method_not_allowed
	
	       response = handler(request, *args, **kwargs)
	
	    except Exception as exc:
	       response = self.handle_exception(exc)
	
	    self.response = self.finalize_response(request, response, *args, **kwargs)
	    return self.response

-initialize_request()源码:

	def initialize_request(self, request, *args, **kwargs):
		parser_context = self.get_parser_context(request)
		# 这个函数返回了Request类实列,再进入Request类看看
		return Request(
			request,
			parsers=self.get_parsers(),
			authenticators=self.get_authenticators(),
			negotiator=self.get_content_negotiator(),
			parser_context=parser_context
		)

-Request类的构造函数:

    def __init__(self, request, parsers=None, authenticators=None,
                     negotiator=None, parser_context=None):

新的request已经是Request类的实列了,其中传入了原生的request,authenticators认证器,其他的也不知道什么意思。

如果调用原生的request,要这样写:request._request,具体看Request()的构造函数
这个验证器是怎么回事呢?authenticators=self.get_authenticators()
看看get_authenticators()函数:

	def get_authenticators(self):
		return [auth() for auth in self.authentication_classes]

[auth() for auth in self.authentication_classes] 这是一个列表生成式,不过里面元素是类实列

-其中在restframework中的APIView(View)类中,有这样一个字段:

authentication_classes = api_settings.DEFAULT_AUTHENTICATION_CLASSES

这个authentication_classes,如果没在api_settings中定义DEFAULT_AUTHENTICATION_CLASSES,就需要在你的类视图中添加
一般情况下,写API时候,如果要验证,是要重写验证类的:

     class Authentication(Object):
     
     	def authenticate(self, request):
     		#  内部进行验证,并返回一个元组
     		return (xx, xxx)   # request.user, request.auth  最后会赋值给他俩
     		
     	def authentication_header(self, request):  # 这个必须写,不写就报错
     		pass
 				
 	# 类视图
 	class ClassView(APIView):
 				
 		authentication_classes = [Authentication, ]
 				
 		def get(self, request, *args, **kwargs):
 			pass
 				
 		def post(self, request, *args, **kwargs):
 			pass

-现在再看下一步,dispatch()中try里的 self.initial(request, *args, **kwargs),这就是对封装后的request进行验证了
-initial(request, *args, **kwargs)源码:

	def initial(self, request, *args, **kwargs):
		self.format_kwarg = self.get_format_suffix(**kwargs)
		neg = self.perform_content_negotiation(request)
		request.accepted_renderer, request.accepted_media_type = neg

		version, scheme = self.determine_version(request, *args, **kwargs)
		request.version, request.versioning_scheme = version, scheme

		# 其他的也看不懂,反正看到authentication了,这个就是执行验证啊
		self.perform_authentication(request)  
		self.check_permissions(request)
		self.check_throttles(request)
			

-perform_authentication(request)源码:

	def perform_authentication(self, request):
		request.user

这里是什么鬼?先ctrl+B进去看看user
这里的request实际上是Request(Object)的实列了,实列调用类方法,因为使用了@property,所以不用括号()

Request(Object)的user()方法:

    @property
    def user(self):
    	# self指的是当前Request(Object)类的实列,看它的构造函数,是没有_user属性的
    	if not hasattr(self, '_user'):  
			with wrap_attributeerrors(): 
				self._authenticate()   # 所以执行_authenticate()
		return self._user

Request(Object)的_authenticate()方法:

    def _authenticate(self):
    	for authenticator in self.authenticators:  # 循环重写的认证实列,前面写有
    		try:
		    	# 执行自定义认证类的authenticate()
		    	# 1.认证成功,看user_auth_tuple名字知道是一个元组,执行下方if代码
		    	# 2.没有认证成功,抛出异常。
		    	# 3.如果自定义的认证器都通过,又没有抛出异常,那就跳出for循环,执行self._not_authenticated()
    			user_auth_tuple = authenticator.authenticate(self)
    		except exceptions.APIException:
    			self._not_authenticated()
    			raise

			if user_auth_tuple is not None:
				self._authenticator = authenticator
				# 将元组中的值赋给request
				self.user, self.auth = user_auth_tuple  
				return
			
		self._not_authenticated()
		

-通过验证之后,dispatch会根据请求方法找到相应的视图函数并返回。

-REST Framework内置有认证基类:

    from rest_framework.authentication import BaseAuthentication
    
    class BaseAuthentication(object):
	    """
	    All authentication classes should extend BaseAuthentication.
	    """

	    def authenticate(self, request):
		    """
	        Authenticate the request and return a two-tuple of (user, token).
	        """
	        #  这里提示,.authenticate()必须要重写。前面有个点,就是说它是一个方法
	        raise NotImplementedError(".authenticate() must be overridden.")
	
	    def authenticate_header(self, request):
		    """
	        Return a string to be used as the value of the `WWW-Authenticate`
	        header in a `401 Unauthenticated` response, or `None` if the
	        authentication scheme should return `403 Permission Denied` responses.
	        """
	        pass

这个基类写的很清楚了,All authentication classes should extend BaseAuthentication,就不翻译了。
所以上面的实列代码需要改一下:

	class Authentication(BaseAuthentication):
    
    	def authenticate(self, request):
    		#  Authenticate the request and return a two-tuple of (user, token).
    		#  内部进行验证,并返回一个元组
    		return (user, token)   # request.user, request.auth  最后会赋值给他俩

另外一个方法就可以不用写了,因为父类有嘛。。。。
另外一个方法看文档描述应该是认证失败的时候用的,目前还没研究到。。。

2019.3.22更新
下面将用户登录认证的代码贴上
1.首先要使用restFramework,要再settings.py中进行app注册
2.贴代码:

	def md5(user):
	    ctime = str(time.time())
	    token = hashlib.md5()
	    token.update(user.encode('utf-8'))
	    token.update(ctime.encode('utf-8'))
	    return token.hexdigest()

	# 如果不想让未登录的用户访问某些页面
	# 那么就可以对未登录用户进行验证,下面是认证器,是否携带登录时产生的token码
	class UserAuthentication(BaseAuthentication):
	
	    def authenticate(self, request):
	        # 提取token,这里没有写前端,只是使用Postman进行测试
	        token = request._request.GET.get('token')
	        # 验证token是否正确
	        token_obj = UserToken.objects.filter(token=token).first()
	        if not token_obj:
	                raise exceptions.AuthenticationFailed('用户未登录')
	        else:
	            # 下面的元组内的元素会分别赋值给 request.user request.auth
	            return (token_obj.user, token_obj)



	# 先写一个rest framework登录视图
	class LoginView(APIView):
	
	    def post(self, request, *args, **kwargs):
	        # 定义一个result,返回给前端
	        restful = {'code': 100, 'message': None, 'data': None}
	        username = request._request.POST.get('username')
	        password = request._request.POST.get('password')
	        try:
	            user_obj = UserInfo.objects.get(username=username, password=password)
            if user_obj:
	                token = md5(username)
	                # 如果用户登录,创建一个token值,第一次登录就创建,否则就是更新
	                UserToken.objects.update_or_create(user=user_obj, defaults={'token': token})
	                restful['code'] = 200
	                restful['message'] = "登录成功"
	            else:
	                restful['code'] = 400
	                restful['message'] = "用户输入错误"
	        except Exception as e:
	            pass
	        return JsonResponse(restful)
	
	
	# 再写一个类视图,用来限制未登录的用户访问 个人中心吧
	class UserCenter(APIView):
	    authentication_classes=[UserAuthentication,]
	    def get(self, request, *args, **kwargs):
	        return HttpResponse('欢迎来到个人中心')
	
	    def post(self, request, *args, **kwargs):
	        return HttpResponse('post')

2019.3.24更新
昨天给自己放了个假。。。

除了用户登录认证,还有权限认证了。
如果用户认证源码搞定了,那么后面的源码都是差不多了。。
权限认证的代码,这里就不贴模型代码了。。。

# 只有黄金会员权限访问
	class UserPermission(BasePermission):
		message = '没有权限访问'
	    def has_permission(self, request, view):
	        # 只有登录后才可以访问的,所以先经过登录认证,这时可以从request.user获取用户
	        user = request.user
	        user_type = user.user_type
	        if user_type == 3:
	            return True
	        else:
	            # raise exceptions.PermissionDenied('没有权限访问')  这里最好不要这么写,因为验证失败,系统自己会抛出错误,定义一个message就好
	            # if not permission.has_permission(request, self):  这里会用has_permission()返回的值进行判断
	            return False

-再写一个类视图代码:

	class UserSVip(APIView):
	
	    authentication_classes = [UserAuthentication,]  # 添加前面的用户认证器
	    permission_classes = [UserPermission,]  # 添加权限
	
	    def get(self, request, *args, **kwargs):
	        return HttpResponse('这是只有黄金会员才能访问的页面。。。')
	
	    def post(self, request, *args, **kwargs):
	        pass

-权限认证还是比较简单的,继承权限的基类,重写has_permission()方法就好了。。

-接下来是节流认证,用来控制用户访问频率,免得恶意一直访问。。。
源码如下,这是rest_framework自带的一个节流系统,比自己写好多了,拿来就用。。

	class SimpleRateThrottle(BaseThrottle):
		"""
		A simple cache implementation, that only requires `.get_cache_key()`
		to be overridden.

		The rate (requests / seconds) is set by a `rate` attribute on the View
		class.  The attribute is a string of the form 'number_of_requests/period'.

		Period should be one of: ('s', 'sec', 'm', 'min', 'h', 'hour', 'd', 'day')

		Previous request information used for throttling is stored in the cache.
		"""
		cache = default_cache   #使用django的缓存
		timer = time.time  # 拿到当前时间
		cache_format = 'throttle_%(scope)s_%(ident)s'
		scope = None  # 先定义为None, 实际写的时候这个值要进行覆盖
		# 再django的设置中写的rest_famework设置,它是一个字典,key为scope的值,value为rate
		THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES  

		def __init__(self):  # 实列化的时候,就会执行构造函数
			if not getattr(self, 'rate', None):  # 如果没有rate
				self.rate = self.get_rate()  # 得到一个标识值,下面有get_rate()函数
				# 拿到rate值后,执行parse_rate()函数,在下面
			self.num_requests, self.duration = self.parse_rate(self.rate)  
		
		# 获取缓存key的函数,这里面什么都没有,需要重写覆盖
		def get_cache_key(self, request, view):  
			"""
			Should return a unique cache-key which can be used for throttling.
			Must be overridden.

			May return `None` if the request should not be throttled.
			"""
			raise NotImplementedError('.get_cache_key() must be overridden')

		def get_rate(self):
			"""
			Determine the string representation of the allowed request rate.
			"""
			if not getattr(self, 'scope', None):  # 判断是否有scope值
				msg = ("You must set either `.scope` or `.rate` for '%s' throttle" %
					   self.__class__.__name__)
				raise ImproperlyConfigured(msg)

			try:
			# 如果有scope值的话,从配置中拿到,scope是作为一个健存在的,拿到value,也就是rate
				return self.THROTTLE_RATES[self.scope]  
			except KeyError:
				msg = "No default throttle rate set for '%s' scope" % self.scope
				raise ImproperlyConfigured(msg)

		def parse_rate(self, rate):
			"""
			Given the request rate string, return a two tuple of:
			, 
			"""
			if rate is None:
				return (None, None)
			# 其实rate是一个字符串,类似于“5/m”,每分钟最多访问五次,这里进行分割
			num, period = rate.split('/')  
			num_requests = int(num)
			# 不一定非要设置为m,还有每天,每小时等等
			duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]  
			return (num_requests, duration)

		def allow_request(self, request, view):  # 然后执行这个函数
			"""
			Implement the check to see if the request should be throttled.

			On success calls `throttle_success`.
			On failure calls `throttle_failure`.
			"""
			if self.rate is None:
				return True
			# 设置一个缓存key,从当前访问request中提取一个值,可以将用户id,ip作为key,以此函数返回的值为准
			self.key = self.get_cache_key(request, view)  
			if self.key is None:
				return True
			# 从缓存中拿到key对应的value,如果没有就设置一个空列表
			self.history = self.cache.get(self.key, [])  
			self.now = self.timer()

			# Drop any requests from the history which have now passed the
			# throttle duration
			# 这里对history进行判断,没有值,while就不执行了,因为你没有访问记录再value中
			# duration==60
			while self.history and self.history[-1] <= self.now - self.duration: 
				self.history.pop()
			if len(self.history) >= self.num_requests:
				return self.throttle_failure()
			 # 上面几行代码都是对history进行判断,满足情况就执行throttle_success(),不满足就执行throttle_failure()
			return self.throttle_success() 
 
		def throttle_success(self):
			"""
			Inserts the current request's timestamp along with the key
			into the cache.
			"""
			self.history.insert(0, self.now)  # 将刚访问的实际插入到history中
			# 设置缓存的value,有效时间为duration的值
			self.cache.set(self.key, self.history, self.duration)  
			return True

		def throttle_failure(self):
			"""
			Called when a request to the API has failed due to throttling.
			"""
			return False
		# 下面的就很简单了,如果访问频率超过限制,就会执行这个函数
		def wait(self):
			"""
			Returns the recommended next request time in seconds.
			"""
			if self.history:
				remaining_duration = self.duration - (self.now - self.history[-1])
			else:
				remaining_duration = self.duration

			available_requests = self.num_requests - len(self.history) + 1
			if available_requests <= 0:
				return None

			return remaining_duration / float(available_requests)

-从上面的代码分析来看,如果使用自带的节流认证SimpleRateThrottle(BaseThrottle):,只需继承这个类,然后设置scope的值,这个值随便写,想写啥就写啥,再写一个get_cache_key()函数,返回缓存需要的值就好了,得到这个值可以根据基类BaseThrottle(Object)中的get_ident(self, request)方法就好了。

所以,只需要写的代码就四行:

	class UserThrottle(SimpleRateThrottle):
		scope = 'stu'
		def get_cache_key(self, request, view):
			return self.get_ident(request)

然后再django配置中:

	REST_FRAMEWORK = {
				"DEFAULT_THROTTLE_RATES":{
				'stu': '5/m',
				}
			}

配置到类视图中:

	class UserSVip(APIView):
	
	    authentication_classes = [UserAuthentication,]
	    permission_classes = [UserPermission,]
	    throttle_classes = [UserThrottle,]
	
	    def get(self, request, *args, **kwargs):
	        return HttpResponse('这是只有黄金会员才能访问的页面。。。')
	
	    def post(self, request, *args, **kwargs):
	        pass

这样就简单了分析了用户认证,权限认证,节流认证三个基本的功能,还有很多需要自己去研究。。。。

你可能感兴趣的:(Python)