Django_restframework源码解析

源码入口

self.dispatch()
这个方法中主要有两个操作,一个是重新封装request,另一个就是执行initialize方法

initialize方法

这里面主要做了这几件事

1. 获取版本信息

version, scheme = self.determine_version(request, *args, **kwargs)
request.version, request.versioning_scheme = version, scheme

进入determine_version方法,

if self.versioning_class is None:
     return (None, None)
scheme = self.versioning_class()
     return (scheme.determine_version(request, *args, **kwargs), scheme)

如果没有配置版本信息,则return None,如果配置了,执行配置类的determine_version方法

这里提供的版本相关的配置有:

BaseVersioning
AcceptHeaderVersioning
URLPathVersioning
NamespaceVersioning
HostNameVersioning
QueryParameterVersioning

这里主要介绍URLPathVersioningQueryParameterVersioning

QueryParameterVersioning

这个配置主要是从get请求的参数里获取版本信息,看源码:

version = request.query_params.get(self.version_param, self.default_version)

URLPathVersioning

这个配置是从url中获取版本信息

version = kwargs.get(self.version_param, self.default_version)

这种方式,需要在路由匹配里面,用一个参数接收版本号

url(r'^(?P\w+)/users/',views.UserView.as_view()),

反向生成url

    def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra):
        if request.version is not None:
            kwargs = {} if (kwargs is None) else kwargs
            kwargs[self.version_param] = request.version

        return super(URLPathVersioning, self).reverse(
            viewname, args, kwargs, request, format, **extra
        )

2. 获取用户认证

源码逻辑

这个是从initial方法中的self.perform_authentication(request)中看起

首先来到perform_authentication方法,只返回了一个request.user

再选择rest_framework的request中的user

    def _authenticate(self):
        """
        Attempt to authenticate the request using each authentication instance
        in turn.
        """
        for authenticator in self.authenticators:
            try:
                user_auth_tuple = authenticator.authenticate(self)
            except exceptions.APIException:
                self._not_authenticated()
                raise

            if user_auth_tuple is not None:
                self._authenticator = authenticator
                self.user, self.auth = user_auth_tuple
                return

        self._not_authenticated()

发现这个里面,如果有user_auth_tuple这个元组,就直接返回了这个元组的内容,用户验证的默认配置是ForcedAuthentication

if force_user is not None or force_token is not None:
    forced_auth = ForcedAuthentication(force_user, force_token)
    self.authenticators = (forced_auth,)

这是个强制认证,一般用在测试中,此时不必进行身份验证

如果我们需要使用用户认证,就得修改这个配置类,所用的配置类必须要有authenticate方法,但是这里并没有找到更进一步的信息,所以要退回去重新找。先来看一下这里最后认证不成功的处理:

def _not_authenticated(self):
    """
    Set authenticator, user & authtoken representing an unauthenticated request.

    Defaults are None, AnonymousUser & None.
    """
    self._authenticator = None

    if api_settings.UNAUTHENTICATED_USER:
        self.user = api_settings.UNAUTHENTICATED_USER()
    else:
        self.user = None

    if api_settings.UNAUTHENTICATED_TOKEN:
        self.auth = api_settings.UNAUTHENTICATED_TOKEN()
    else:
        self.auth = None

这里有两条配置,UNAUTHENTICATED_USERUNAUTHENTICATED_TOKEN

这两个就是配置验证不通过时,也就是用户没有登录时的配置,可以配置成匿名用户,如果不配置,就会返回None

回到开始,从封装request里面,看能不能找到想要的

看到这里有一个操作,authenticators=self.get_authenticators(),让人充满希望

def get_authenticators(self):
    return [auth() for auth in self.authentication_classes]

果然在这里,看到一个东西加括号,可以联想到这几种情况,函数加括号执行,类加括号实例化。

ps,其实这里还应该看到这个文件的整齐性,能看到的函数,都是这种形式的返回值,可以留意一下

那我们先看看authentication_classes里面是什么

是配置信息authentication_classes = api_settings.DEFAULT_AUTHENTICATION_CLASSES

此时我们可以打印一下这个authentication_classes,看看有没有默认值,打印之后发现真的有默认值:

[, ]

这就好说了,我们可以到这个默认的配置里看看,

from rest_framework.authentication import BasicAuthentication

哈,这里有个authenticate方法,这就和前面对上了,同时也说明了前面那个auth()是实例化对象

看一下这个文件,提供了这么几个配置

BaseAuthentication
BasicAuthentication
SessionAuthentication
TokenAuthentication
RemoteUserAuthentication

这些配置最后都是返回了一个两个元素的元组,第一个元素是user,第二个是token

如果验证不成功,返回None表示支持匿名用户,抛出异常表示不支持匿名用户

if token_obj:
    return (token_obj.user, token_obj)
else:
    raise exceptions.AuthenticationFailed('请登录')    #抛出异常表示不允许匿名用户访问

好了,到此为止,关于用户认证的源码也基本上明白了

扩展

用户登录,会获得一个token,再次登录,就会替换掉这个token,如何实现多设备登陆呢?

在数据库中,token表是和用户表做OnetoOne关联的,此时只能有一个设备登陆,如果我们改成ManytoMany,就可以实现多设备登陆了

3. 获取用户权限信息

回到最开始的initialize方法,用户验证的下面就是获取权限的代码

self.perform_authentication(request)
self.check_permissions(request)  # 权限
self.check_throttles(request)

然后我们就来到了这:

def get_permissions(self):
    return [permission() for permission in self.permission_classes]

和用户验证同样的套路

提供的配置还挺多的,以BasePermission为例吧

class BasePermission(object):
    """
    A base class from which all permission classes should inherit.
    """

    def has_permission(self, request, view):
        """
        Return `True` if permission is granted, `False` otherwise.
        授予权限返回True,否则返回False
        """
        return True

    def has_object_permission(self, request, view, obj):
        """
        Return `True` if permission is granted, `False` otherwise.
        """
        return True

所以,这两个方法都是,有权限返回True,没有权限返回False,具体逻辑需要我们自己写

4. 限制用户访问频率

源码逻辑
self.perform_authentication(request)
self.check_permissions(request)  # 权限
self.check_throttles(request)   # 限制访问频率

也是一样的套路,看一下配置,发现后面几个都继承了SimpleRateThrottle

代码太多就不贴了

基本原理就是,首先要获取用户的唯一标识,通过这个get_cache_key函数,默认这个函数是抛出了一个异常的,所以如果用到这个函数,就必须重写这个方法。

如果没有获取到用户的唯一标识,直接返回True,表示不限制访问频率

将用户的访问时间,记录到一个列表中(以用户的唯一标识为key,访问记录为这个列表),访问一次记录一次(插到列表第一个位置),如果列表中的最后一条记录(最早的记录)的时间超过当前时间减去限制的时间,就从列表中剔除,否则插入到列表,当列表长度超过限制的访问次数,抛出异常。

主要的逻辑代码在这:

def allow_request(self, request, view):
    """
    Implement the check to see if the request should be throttled.

    On success calls `throttle_success`.
    On failure calls `throttle_failure`.
    """
    if self.rate is None:
        return True

    self.key = self.get_cache_key(request, view)  # 获取用户唯一标识
    if self.key is None:
        return True  # 返回True就不执行下面的了

    self.history = self.cache.get(self.key, [])
    self.now = self.timer()

    # Drop any requests from the history which have now passed the
    # throttle duration
    while self.history and self.history[-1] <= self.now - self.duration:
        self.history.pop()
    if len(self.history) >= self.num_requests:
        return self.throttle_failure()
    return self.throttle_success()

对于登录用户,可以用他的用户名标识身份,如果是匿名用户,如何标识身份呢?源码中AnonRateThrottle里为我们提供了一种方案就是使用用户的ip,

return self.cache_format % {
            'scope': self.scope,
            'ident': self.get_ident(request)
        }

这里说明了配置应该怎么写:

    def parse_rate(self, rate):
        """
        Given the request rate string, return a two tuple of:
        , 
        """
        if rate is None:
            return (None, None)
        num, period = rate.split('/')
        num_requests = int(num)
        duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
        return (num_requests, duration)

需要注意,这个scope = None默认是None,使用的时候需要指定,比如登录用户设置成'user',匿名用户设置成'anon'

那配置就可以这样写:

REST_FRAMEWORK = {
    'DEFAULT_THROTTLE_RATES': {
            'user': '5/min',    
            'anon': '2/min'
        }
}
扩展

默认的情况是,限制访问频率是对站点中所有页面的限制,是访问站点的频率。如果想限制某个用户对某一个页面的访问频率,可以怎么做?

本质上就是设置用户的唯一标识,可以把唯一标识设置成用户+指定页面的名字

指定页面的名字,其实就是CBV中,对应的视图类的名字,所以

self.key = request.user.user + view.__class__.__name__

关于csrf的处理

其实在路由匹配中,就已经做了处理

看一下这个as_view()方法

@classmethod
    def as_view(cls, **initkwargs):
        """
        Store the original class on the view function.

        This allows us to discover information about the view when we do URL
        reverse lookups.  Used for breadcrumb generation.
        """
        if isinstance(getattr(cls, 'queryset', None), models.query.QuerySet):
            def force_evaluation():
                raise RuntimeError(
                    'Do not evaluate the `.queryset` attribute directly, '
                    'as the result will be cached and reused between requests. '
                    'Use `.all()` or call `.get_queryset()` instead.'
                )
            cls.queryset._fetch_all = force_evaluation

        view = super(APIView, cls).as_view(**initkwargs)
        view.cls = cls
        view.initkwargs = initkwargs

        # Note: session based authentication is explicitly CSRF validated,
        # all other authentication is CSRF exempt.
        return csrf_exempt(view)

最后默认就禁止了csrf

这里还为我们提供了一种思路,就是CBV的装饰器还可以这么用csrf_exempt(view),其实这也是装饰器的本质

注意是因为在CBV中,CBV中的装饰器才能这么用

CBV装饰器的三种写法,可以看看这篇文章

解析器

关于request.POST 里面有没有东西的讨论,在这里已经说过,

在rest framework里面,怎么处理呢? restful里面有个解析器

先看看封装request那里面,还有个negotiator=self.get_content_negotiator(),

其实这里面就是处理的解析器

配置在这里

content_negotiation_class = api_settings.DEFAULT_CONTENT_NEGOTIATION_CLASS

一般情况下不用自定制,如果需要的话这样写:

from rest_framework.renderers import JSONRenderer, TemplateHTMLRenderer, BrowsableAPIRenderer, MultiPartRenderer
parser_classes = [JSONRenderer, TemplateHTMLRenderer, BrowsableAPIRenderer, MultiPartRenderer]

只有在调用request.data 的时候,才会触发这些配置

序列化/验证

从数据库直接查询到的结果,是queryset类型,如何把这个queryset发送给前端,在Django中可以用queryset.values(),然后list转成列表发送给前端,在rest framework里面,有个serializers

这个序列化有两大功能,一是将用户请求的数据序列化,二是对用户提交的数据进行验证

手动版本:

from rest_framework import serializers

class PasswordValidator(object):
    #密码字段的复杂验证
    def __call__(self,value):
        """

        :param vlaue: 用户传来的数据
        :return:
        """
        try:
            int(value)  #这个只是举个例子
        except Exception as e:
            message = '密码必须是数字'
            raise serializers.ValidationError(message)


class UserSerialize(serializers.Serializer):
    """
    要先写这个类,定义字段,字段是要取的数据表的字段
    """
    user = serializers.CharField()
    pwd = serializers.CharField(validators=[PasswordValidator()])#复杂验证
    email = serializers.CharField()
    user_type_id = serializers.IntegerField()
    # ug = serializers.CharField(source='ug.title')#跨表查询用source
    ug = serializers.CharField(source='ug.title',required=False)#表示这个字段可以为空
class CustomSerialize(APIView):
    def get(self, request, *args,**kwargs):
        user_list = models.UserInfo.objects.filter()
        # print(user_list.first().ug)
        ser = UserSerialize(instance=user_list,many=True)   # 如果是一个子段,many就等于False,多个字段就用True
        # user_obj = models.UserInfo.objects.filter().first()
        # ser = UserSerialize(instance=user_obj,many=False)
        return Response(ser.data)       #.data方法就拿到了所有的数据
        #这个Response 是restframework的返回值    from rest_framework.response import Response

    def post(self, request, *args,**kwargs):
        #序列化的第二大功能,对用户输入的数据进行合法性验证

        ser = UserSerialize(data=request.data)#提交数据就不能用对象了,而应该是用data参数
        if ser.is_valid():
            #返回提交的post请求的数据
            print(ser.validated_data)#OrderedDict([('user', 'zhang'), ('pwd', '123'), ('email', '[email protected]'), ('user_type_id', 2), ('ug', {'title': '2'})])  返回结果是一个有序字典
            print(request.data)#
            return Response(ser.validated_data)#页面就显示刚才提交的数据
        else:

            return Response(ser.errors)#用户数据不通过验证,则返回错误信息

注意最上面的那个复杂验证,用到了__call__方法,因为源码里有用到示例对象加括号的操作,所以想到了__call__,事实证明有效

还有个ModelSerializer,类似Django中的modelForm,可以帮我们封装自定义字段的步骤,上面例子中的UserSerialize可以这么写

class UserSerialize(serializers.ModelSerializer):
    class Meta:
        model = models.UserInfo
        fields = '__all__'  # 表示获取所有字段

更多操作:

# class UserSerialize(serializers.ModelSerializer):
#     #还可以添加字段,已有就更新,没有的就新增
#     # user = serializers.CharField()      #可以加简单的限制
#     ug = serializers.HyperlinkedIdentityField(view_name='xxx')#这个view_name 就是反向生成的name
#     v = serializers.CharField()         #发送post请求的时候,可以多添加这个参数,数据表中已有这个字段则覆盖,
#                                             # 但是注意如果要创建数据,要把这条记录删掉,否则会报错,可以用pop拿出来
#
#     class Meta:
#         model = models.UserInfo
#         fields = '__all__'
#
#         extra_kwargs = {
#             'user':{'min_length':2}     #这个限制和上面那个效果是一样的,这个的意义是添加复杂认证
#         }
#         depth = 2       #跨表的深度也可以设置

分页

from rest_framework.pagination import PageNumberPagination, LimitOffsetPagination, CursorPagination
class UserPageSerialize(serializers.ModelSerializer):
    class Meta:
        model = models.UserInfo
        fields = '__all__'

说一说这三种分页的配置:

class StandardResultsSetPagination(PageNumberPagination):

    #根据页码分页
    page_size = 1   #默认每页显示的数据条数

    page_size_query_param = 'page_size'   #获取url参数中设置的每页 显示数据条数

    page_query_param = 'page'   #获取url参数中传入的页码key

    # max_page_size = 1 #最大支持的每页显示的数据条数
class StandardResultsSetPagination(LimitOffsetPagination):
    # 根据位置和个数分页
    default_limit = 10  # 默认每页显示的数据条数

    limit_query_param = 'limit'  # URL中传入的显示数据条数的参数,是从0开始的

    offset_query_param = 'offset'  # URL中传入的数据位置的参数

    max_limit = None  # 最大每页显得条数
class StandardResultsSetPagination(CursorPagination):
    # 游标分页,只能点上一页,下一页

    cursor_query_param = 'cursor'  # URL传入的游标参数

    page_size = 1  # 默认每页显示的数据条数

    page_size_query_param = 'page_size'  # URL传入的每页显示条数的参数

    max_page_size = 1000  # 每页显示数据最大条数

    ordering = "id"  # 根据ID从大到小排列

最终返回数据的代码

class UserViewSet(APIView):
    def get(self, request, *args, **kwargs):
        paginator = StandardResultsSetPagination()  # 获取上面设置的页码配置
        user_list = models.UserInfo.objects.filter().order_by('-id')
        page_user_list = paginator.paginate_queryset(user_list, request, self)  # 这个要拿到每页的数据

        ser = UserPageSerialize(instance=page_user_list, many=True)
        print(user_list)
        # return Response(ser.data)#这种只会返回数据,下面这种还能返回上一页和下一页
        response = paginator.get_paginated_response(ser.data)  # 这样就会有上一页下一页的连接了
        return response

这三种页码中,重点说一下游标这个,这里面涉及一个加密解密的过程

如果使用这种配置,页面中将不能直接跳到多少页,只能乖乖的一页一页点

为什么要有这种设定呢?

分页有个性能问题1亿条数据,如果用offset和limit如果拿第9000万条数据,会把之前所有的都读一遍所以,页码越靠后,越慢
所以想办法不让他扫描,就是这个游标(cursor)

记录上一页的最后一个下标,找到时候就直接从那个下标开始找所以有的网站只有上一页和下一页

比如

select * from offset 220 limit 5

从第220条开始找5条

但是这种是一条一条找到第220条数据的,然后开始继续找5条,这样就会有性能问题

然而cursor ,是这样的: select * from offset where id>220 offfset 0 limit 5

这样就不会一条一条找前220条数据了

路由

这个其实涉及挺多的,一点一点说吧

1. 手动写路由和视图

我们知道,如果是单纯访问数据,应该支持url最后是以.json这种格式结尾的,这个怎么实现呢?

url(r'^ser\.(?P\w+)$', views.CustomSerialize.as_view())

这种就表示带不带后缀都能访问到数据,注意这里只能用format作为正则分组的名字,否则要改配置

但是现在这个路由有个问题,能处理查询和新增,但是不能做修改和删除,因为没有id,所以还需要再增加一个url

url(r'^test\.(?P\w+)/(?P\d+)', views.RouteView.as_view())

然后查询数据的请求再来的时候,视图中就可以判断参数中有没有id,有的话就是就返回指定的数据,没有就返回全部数据

最终需要的路由有这四条:

url(r'^test\.(?P\w+)', views.RouteView.as_view()),

url(r'^test/', views.RouteView.as_view()),

url(r'^test/(?P\d+)/', views.RouteView.as_view()),
url(r'^test\.(?P\w+)/(?P\d+)', views.RouteView.as_view()),

视图中:

class RouteView(APIView):
    # authentication_classes = [CustomAuthenticate,]          #应用认证配置类

    def get(self,request,*args,**kwargs):
        pk = kwargs.get('pk')
        if not pk:
            user_list = models.UserInfo.objects.filter()
            ser = TestServerlizer(instance=user_list,many=True)     #many=True表示多条
        else:
            obj = models.UserInfo.objects.filter(pk=pk)
            ser = TestServerlizer(instance=obj, many=False)

        return Response(ser.data)           #ser.data 才能拿到数据
2. GenericAPIView

路由匹配不改,视图中可以简化一点

需要先导入这个类

from rest_framework.generics import GenericAPIView
class RouteView(GenericAPIView):#要继承这个类
    queryset = models.UserInfo.objects.all()
    serializer_class = RouteSerializer
    pagination_class = StandardResultsSetPagination
    def get(self,request,*args,**kwargs):
        user_list = self.get_queryset()
        page_user_list = self.paginate_queryset(user_list)

        ser = self.get_serializer(instance=page_user_list,many=True)
        response = self.get_paginated_response(ser.data)#还内置了分页的方法

        return response

这种方法其实没啥用,一点都没有简化,但是也算是个过程,开始不用继承APIView了

3. GenericViewSet

这个就可以简化路由了,并加入增删改查方法

视图中

url(r'^test', views.RouteView.as_view({'get':'list','post':'create'})),
url(r'^test/(?P\d+)', views.RouteView.as_view({'get': 'retrieve', 'put': 'update', 'patch': 'partial_update', 'delete': 'destroy'}))

增删改查都需要导入类

from rest_framework.viewsets.mixins import ListModelMixin,CreateModelMixin,UpdateModelMixin,DestroyModelMixin
class RouteView(GenericViewSet,ListModelMixin,CreateModelMixin,UpdateModelMixin,DestroyModelMixin):#要继承这个类
    queryset = models.UserInfo.objects.all()
    serializer_class = RouteSerializer
    pagination_class = StandardResultsSetPagination

此时不用我们手动写增删改查方法了,只需要导入类就行了

4. ModelViewSet

这种就真的是完全自动了

路由:

首先在urlpartterns之前,注册好url访问的路径

from rest_framework.routers import DefaultRouter
route = DefaultRouter()
route.register('abc',views.RouteView)

比如访问的路径是abc,不管后面有没有id,有没有format,都可以识别

路由里只需要写一句话就行

url(r'^',include(route.urls)),

然后视图里面也很简单,继承一个ModelViewSet就行了

from rest_framework.viewsets import ModelViewSet


class RouteSerializer(serializers.ModelSerializer):
    class Meta:
        model = models.UserInfo
        fields = '__all__'


class RouteView(ModelViewSet):  
    queryset = models.UserInfo.objects.all()
    serializer_class = RouteSerializer

实际上这ModelViewSet内部帮我们封装了增删改查那几个类

源码中:

class ModelViewSet(mixins.CreateModelMixin,
                   mixins.RetrieveModelMixin,
                   mixins.UpdateModelMixin,
                   mixins.DestroyModelMixin,
                   mixins.ListModelMixin,
                   GenericViewSet):
    """
    A viewset that provides default `create()`, `retrieve()`, `update()`,
    `partial_update()`, `destroy()` and `list()` actions.
    """
    pass

但这种封装度这么高,也注定了他并不会被经常使用

渲染

这个是根据url,使用合适的渲染组件

配置在这里:

from rest_framework.renderers import JSONRenderer,AdminRenderer,HTMLFormRenderer,TemplateHTMLRenderer

一般就用JSON这个就行了,其他的可以了解一下

在视图类里面:

renderer_classes = [JSONRenderer, ]

注意路由中url的格式,要写成这样:

url(r'^test\.(?P[a-z0-9]+)', views.RouteView.as_view()),

浏览器访问的时候,url是这样的:

http://127.0.0.1:8000/test/?format=json
http://127.0.0.1:8000/test.json
http://127.0.0.1:8000/test/ 

这三种都行

其他几种写法都类似:

AdminRenderer 是表格形式,HTMLFormRenderer是form表单,

TemplateHTMLRenderer是自定义的模板

你可能感兴趣的:(Django_restframework源码解析)