rest_framework其他相关组件及源码分析

rest_framework视图组件

总共有四种方式

原始代码

class PublishSer(serializers.ModelSerializer):
    class Meta:
        model = models.Publish
        fields = '__all__'
class Publish(APIView):
    def get(self,request):
        ret = models.Publish.objects.all()
        ser = PublishSer(instance=ret,many=True)
        return Response(ser.data)

    def post(self,request):
        back_msg = {'status':0,'data':None,'msg':'错误'}
        ser = PublishSer(data=request.data)
        if ser.is_valid():
            ser.save()
            back_msg['status'] = 1
            back_msg['data'] = ser.data
            back_msg['msg'] = '创建成功'
        return Response(back_msg)


class PublishDetail(APIView):
    def get(self,request,pk):
        ret = models.Publish.objects.filter(pk=pk).first()
        ser = PublishSer(instance=ret,many=False)
        return Response(ser.data)

    def put(self,request,pk):
        ret = models.Publish.objects.filter(pk=pk).first()
        ser = PublishSer(instance=ret, data=request.data)
        if ser.is_valid():
            ser.save()
            return Response(ser.data)

    def delete(self,request,pk):
        models.Publish.objects.filter(pk=pk).delete()
        return Response('')

方式一(用的最多)

class List():
    def list(self, request):
        ret = self.query_set.objects.all()
        ser = self.serializer_class(instance=ret, many=True)
        return Response(ser.data)


class Create():
    def create(self, request):
        back_msg = {'status': 0, 'data': None, 'msg': '错误'}
        ser = self.serializer_class(data=request.data)
        if ser.is_valid():
            ser.save()
            back_msg['status'] = 1
            back_msg['data'] = ser.data
            back_msg['msg'] = '创建成功'
        return Response(back_msg)


class Publish(APIView, List, Create):
    query_set = models.Publish
    serializer_class = PublishSer
    def get(self, request):
        return self.list(request)

    def post(self, request):
        return self.create(request)

方式二

from rest_framework.mixins import *
from rest_framework.generics import GenericAPIView
class Publish(GenericAPIView,CreateModelMixin,ListModelMixin):
    queryset = models.Publish.objects.all()
    # queryset = models.Publish.objects
    serializer_class = PublishSer
    def get(self, request):
        return self.list(request)
    def post(self, request):
        return self.create(request)

方式三

from rest_framework.mixins import *
from rest_framework.generics import GenericAPIView,ListCreateAPIView,RetrieveUpdateDestroyAPIView
class Publish(ListCreateAPIView):
    queryset = models.Publish.objects.all()
    serializer_class = PublishSer
class PublishDetail(RetrieveUpdateDestroyAPIView):
    # 查,删,更新
    queryset = models.Publish.objects.all()
    serializer_class = PublishSer

前三种方式的路由设计

#原始方式
url(r'^publish/$', views.Publish.as_view()),
url(r'^publish/(?P\d+)/$', views.PublishDetail.as_view()),

方式四

from rest_framework.viewsets import ModelViewSet
class Publish(ModelViewSet):
    queryset = models.Publish.objects.all()
    serializer_class = PublishSer
    # get
    # post
    # get(查一条)
    # put
    # delete

方式四路由设计

#半自动化路由

走的是ModelViewSet-->GenericViewSet-->ViewSetMixin的as_view(cls, actions=None, **initkwargs)方法,可以传参
    # actions===>传的参数{'get':'list','post':'create'}

    # 源码分析
    # for method, action in actions.items():  #  method='get',action='list'
    #     handler = getattr(self, action)
    #     setattr(self, method, handler)
    url(r'^publish/$', views.Publish.as_view({'get':'list','post':'create'})),
    url(r'^publish/(?P\d+)/$', views.Publish.as_view({'get':'retrieve','put':'update','delete':'destroy'})),

路由

'''
原始方式(用的最多)
url(r'^publish/$', views.Publish.as_view()),
url(r'^publish/(?P\d+)/$', views.PublishDetail.as_view()),


半自动路由(需要继承ModelViewSet)
url(r'^publish/$', views.Publish.as_view({'get':'list','post':'create'})),
url(r'^publish/(?P\d+)/$', views.Publish.as_view({'get':'retrieve','put':'update','delete':'destroy'})),

全自动化路由(需要继承ModelViewSet)
from django.conf.urls import url,include
from rest_framework import routers
router = routers.DefaultRouter()
router.register('publish',views.Publish)

    url(r'api/v1/', include(router.urls)),

^publish/$ [name='publish-list']
^publish\.(?P[a-z0-9]+)/?$ [name='publish-list']
^publish/(?P[^/.]+)/$ [name='publish-detail']
^publish/(?P[^/.]+)\.(?P[a-z0-9]+)/?$ [name='publish-detail']
^$ [name='api-root']
^\.(?P[a-z0-9]+)/?$ [name='api-root']
'''

用的最多是APIview及ViewSetMixin。

认证,频率,权限组件分析

一.认证

models.py文件
class UserInfo(models.Model):
    name = models.CharField(max_length=32)
    pwd = models.CharField(max_length=32)
    # 权限会用的到
    choices = ((1, '超级用户'), (2, '普通用户'), (3, '二笔用户'))
    user_type = models.IntegerField(choices=choices,null=True)


class UserToken(models.Model):
    user = models.OneToOneField(to=UserInfo)
    token = models.CharField(max_length=64)
utils文件下的common.py写个认证类
class MyAuth(BaseAuthentication):
    def authenticate(self,request):
        # 取出token然后与数据库对比
        # request.GET.get('token')
        token = request.query_params.get('token')
        ret = UserToken.objects.filter(token=token).first()
        if ret:
            # 认证通过
            # 返回的两个值对应的是 self.user, self.auth,详见Request源码384行
            return ret.user,ret.user.name
        else:
            # 没验证通过抛异常
            # raise APIException('认证失败')
            raise AuthenticationFailed('认证失败')

    # 继承了BaseAuthentication就不用写这个方法了
    # def authenticate_header(self,request):
    #     pass
views.py
from django.shortcuts import render,HttpResponse,redirect
import json
from django.http import JsonResponse
from rest_framework.views import APIView
from utils.common import *
from app01.models import *
from rest_framework.response import Response
from rest_framework.exceptions import APIException,AuthenticationFailed
from rest_framework.authentication import BaseAuthentication


class Login(APIView):
    # 如果配置了全局,登录也会受到影响,所以在自己这里配置一个空的就可以了
    authentication_classes = []
    def post(self,request,*args,**kwargs):
        # 拿到状态码等信息
        response = GetSomething()
        # 拿到用户登录信息
        name = request.data.get('name')
        pwd = request.data.get('pwd')
        # 去数据库拿到用户信息
        user = UserInfo.objects.filter(name=name,pwd=pwd).first()
        if user:
            # 生成一个随机字符串
            token = get_token(name)
            # 如果不存在会创建,如果存在会更新
            UserToken.objects.update_or_create(user=user,defaults={'token':token})
            response.status=200
            response.mag='登录成功'
            # 将随机字符串传给前端
            response.token=token
            # "token": "8028525238b2619e2c0682a4c1fbff63"
            # token=8028525238b2619e2c0682a4c1fbff63
        else:
            response.mag = '用户名或密码错误'
        return Response(response.get_dic())


class Course(APIView):
    # 先走认证,频率,权限再走get等方法,如果认证,频率,权限都没通过,会直接返回回去
    # 认证组件
    # authentication_classes=[认证类,认证类...]

    # 详见Request源码374行,发现是一个for循环,一旦有返回值,循环就结束掉了,如果有返回值的那个认证类必须写在最后一个
    # authentication_classes=[MyAuth1,MyAuth2]
    authentication_classes=[MyAuth,]
    permission_classes = [MyPermission,]
    throttle_classes = [MyThor,]

    def get(self,request,*args,**kwargs):
        print(request.user)  # UserInfo object
        print(request.auth)  # tom
        return JsonResponse({'name':'python'})

utils文件下的common.py
import hashlib
import time
class GetSomething:
    def __init__(self):
        self.status = 403
        self.mag = None

    # 返回字典格式数据
    def get_dic(self):
        return self.__dict__

# 生产随机字符串
def get_token(name):
    m = hashlib.md5()
    m.update(name.encode('utf-8'))
    m.update(str(time.time()).encode('utf-8'))

    return m.hexdigest()

源码分析

APIView—>APIView的dispatch方法—>到dispatch方法下的self.initial(request, *args, **kwargs)---->APIView的initial方法下的self.perform_authentication(request),可以看到这个方法执行了request.user,这是Request类下的一个方法通过@property伪装成了属性,进入这个方法----->执行self._authenticate()---->for authenticator in self.authenticators这里的authenticators是不是很熟悉,就是initialize_request这个方法下Request实例化时authenticators=self.get_authenticators(),---->执行了get_authenticators(),最后执行了authenticate方法。

rest_framework其他相关组件及源码分析_第1张图片

rest_framework其他相关组件及源码分析_第2张图片

rest_framework其他相关组件及源码分析_第3张图片

rest_framework其他相关组件及源码分析_第4张图片

局部使用

# 详见Request源码374行,发现是一个for循环,一旦有返回值,循环就结束掉了,如果有返回值的那个认证类必须写在最后一个
# authentication_classes=[MyAuth1,MyAuth2]

全局使用

查找顺序:自定义的APIView里找--->项目settings里找---->内置默认的

settings配置:
REST_FORMEWORK = {
'DEFAULT_AUTHENTICATION_CLASSES':['utils.common.MyAuth']
}
不存数据库验证token
def get_token(id, salt='123'): # salt 加盐,一般写在配置里,且不要写死
    import hashlib
    md = hashlib.md5()
    md.update(bytes(str(id), encoding='utf-8'))
    md.update(bytes(salt, encoding='utf-8'))

    return md.hexdigest() + '|' + str(id)


def check_token(token, salt='123'):
    ll = token.split('|')
    import hashlib
    md = hashlib.md5()
    md.update(bytes(ll[-1], encoding='utf-8'))
    md.update(bytes(salt, encoding='utf-8'))
    if ll[0] == md.hexdigest():
        return True
    else:
        return False


class Login(APIView):
    def post(self, request):
        back_msg = {'status': 1001, 'msg': None}
        try:
            name = request.data.get('name')
            pwd = request.data.get('pwd')
            user = UserInfo.objects.filter(name=name, pwd=pwd).first()
            if user:
                token = get_token(user.pk)
                back_msg['status'] = '1000'
                back_msg['msg'] = '登录成功'
                back_msg['token'] = token
            else:
                back_msg['msg'] = '用户名或密码错误'
        except Exception as e:
            back_msg['msg'] = str(e)
        return Response(back_msg)


from rest_framework.authentication import BaseAuthentication
class TokenAuth(BaseAuthentication):
    def authenticate(self, request):
        token = request.GET.get('token')
        token_obj = UserToken.objects.filter(token=token).first()
        if token_obj:
            return
        else:
            raise AuthenticationFailed('认证失败')
    # def authenticate_header(self,request):
    #     pass


class Course(APIView):
    authentication_classes = [TokenAuth, ]

    def get(self, request):
        return HttpResponse('get')

    def post(self, request):
        return HttpResponse('post')

二.权限组件

models.py
class UserInfo(models.Model):
    name = models.CharField(max_length=32)
    pwd = models.CharField(max_length=32)
    # 权限会用的到
    choices = ((1, '超级用户'), (2, '普通用户'), (3, '二笔用户'))
    user_type = models.IntegerField(choices=choices,null=True)


class UserToken(models.Model):
    user = models.OneToOneField(to=UserInfo)
    token = models.CharField(max_length=64)

utils文件下的common.py写个权限类
# 能走到权限类,说明用户肯定是登录了
from rest_framework.permissions import BasePermission
class MyPermission(BasePermission):
    message = '不是超级用户,查看不了'
    def has_permission(self,request,view):
        token = request.query_params.get('token')
        ret = UserToken.objects.filter(token=token).first()
        # 取出对应的中文介绍
        # ret.user.get_user_type_display() 
        if ret.user.user_type == 1:
            # 超级用户可以访问
            return True
        else:
            return False
局部使用
permission_classes = [MyPermission,]
全局使用
查找顺序:自定义的APIView里找--->项目settings里找---->内置默认的

settings配置:
REST_FORMEWORK = {
'DEFAULT_PERMISSION_CLASSES':['utils.common.MyPermission']
}

源码分析

与认证差不多,最大差别是方法has_permission(self,request,view)

rest_framework其他相关组件及源码分析_第5张图片

三.频率组件

局部使用
throttle_classes = [MyThor,]
全局使用
查找顺序:自定义的APIView里找--->项目settings里找---->内置默认的

settings配置:
REST_FORMEWORK = {
'DEFAULT_THROTTLE_CLASSES':['utils.common.MyThor',],
'DEFAULT_THROTTLE_RATES':{
        # 一分钟访问5次,可以修改
        'luffy':'5/m'
    }
}
utils文件下的common.py写个权限类
from rest_framework.throttling import SimpleRateThrottle
class MyThor(SimpleRateThrottle):
    scope = 'luffy'

    # 关键在于传一个能唯一标识用户的数据(源码里是保存到缓存内的,我们自定义可以保存到字典)
    # 也可以做别的校验,只要是唯一的
    # {'ip':[]}
    # {'id':[]}

    def get_cache_key(self, request, view):
        # 看源码知道self.get_ident(request)返回的是IP地址,也就相当于下面两句代码:
        # ip = request.META.get('REMOTE_ADDR')
        # return ip

        return self.get_ident(request)

        # 如果要求只允许用户一分钟访问5次
        # 先取出用户的ID。通过源码知道先走认证再权限再到频率,
        # 能走到频率这一步说明认证已经通过了,认证通过返回两个值,所以可以取出用户的ID了
        # id = request.user.pk
        # return id

源码分析

rest_framework其他相关组件及源码分析_第6张图片
在这里插入图片描述
rest_framework其他相关组件及源码分析_第7张图片

在这里插入图片描述

rest_framework其他相关组件及源码分析_第8张图片

rest_framework其他相关组件及源码分析_第9张图片

rest_framework其他相关组件及源码分析_第10张图片

自定义频率类,自定义频率规则(一分钟最多访问三次)

class MyThrottles():
    VISIT_RECORD = {}
    def __init__(self):
        self.history=None
    # 频率限制的逻辑
    def allow_request(self,request, view):
        #(1)取出访问者ip
        # print(request.META)
        ip=request.META.get('REMOTE_ADDR')
        import time
        ctime=time.time()
        # (2)判断当前ip不在访问字典里,添加进去,并且直接返回True,表示第一次访问
        if ip not in self.VISIT_RECORD:
            self.VISIT_RECORD[ip]=[ctime,]
            return True
        self.history=self.VISIT_RECORD.get(ip)
        # (3)循环判断当前ip的列表,有值,并且当前时间减去列表的最后一个时间大于60s,把这种数据pop掉,这样列表中只有60s以内的访问时间,
        while self.history and ctime-self.history[-1]>60:
            self.history.pop()
        # (4)判断,当列表小于3,说明一分钟以内访问不足三次,把当前时间插入到列表第一个位置,返回True,顺利通过
        # (5)当大于等于3,说明一分钟内访问超过三次,返回False验证失败
        if len(self.history)<3:
            self.history.insert(0,ctime)
            return True
        else:
            return False
    #返回一个数字,给用户提示还是多少秒
    def wait(self):
        import time
        ctime=time.time()
        return 60-(ctime-self.history[-1])

四.版本控制

  • 127.0.0.1/test1/?version=v1/
from rest_framework.versioning import QueryParameterVersioning
class Test1(APIView):
    versioning_class=QueryParameterVersioning #是个类,不是列表
    def get(self,request,*args,**kwargs):
        print(request.version)
        return HttpResponse('ok')
  • 127.0.0.1/v1/test2/
class Test2(APIView):
    versioning_class=URLPathVersioning
    def get(self,request,*args,**kwargs):
        print(request.version)
        return Response('ok')
settings.py
REST_FARMEWORK = {
	'DEFAULT_VERSIONING_CLASS':'rest_framework.versioning.QueryParameterVersioning',#这是全局配置
    'VERSION_PARAM': 'version',
    'DEFAULT_VERSION':'v1',
    'ALLOWED_VERSIONS':['v1','v2'],
}
反向解析
url1 = request.versioning_scheme.reverse(viewname='ttt',request=request)

url(r'^(?P[v1|v2]+)/test2/$', views.Test2.as_view(),name='ttt'),

        

五.响应器

项目中用:(返回的格式,只是json格式)

REST_FARMEWORK = {
    'DEFAULT_RENDERER_CLASSES':['rest_framework.renderers.JSONRenderer']
}

查找模板的时候:先从自己app里找,找不到去项目,再找不到去各个app里找。

六.分页器

简单分页
127.0.0.1/course/page=3
偏移分页
127.0.0.1/course/offset=10&limit=5
加密分页
后台返回的url:127.0.0.1/course/page=sdaseq
class BookSer(ModelSerializer):
    class Meta:
        model = models.Book
        fields = '__all__'
from rest_framework.pagination import PageNumberPagination, CursorPagination, LimitOffsetPagination

'''
class Book(APIView):
    def get(self,request):
        ret = models.Book.objects.all()
        # 实例化一个对象
        mypage = PageNumberPagination()

        mypage.page_size = 2
        mypage.page_size_query_param = 'size'
        mypage.max_page_size = 5

        page_list = mypage.paginate_queryset(ret,request,self)
        ser = BookSer(instance=page_list,many=True)

        
        # 1.在settings里面配置参数
        # 2.写一个类,继承它,重写属性
        # 3.在对象里修改
        
         # 每页显示个数
         # page_size = api_settings.PAGE_SIZE 
         # 页码的参数的key,例如:127.0.0.1/books/?page=1
         # page_query_param = 'page'     
         # 指定每页显示条数,127.0.0.1/books/?page=1&size=3
         # page_size_query_param=None
         # 最大显示条数
         # max_page_size = None  
    
        return Response(ser.data)
        # 页面多了总条数,上一页和下一页的链接
        return mypage.get_paginated_response(ser.data)
'''

'''
class Book(APIView):
    def get(self, request):
        ret = models.Book.objects.all()
        # 实例化一个对象
        mypage = LimitOffsetPagination()
        mypage.default_limit = 2
        page_list = mypage.paginate_queryset(ret, request, self)
        ser = BookSer(instance=page_list, many=True)

        # 1.在settings里面配置参数
        # 2.写一个类,继承它,重写属性
        # 3.在对象里修改

        # 每页显示个数
        # default_limit = api_settings.PAGE_SIZE

        # 从标杆往后取多少个,例如:127.0.0.1/books/?offset=1&limit=3
        # limit_query_param = 'limit'

        # 指定标杆,从标杆往后取。127.0.0.1/books/?offset=1
        # offset_query_param = 'offset'

        # 最大显示条数
        # max_limit = None

        # return Response(ser.data)
        # 页面多了上一页和下一页
        return mypage.get_paginated_response(ser.data)
'''

class Book(APIView):
    def get(self, request):
        ret = models.Book.objects.all()
        # 实例化一个对象
        mypage = CursorPagination()
        mypage.ordering = 'id'
        mypage.page_size = 2

        page_list = mypage.paginate_queryset(ret, request, self)
        ser = BookSer(instance=page_list, many=True)

        # 1.在settings里面配置参数
        # 2.写一个类,继承它,重写属性
        # 3.在对象里修改

        # 每页显示个数
        # page_size = api_settings.PAGE_SIZE

        # 页码的参数key。例如:127.0.0.1/books/?cursor
        # cursor_query_param = 'cursor'


        # 最大显示条数
        # max_page_size = None

        # 排序
        # ordering = '-created'

        # return Response(ser.data)
        # 页面多了总条数,上一页和下一页的链接
        return mypage.get_paginated_response(ser.data)

你可能感兴趣的:(rest_framework)