在上篇文章中,我们介绍了rest framework框架的3个认证函数,perform_authentication(身份验证)、permissions(许可验证)、throttles(节流认证、限制访问数量)
,而该类方法,是通过APIView继承于View类方法,通过dispatch方法反射的操作,通过initialize_request方法封装了一个request(Request类)
,之后在调用initial方法来执行认证函数,而在此之前initial方法还有赋值的操作,而是通过什么来赋值,作用是什么?让我们来通过源码来观察吧。
initial函数方法如下:
def initial(self, request, *args, **kwargs):
self.format_kwarg = self.get_format_suffix(**kwargs)
neg = self.perform_content_negotiation(request)
request.accepted_renderer, request.accepted_media_type = neg
version, scheme = self.determine_version(request, *args, **kwargs)
request.version, request.versioning_scheme = version, scheme
# 权限认证相关
self.perform_authentication(request)
self.check_permissions(request)
self.check_throttles(request)
我们先从version开始看,能发现version的中文为版本,且通过调用self.determine_version方法,返回的参数由接收,最后再将接收到的参数赋值给request中的version、versioning_scheme对象中。
determine_version函数如下:
def determine_version(self, request, *args, **kwargs):
if self.versioning_class is None:
return (None, None)
scheme = self.versioning_class()
return (scheme.determine_version(request, *args, **kwargs), scheme)
通过determine_version函数可以发现,参数需要传递request,且判断了self.versioning_class是否为空
,而根据我们前面源码分析判断,可以推断出,该变量是APIView方法中默认的全局配置,所以我们在自定义版本的时候就可以加上versioning_class或配置在全局中。通过代码可以发现,该方法没有进行遍历,且后面通过scheme = self.versioning_class()调用了函数
,也就是说我们在子类中定义的versioning_class不是列表,而是一个函数对象,通过scheme对象调用了determine_version方法,可以得知我们在自己定义类的时候必须传入determine_version函数
。
最后return返回了2个参数,第一个则为根据determine_version函数发送来的请求来获取到版本号
,而第二个返回参数则返回了versioning_class这个函数
。
所以我们继续返回到initial函数中,发现返回的两个参数赋值给了version, scheme
,而version, scheme又赋值给了request.version, request.versioning_scheme。
由此得出当我们子类自定义函数时,若想获取版本信息,可以通过request.version获取版本,而想获取自定义类方法的对象时可以通过request.versioning_scheme获取(子类没有就从父类找)。
urls如下:
from django.conf.urls import url
from api import views
urlpatterns = [
url(r'^api/v1/auth/$', views.AuthView.as_view()),
url(r'^api/v1/order/$', views.OrderView.as_view()),
url(r'^api/v1/users/$', views.UsersView.as_view()),
]
api/view如下:
from django.shortcuts import render, HttpResponse
from rest_framework.views import APIView
from django.http import JsonResponse
class ParamVersion(object):
def determine_version(self, request, *args, **kwargs):
version = request.query_params.get('version')
return version
class UsersView(APIView):
authentication_classes = []
permission_classes = []
throttle_classes = []
versioning_class = ParamVersion
def get(self, request, *args, **kwargs):
print(request.version, request.versioning_scheme)
return HttpResponse('ok')
此时访问http://127.0.0.1:8000/api/v1/users/?version=v1打印:
v1 <api.views.ParamVersion object at 0x000001CF29E7DD08>
可以发现通过request是可以获取到版本和ParamVersion函数的对象的,且在Reuqest类中定义了一个query_params函数,该函数的作用就是返回原生request中的GET请求的参数
(具体还有很多返回请求的函数可以自行通过Request类中查看)。
BaseVersioning类是django内置的version版本类,该方法作为多个类方法的父类(子类继承
),用于定义默认版本参数,以及版本的选择。
BaseVersioning源码如下:
class BaseVersioning:
default_version = api_settings.DEFAULT_VERSION
allowed_versions = api_settings.ALLOWED_VERSIONS
version_param = api_settings.VERSION_PARAM
def determine_version(self, request, *args, **kwargs):
msg = '{cls}.determine_version() must be implemented.'
raise NotImplementedError(msg.format(
cls=self.__class__.__name__
))
def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra):
return _reverse(viewname, args, kwargs, request, format, **extra)
def is_allowed_version(self, version):
if not self.allowed_versions:
return True
return ((version is not None and version == self.default_version) or
(version in self.allowed_versions))
看到api_settings我们知道了可以通过settings设置全局,所以我们就可以通过全局的设置来改变默认参数、允许的版本、传入的值通过什么来接收。
settings.py如下:
REST_FRAMEWORK = {
"DEFAULT_VERSION": 'v1',
"ALLOWED_VERSIONS": ['v1', 'v2'],
"VERSION_PARAM": 'version'
}
BaseVersioning类还有个reverse方法,该方法就是通过获取路由后面传递参数的值,然后调用django的reverse方法将url返回给用户。
urls.py如下:
from django.conf.urls import url
from api import views
urlpatterns = [
url(r'^api/v1/auth/$', views.AuthView.as_view()),
url(r'^api/v1/order/$', views.OrderView.as_view()),
url(r'^api/v1/users/$', views.UsersView.as_view(), name='uuu'),
]
api/views.py如下:
from django.shortcuts import render, HttpResponse
from rest_framework.views import APIView
from rest_framework.versioning import BaseVersioning
class ParamVersion(BaseVersioning):
def determine_version(self, request, *args, **kwargs):
version = request.query_params.get('version')
return version
class UsersView(APIView):
authentication_classes = []
permission_classes = []
throttle_classes = []
versioning_class = ParamVersion
def get(self, request, *args, **kwargs):
u1 = request.versioning_scheme.reverse(viewname='uuu', request=request)
print(request.version, request.versioning_scheme, u1)
return HttpResponse('ok')
此时访问http://127.0.0.1:8000/api/v1/users/?version=v2打印:
v2 <api.views.ParamVersion object at 0x00000286D85413C8> http://127.0.0.1:8000/api/v1/users/
QueryParameterVersioning方法继承于BaseVersioning的类
,该方法通过以GET请求的方式获取版本,且还可以通过内置的reverse类来获取URL。
QueryParameterVersioning源码如下:
class QueryParameterVersioning(BaseVersioning):
invalid_version_message = _('Invalid version in query parameter.')
def determine_version(self, request, *args, **kwargs):
version = request.query_params.get(self.version_param, self.default_version)
if not self.is_allowed_version(version):
raise exceptions.NotFound(self.invalid_version_message)
return version
def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra):
url = super().reverse(
viewname, args, kwargs, request, format, **extra
)
if request.version is not None:
return replace_query_param(url, self.version_param, request.version)
return url
可以发现QueryParameterVersioning中的determine_version方法通过request.query_params.get获取GET请求中的参数,通过我们定义的全局配置属性,self.version_param获取键值,还有默认属性
,之后通过is_allowed_version判断我们是否传了值,没传则显示默认值
。
而reverse方法则是调用了父类BaseVersioning的reverse方法。
is_allowed_version函数如下:
def is_allowed_version(self, version):
if not self.allowed_versions:
return True
return ((version is not None and version == self.default_version) or
(version in self.allowed_versions))
api/view如下:
from rest_framework.versioning import QueryParameterVersioning
from django.shortcuts import render, HttpResponse
class UsersView(APIView):
authentication_classes = []
permission_classes = []
throttle_classes = []
versioning_class = QueryParameterVersioning
def get(self, request, *args, **kwargs):
u1 = request.versioning_scheme.reverse(viewname='uuu', request=request)
print(request.version, request.versioning_scheme, u1)
return HttpResponse('ok')
此时打印的数据和之前一样的。
URLPathVersioning方法继承了BaseVersioning的类,该方法封装了版本、URL路径,可以使用版本放在路由中,通过调用父类的reverse方法来获取URL。
URLPathVersioning源码如下:
class URLPathVersioning(BaseVersioning):
"""
urlpatterns = [
re_path(r'^(?P[v1|v2]+)/users/$', users_list, name='users-list'),
re_path(r'^(?P[v1|v2]+)/users/(?P[0-9]+)/$', users_detail, name='users-detail')
]
"""
invalid_version_message = _('Invalid version in URL path.')
def determine_version(self, request, *args, **kwargs):
version = kwargs.get(self.version_param, self.default_version)
if version is None:
version = self.default_version
if not self.is_allowed_version(version):
raise exceptions.NotFound(self.invalid_version_message)
return version
def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra):
if request.version is not None:
kwargs = {
} if (kwargs is None) else kwargs
kwargs[self.version_param] = request.version
return super().reverse(
viewname, args, kwargs, request, format, **extra
)
从源码注释来看可以知道,该类方法通过kwargs的get方法获取通过正则来获取路由上的数据,且定义路由时的方式可以为(?P[v1|v2]+)
来定义版本,如果什么都没有就返回默认版本的信息。
而reverse方法则先判断request.version是不是存在,存在则赋值给kwargs,之后调用父类的BaseVersioning方法的reverse函数
,所以我们可以知道,在URLPathVersioning中也只需要传入request即可,内部会进行赋值操作的。
urls.py如下:
from django.conf.urls import url
from api import views
urlpatterns = [
url(r'^api/v1/auth/$', views.AuthView.as_view()),
url(r'^api/v1/order/$', views.OrderView.as_view()),
url(r'^api/(?P[v1|v2]+)/users/$' , views.UsersView.as_view(), name='uuu'),
]
app/views.py如下:
from rest_framework.versioning import URLPathVersioning
from django.shortcuts import render, HttpResponse
class UsersView(APIView):
authentication_classes = []
permission_classes = []
throttle_classes = []
versioning_class = URLPathVersioning
def get(self, request, *args, **kwargs):
# reverse(viewname='uuu', request={'version':'v1'})
u1 = request.versioning_scheme.reverse(viewname='uuu', request=request)
print(request.version, request.versioning_scheme, u1)
return HttpResponse('ok')
此时访问http://127.0.0.1:8000/api/v1/users/打印:
v1 <rest_framework.versioning.URLPathVersioning object at 0x000001ACD1710488> http://127.0.0.1:8000/api/v1/users/
一般我们定义版本的时候都是放在路由中,所以为了方便,我们可以通过配置versioning_class来使得全局使用上。
settings.py如下:
REST_FRAMEWORK = {
"DEFAULT_AUTHENTICATION_CLASSES": ['api.utils.auth.Authtication', ],
"UNAUTHENTICATED_USER": None,
"UNAUTHENTICATED_TOKEN": None,
"DEFAULT_PERMISSION_CLASSES": ['api.utils.permissions.MyPermission', ],
"DEFAULT_THROTTLE_CLASSES": ['api.utils.throttle.UserThrottle'],
"DEFAULT_THROTTLE_RATES": {
"scope": '3/m',
"user": '5/m',
},
"DEFAULT_VERSION": 'v1',
"ALLOWED_VERSIONS": ['v1', 'v2'],
"VERSION_PARAM": 'version',
"DEFAULT_VERSIONING_CLASS": "rest_framework.versioning.URLPathVersioning"
}
当数据返回的时候,如果请求通过post请求返回,那么就需要遵循post请求数据的规范:
(request.post请求去request.body解析数据)
一般对于封装类似请求的东西,APIView都会封装在Request类中,之后在把原生request打包好,之后就可以通过打包好的函数调用,而解析器其实就封装在Request类中的data函数中
。
initialize_request函数如下:
def initialize_request(self, request, *args, **kwargs):
"""
Returns the initial request object.
"""
parser_context = self.get_parser_context(request)
return Request(
request,
parsers=self.get_parsers(),
authenticators=self.get_authenticators(),
negotiator=self.get_content_negotiator(),
parser_context=parser_context
)
可以发现Request类中parsers这个对象调用了get_parsers()方法。get_parsers函数如下:
def get_parsers(self):
return [parser() for parser in self.parser_classes]
可以发现通过调用self.parser_classes对象遍历成每一个函数,所以从这里我们就可以知道,要想在子类定义解析器,需写上self.parser_classes这个列表
。
Request/data函数如下:
@property
def data(self):
if not _hasattr(self, '_full_data'):
self._load_data_and_files()
return self._full_data
可以看到在data函数中添加了@property装饰器,所以在使用request.data的时候是无需传括号
的,且我们可以看到data函数中调用了_load_data_and_files函数。
_load_data_and_files函数源码如下:
def _load_data_and_files(self):
if not _hasattr(self, '_data'):
self._data, self._files = self._parse()
if self._files:
self._full_data = self._data.copy()
self._full_data.update(self._files)
else:
self._full_data = self._data
if is_form_media_type(self.content_type):
self._request._post = self.POST
self._request._files = self.FILES
此时可以看到_load_data_and_files方法又调用了self._parse(),且返回的参数给self._data, self._files。
self._parse()函数源码如下:
def _parse(self):
media_type = self.content_type
try:
stream = self.stream
except RawPostDataException:
if not hasattr(self._request, '_post'):
raise
if self._supports_form_parsing():
return (self._request.POST, self._request.FILES)
stream = None
if stream is None or media_type is None:
if media_type and is_form_media_type(media_type):
empty_data = QueryDict('', encoding=self._request._encoding)
else:
empty_data = {
}
empty_files = MultiValueDict()
return (empty_data, empty_files)
parser = self.negotiator.select_parser(self, self.parsers)
if not parser:
raise exceptions.UnsupportedMediaType(media_type)
try:
parsed = parser.parse(stream, media_type, self.parser_context)
except Exception:
self._data = QueryDict('', encoding=self._request._encoding)
self._files = MultiValueDict()
self._full_data = self._data
raise
try:
return (parsed.data, parsed.files)
except AttributeError:
empty_files = MultiValueDict()
return (parsed, empty_files)
可以看到_parse函数获取到了self.content_type(即请求头content_type的参数)。
我们直接看 parser = self.negotiator.select_parser(self, self.parsers)
这里传入了self.parsers(解析器),通过self.negotiator对象的select_parser方法来解析,之后将值返还给parser。
DefaultContentNegotiation类/select_parser函数如下:
def select_parser(self, request, parsers):
for parser in parsers:
if media_type_matches(parser.media_type, request.content_type):
return parser
return None
可以发现此时通过循环parsers来获取到该方法中的media_type请求头,根据支持的请求头,返回该请求头的解析器。
之后我们继续往下走到parsed = parser.parse(stream, media_type, self.parser_context)
可以发现该方法是通过调用子类的parse函数执行的。
rest_framework有内置给我们的解析器,且它们都继承BaseParser类,而该类有一个parse函数,继承BaseParser类的解析器有很多,不过每一个源码流程都相似,这里就通过JSONParser的源码来了解一下解析器的整个流程。
JSONParser/parse函数如下:
def parse(self, stream, media_type=None, parser_context=None):
parser_context = parser_context or {
}
encoding = parser_context.get('encoding', settings.DEFAULT_CHARSET)
try:
decoded_stream = codecs.getreader(encoding)(stream)
parse_constant = json.strict_constant if self.strict else None
return json.load(decoded_stream, parse_constant=parse_constant)
except ValueError as exc:
raise ParseError('JSON parse error - %s' % str(exc))
decoded_stream 接收到的就是request.body返回的参数,然后通过json.load的方式将json格式解析成了字典的形式。
urls.py如下:
from django.conf.urls import url
from api import views
urlpatterns = [
url(r'^api/v1/auth/$', views.AuthView.as_view()),
url(r'^api/v1/order/$', views.OrderView.as_view()),
url(r'^api/(?P[v1|v2]+)/users/$' , views.UsersView.as_view(), name='uuu'),
url(r'^api/(?P[v1|v2]+)/parser/$' , views.ParserView.as_view(), name='ddd'),
]
api/views.py如下:
from rest_framework.parsers import JSONParser
class ParserView(APIView):
authentication_classes = []
permission_classes = []
throttle_classes = []
parser_classes = [JSONParser]
def post(self, request, *args, **kwargs):
print(request.data)
return HttpResponse('ParserView')
此时访问http://127.0.0.1:8000/api/v1/parser/显示:
此时控制台打印:
{
'name': 'sehun', 'age': 18, 'gender': '男'}
而要想每一个都能使用,我们继续将其放入全局配置中。
settings.py如下:
REST_FRAMEWORK = {
"DEFAULT_AUTHENTICATION_CLASSES": ['api.utils.auth.Authtication', ],
"UNAUTHENTICATED_USER": None,
"UNAUTHENTICATED_TOKEN": None,
"DEFAULT_PERMISSION_CLASSES": ['api.utils.permissions.MyPermission', ],
"DEFAULT_THROTTLE_CLASSES": ['api.utils.throttle.UserThrottle'],
"DEFAULT_THROTTLE_RATES": {
"scope": '3/m',
"user": '5/m',
},
"DEFAULT_VERSION": 'v1',
"ALLOWED_VERSIONS": ['v1', 'v2'],
"VERSION_PARAM": 'version',
"DEFAULT_VERSIONING_CLASS": "rest_framework.versioning.URLPathVersioning",
"rest_framework": ["rest_framework.parsers.JSONParser", ["rest_framework.parsers.FormParser"]],
}
此时就会可以全局都配置上解析器了。