在前面说的 APIView 中封装了三大认证,分别为认证、权限、频率。认证即登录认证,权限表示该用户是否有权限访问接口,频率表示用户指定时间内能访问接口的次数
为了方便举例说明,事先定义好模型表
from django.db import models
from django.contrib.auth.models import AbstractUser
class Book(models.Model):
name = models.CharField(max_length=32)
price = models.DecimalField(max_digits=5, decimal_places=2)
# 关联关系
publish = models.ForeignKey(to='Publish', on_delete=models.CASCADE)
authors = models.ManyToManyField(to='Author') # 自动生成中间表
def __str__(self):
return self.name
class Meta:
verbose_name_plural = '图书表'
def publish_info(self):
return {'name': self.publish.name, 'city': self.publish.city}
def authors_info(self):
return [{'name': i.name, 'age': i.age} for i in self.authors.all()]
class Author(models.Model):
name = models.CharField(max_length=32)
age = models.IntegerField()
author_detail = models.OneToOneField(to='AuthorDetail', on_delete=models.CASCADE)
def __str__(self):
return self.name
class Meta:
verbose_name_plural = '作者表'
class AuthorDetail(models.Model):
telephone = models.BigIntegerField()
birthday = models.DateField()
addr = models.CharField(max_length=64)
class Meta:
verbose_name_plural = '作者详情表'
class Publish(models.Model):
name = models.CharField(max_length=32)
city = models.CharField(max_length=32)
email = models.EmailField()
def __str__(self):
return self.name
class Meta:
verbose_name_plural = '出版社表'
class MyUser(AbstractUser):
user_type = models.IntegerField(choices=((1, '超级用户'), (2, 'VIP用户'), (3, '普通用户')), default=3)
class Meta:
verbose_name_plural = '用户表'
class UserToken(models.Model):
# 用于存放uuid
token = models.CharField(max_length=32)
user = models.OneToOneField(to=MyUser, on_delete=models.CASCADE)
自定义登录认证可以生成随机字符串,并添加进 UserToken 表中,每一次登录都会生成,若该字符串已存在则更新,不存在则新建。
登录视图函数
import uuid
from rest_framework.viewsets import ViewSet
from django.contrib import auth
# 继承 ViewSet,可以自动添加路由,也可以使用 action 装饰器自定义请求方法
class LoginView(ViewSet):
# 登录功能不需要认证,设为空列表
authentication_classes = []
permission_classes = []
@action(methods=['post', ], detail=False)
def login(self, request):
username = request.data.get('username')
password = request.data.get('password')
# 验证用户名以及密码是否正确
user = auth.authenticate(request, username=username, password=password)
if not user:
return Response({'code': 10001, 'msg': '用户名或密码错误'})
else:
uid = uuid.uuid4()
# defaults 是需要更新的数据, user可以理解为是筛选的条件
models.UserToken.objects.update_or_create(defaults={'token': uid}, user=user)
return Response({'code': 10000, 'msg': '登录成功', 'token': uid})
自定义认证表需要创建认证类,首先继承拓展 BaseAuthentication
导入语句:from rest_framework.authentication import BaseAuthentication
自定义认证类
from rest_framework.authentication import BaseAuthentication
from rest_framework.exceptions import AuthenticationFailed # 异常
class MyAuthentication(BaseAuthentication):
# 重写 BaseAuthentication 中的 authenticate 方法
def authenticate(self, request):
# 在请求头中获取用户登录的 token 字符串
token = request.query_params.get('token')
# 判断该字符串是否存在
user_token = UserToken.objects.filter(token=token).first()
if user_token:
# 返回的第一个值是当前登录用户,第二个值是 token
return user_token.user, token
else:
# 若不存在则报 AuthenticationFailed 异常
raise AuthenticationFailed('请先登录')
在配置文件中添加配置,DEFAULT_AUTHENTICATION_CLASSES 的值为列表,其包含认证类路径
REST_FRAMEWORK={
"DEFAULT_AUTHENTICATION_CLASSES":["app01.auth.MyAuthentication",]
}
局部使用,只需要在视图类里加入:
authentication_classes = [MyAuthentication, ]
可以选择某一些方法可以认证,在视图类中添加 get_authenticators
def get_authenticators(self):
if self.request.method != 'GET':
return [MyAuthentication(), ]
总结流程
-SessionAuthentication 之前老的 session 认证登录方式用,后期不用
-BasicAuthentication 基本认证方式
-TokenAuthentication 使用 token 认证方式,也可以自己写
可以在配置文件中配置全局默认的认证方案
REST_FRAMEWORK = {
'DEFAULT_AUTHENTICATION_CLASSES': (
'rest_framework.authentication.SessionAuthentication', # session认证
'rest_framework.authentication.BasicAuthentication', # 基本认证
)
}
也可以在每个视图中通过设置authentication_classess属性来设置
from rest_framework.authentication import SessionAuthentication, BasicAuthentication
from rest_framework.views import APIView
class ExampleView(APIView):
authentication_classes = [SessionAuthentication, BasicAuthentication]
...
登录认证成功后,还需要认证权限,有一些接口需要指定权限才能访问。所以权限需要和登录认证相关联。每个人的权限在表中默认设为普通用户。
自定义权限需要继承 BasePermission 编写权限类
导入语句:from rest_framework.permissions import BasePermission
自定义权限类
from rest_framework.permissions import BasePermission
from .models import UserToken, MyUser
class MyPermission(BasePermission):
# message 为认证失败提示信息
message = ''
# 需要重写 has_permission 方法,原方法默认返回 True
def has_permission(self, request, view):
# 获取当前登录用户
user = request.user
# 获取当前用户权限类型
user_type = user.user_type
if user_type == 1:
# 权限符合返回 True
return True
else:
# 权限不符合,添加提示信息,并返回 False
self.message = '你是: %s,权限不够' % user.get_user_type_display()
return False
全局使用也是在配置文件中添加
REST_FRAMEWORK={
"DEFAULT_AUTHENTICATION_CLASSES":["app01.auth.MyAuthentication",],
"DEFAULT_PERMISSION_CLASSES":["app01.auth.MyPermission",]
}
局部使用,只需要在视图类里加入该权限类即可:
permission_classes = [MyPermission,]
在视图类中添加 get_permissions 判断如果请求方式符合就去认证
def get_permissions(self):
if self.request.method == 'DELETE':
return [MyPermission(), ]
总结流程
from rest_framework.permissions import AllowAny,IsAuthenticated,IsAdminUser,IsAuthenticatedOrReadOnly
-AllowAny 允许所有用户
-IsAdminUser 校验是不是 auth 的超级管理员权限
-IsAuthenticated 后面用,验证用户是否登录,登录后才有权限,没登录就没有权限
-IsAuthenticatedOrReadOnly 了解即可
全局使用
可以在配置文件中全局设置默认的权限管理类,如下
REST_FRAMEWORK = {
....
'DEFAULT_PERMISSION_CLASSES': (
'rest_framework.permissions.IsAuthenticated',
)
}
如果未指明,则采用如下默认配置
'DEFAULT_PERMISSION_CLASSES': (
'rest_framework.permissions.AllowAny',
)
局部使用
也可以在具体的视图中通过 permission_classes 属性来设置,如下
from rest_framework.permissions import IsAuthenticated
from rest_framework.views import APIView
class ExampleView(APIView):
permission_classes = (IsAuthenticated,)
...
频率类
from rest_framework.throttling import SimpleRateThrottle
class MyThrottle(SimpleRateThrottle):
# 该属性作为键名在 setting 配置文件中使用
scope = 'count_time'
# 重写 get_cache_key 方法,该方法返回的值会被作为限制的依据
def get_cache_key(self, request, view):
return request.META.get('REMOTE_ADDR')
setting 配置
REST_FRAMEWORK = {
"DEFAULT_THROTTLE_RATES": {
# 频率类中scope对应的值
'count_time': '3/m', # 数字/s m h d
},
}
REST_FRAMEWORK = {
"DEFAULT_THROTTLE_RATES": {
# 频率类中scope对应的值
'count_time': '3/m', # 数字/s m h d
},
'DEFAULT_THROTTLE_CLASSES': ['app01.auth.MyThrottle', ]
}
在视图类中添加,同样,局部禁用只要赋值空列表即可
throttle_classes = [MyThrottle, ]
总结流程
request.META
中获取 'REMOTE_ADDR'
scope
,该属性的值会作为频率的键名,在 setting 配置文件 REST_FRAMEWORK 中的 DEFAULT_THROTTLE_RATES 配置,键名是 scope,键值是字符串,格式为 'x/y'
,x 表示访问的次数,y 表示访问的时间区间(可以为 s(秒)、m(份)、h(时)、d(天))AnonRateThrottle 内置频率类的功能:对于登录用户不限制次数,只未登录用户限制次数,限制的次数需要在配置文件中配置。使用也支持全局和局部
配置文件
REST_FRAMEWORK = {
'DEFAULT_THROTTLE_CLASSES': (
'rest_framework.throttling.AnonRateThrottle',
),
'DEFAULT_THROTTLE_RATES': {
'anon': '3/m',
}
}
UserRateThrottle 内置频率类的功能:限制登录用户的频率,限制的次数需要在配置文件中配置。也支持全局和局部使用
配置文件
REST_FRAMEWORK = {
'DEFAULT_THROTTLE_CLASSES': (
'rest_framework.throttling.UserRateThrottle'
),
'DEFAULT_THROTTLE_RATES': {
'user': '10/m'
}
}
# 自定义的逻辑
#(1)取出访问者ip
#(2)判断当前ip不在访问字典里,添加进去,并且直接返回True,表示第一次访问,在字典里,继续往下走
#(3)循环判断当前ip的列表,有值,并且当前时间减去列表的最后一个时间大于60s,把这种数据pop掉,这样列表中只有60s以内的访问时间,
#(4)判断,当列表小于3,说明一分钟以内访问不足三次,把当前时间插入到列表第一个位置,返回True,顺利通过
#(5)当大于等于3,说明一分钟内访问超过三次,返回False验证失败
class MyThrottles():
VISIT_RECORD = {}
def __init__(self):
self.history=None
def allow_request(self,request, view):
#(1)取出访问者ip
# print(request.META)
ip=request.META.get('REMOTE_ADDR')
import time
ctime=time.time()
# (2)判断当前ip不在访问字典里,添加进去,并且直接返回True,表示第一次访问
if ip not in self.VISIT_RECORD:
self.VISIT_RECORD[ip]=[ctime,]
return True
self.history=self.VISIT_RECORD.get(ip)
# (3)循环判断当前ip的列表,有值,并且当前时间减去列表的最后一个时间大于60s,把这种数据pop掉,这样列表中只有60s以内的访问时间,
while self.history and ctime-self.history[-1]>60:
self.history.pop()
# (4)判断,当列表小于3,说明一分钟以内访问不足三次,把当前时间插入到列表第一个位置,返回True,顺利通过
# (5)当大于等于3,说明一分钟内访问超过三次,返回False验证失败
if len(self.history)<3:
self.history.insert(0,ctime)
return True
else:
return False
def wait(self):
import time
ctime=time.time()
return 60-(ctime-self.history[-1])
AnonRateThrottle
限制所有匿名未认证用户,使用 IP
区分用户。
使用 DEFAULT_THROTTLE_RATES[‘anon’] 来设置频次
UserRateThrottle
限制认证用户,使用 User id
来区分。
使用 DEFAULT_THROTTLE_RATES[‘user’] 来设置频次
ScopedRateThrottle
限制用户对于每个视图的访问频次,使用 ip 或 user id
三大认证的顺序是:登录 ----> 权限 ----> 频率 。
已知的是三大认证是在 APIView 中封装,其源码如下
dispatch
def dispatch(self, request, *args, **kwargs):
...
try:
# 三大认证
self.initial(request, *args, **kwargs)
...
except Exception as exc:
...
...
initial
def initial(self, request, *args, **kwargs):
...
# 认证
self.perform_authentication(request)
# 权限
self.check_permissions(request)
# 频率
self.check_throttles(request)
在代码中由上而下执行,因此有了三大认证的顺序
局部使用三大认证是在视图类中使用,例如 authentication_classes、permission_classes、throttle_classes,要想知道如何在视图类中配置就可以进行认证的可以分析其源码。
def perform_authentication(self, request):
request.user
rest_framework.request Request
中查看 Request 源码,发现其有一个名为 user 的方法,如下所示class Request:
...
@property
def user(self):
if not hasattr(self, '_user'):
with wrap_attributeerrors():
self._authenticate()
return self._user
_authenticate
方法class Request:
def _authenticate(self):
for authenticator in self.authenticators:
try:
user_auth_tuple = authenticator.authenticate(self)
except exceptions.APIException:
self._not_authenticated()
raise
if user_auth_tuple is not None:
self._authenticator = authenticator
self.user, self.auth = user_auth_tuple
return
self._not_authenticated()
for authenticator in self.authenticators
语句中的 self.authenticators 是我们在视图类上添加的认证类,但是是以: [认证类(), 认证类2()] 的形式存在。所以 authenticator 表示的是 认证类()user_auth_tuple = authenticator.authenticate(self)
调用认证类中的 authenticate 方法,返回的元组(一个是登录用户、一个是 token)用 user_auth_tuple 接收。except exceptions.APIException:
捕获的是 APIException,我们抛出的是 AuthenticationFailed,但是由于其继承了 APIException,相当于也捕获了。self.user, self.auth = user_auth_tuple
如果返回了两个值,第一个值给了 request.user ,第二个值给了 request.auth。因此认证过后 request.user 会有当前用户。class Request:
def __init__(self, request, parsers=None, authenticators=None,
negotiator=None, parser_context=None):
...
self.authenticators = authenticators or ()
request = self.initialize_request(request, *args, **kwargs)
class APIView(View):
...
def initialize_request(self, request, *args, **kwargs):
"""
Returns the initial request object.
"""
parser_context = self.get_parser_context(request)
return Request(
request,
parsers=self.get_parsers(),
authenticators=self.get_authenticators(),
negotiator=self.get_content_negotiator(),
parser_context=parser_context
)
authenticators=self.get_authenticators()
由 get_authenticators() 赋值class APIView(View):
...
def get_permissions(self):
return [permission() for permission in self.permission_classes]
同样的查看在 dispatch 中 initial 方法里的 check_permissions 源码
check_permissions
class APIView(View):
...
def check_permissions(self, request):
for permission in self.get_permissions():
if not permission.has_permission(request, self):
self.permission_denied(
request,
message=getattr(permission, 'message', None),
code=getattr(permission, 'code', None)
)
for permission in self.get_permissions():
permission 表示的是权限类对象if not permission.has_permission(request, self):
判断 has_permission 方法返回的是 True 还是 False,如果为 False 则执行 permission_denied 方法报异常message=getattr(permission, 'message', None),
获取在视图类中编写的 messageget_permissions
def get_permissions(self):
return [permission() for permission in self.permission_classes]
permission_denied
def permission_denied(self, request, message=None, code=None):
if request.authenticators and not request.successful_authenticator:
raise exceptions.NotAuthenticated()
raise exceptions.PermissionDenied(detail=message, code=code)
查看 check_throttles 源码
check_throttles
def check_throttles(self, request):
throttle_durations = []
for throttle in self.get_throttles():
if not throttle.allow_request(request, self):
throttle_durations.append(throttle.wait())
if throttle_durations:
durations = [
duration for duration in throttle_durations
if duration is not None
]
duration = max(durations, default=None)
self.throttled(request, duration)
for throttle in self.get_throttles():
throttle 表示的是频率类对象,get_throttles() 表示的是频率类对象列表。if not throttle.allow_request(request, self):
获取对象的 allow_request 方法,返回 True 就是没有频率限制住,返回 False 就是被频率限制了我们可以查看 SimpleRateThrottle 的 allow_request 方法
allow_request
def allow_request(self, request, view):
if self.rate is None:
return True
self.key = self.get_cache_key(request, view)
if self.key is None:
return True
self.history = self.cache.get(self.key, [])
self.now = self.timer()
while self.history and self.history[-1] <= self.now - self.duration:
self.history.pop()
if len(self.history) >= self.num_requests:
return self.throttle_failure()
return self.throttle_success()
if self.rate is None:
对 rate 进行判断,rate 属性在其双下 _init_ 方法中实现_init_
def __init__(self):
if not getattr(self, 'rate', None):
# 没有该属性调用 get_rate
self.rate = self.get_rate()
self.num_requests, self.duration = self.parse_rate(self.rate)
get_rate
def get_rate(self):
if not getattr(self, 'scope', None):
msg = ("You must set either `.scope` or `.rate` for '%s' throttle" %
self.__class__.__name__)
raise ImproperlyConfigured(msg)
try:
return self.THROTTLE_RATES[self.scope]
except KeyError:
msg = "No default throttle rate set for '%s' scope" % self.scope
raise ImproperlyConfigured(msg)
THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES
parse_rate
def parse_rate(self, rate):
if rate is None:
return (None, None)
num, period = rate.split('/')
num_requests = int(num)
duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
return (num_requests, duration)
num, period = rate.split('/')
对 rate 进行切割,这里我们假设频率为 5/m,那么切割后的 num 为 5、period 为 mduration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
[period[0]] 只取出了第一个值然后去字典中取值,也就是说,定义 scope 的时候不一定要写 m,只要以 m 开头即可。回到 allow_request 方法
allow_request
def allow_request(self, request, view):
if self.rate is None:
return True
self.key = self.get_cache_key(request, view)
if self.key is None:
return True
self.history = self.cache.get(self.key, [])
self.now = self.timer()
while self.history and self.history[-1] <= self.now - self.duration:
self.history.pop()
if len(self.history) >= self.num_requests:
return self.throttle_failure()
return self.throttle_success()
self.key = self.get_cache_key(request, view)
调用 get_cache_key 方法,该方法需要自己重写的,用于指定判断的依据。self.history = self.cache.get(self.key, [])
从缓存中拿数据,取出的数据是时间的列表,类似于 [时间2, 时间1],没有则赋值空列表 self.now = self.timer()
timer() 是类属性,加括号调用了,获取时间。源码如下所示timer = time.time
self.history.pop()
把所有超过时间的数据都剔除,self.history 只剩限定时间内的访问时间if len(self.history) >= self.num_requests:
大于等于配置的次数执行 throttle_failure 返回 False,否则执行 throttle_success 把当前时间插入,并返回 True