Django rest framework simple JWT 的源码分析

本文讲述的是 Django rest framework simple JWT 获取 token 的源码流程。
我们知道在项目工程中引入 rest_framework_simplejwt 包后,只需要在 url.py 中

from rest_framework_simplejwt.views import TokenObtainPairView, TokenRefreshView
urlpatterns =[ \[](https://blog.csdn.net/cpxsxn/article/details/104613129)
	url('obtain/', TokenObtainPairView.as_view(), name='token_obtain_pair'), ---------------------------step1
]

所以本文重点讲解的就是该 URL 背后的源码流程,具体参考参考下面的调用流程。

结论:
RefreshToken 实例的 payload (是个 dict ) 中会含有 token_type/exp/jti/user_id。
结合认证类 rest_framework_simplejwt.authentication.JWTAuthentication 的 authenticate 方法的实现,我们可以宏观的理解下 JWT 的认证本质:
给你个 token 字符串,里面隐含有 user_id 和 过期时间, 然后当用户带着这个 token 请求的时候,只需要解密 token (能解密成功说明是合法的 token)得到一个符合 json 格式的串(参考 JWT 到底是个什么串,里面包含了什么信息),然后从中取出 user_id 去和 auth_user 表对比,就能知道当前请求是谁发送的了,取出 expire time 结合当前时间就能知道是否过期,设计的真好!!!。
这样子做的意义是用非敏感信息 user_id 代替了 user/password等敏感信息来表明发送请求的用户是谁。当然后端会对 user_id 做认证的。
其实不管是哪种认证方式要解决的问题都是发请求的人是谁,是否有权限获取想查看的数据。

扩展:
我们能否自定义 token 中的 payload 呢(这个灵感来自 xdbamp 后端改造项目),比如想把用户的部门信息加入到 payload 中,答案当然是肯定的啦,请参考自定义 JWT 中包含的信息,简直完美!

D:\WorkSpace\Archiver\archiver_gitcode\venv\Lib\site-packages\rest_framework_simplejwt\views.py
class TokenViewBase(generics.GenericAPIView):
	permission_classes = ()
	authentication_classes = ()

	serializer_class = None

	www_authenticate_realm = 'api'

	def get_authenticate_header(self, request):
		return '{0} realm="{1}"'.format(
			AUTH_HEADER_TYPES[0],
			self.www_authenticate_realm,
		)

	def post(self, request, *args, **kwargs):
		serializer = self.get_serializer(data=request.data)    ------------------------------step2

		try:
			serializer.is_valid(raise_exception=True)
		except TokenError as e:
			raise InvalidToken(e.args[0])

		return Response(serializer.validated_data, status=status.HTTP_200_OK)
		
class TokenObtainPairView(TokenViewBase):
	"""
	Takes a set of user credentials and returns an access and refresh JSON web
	token pair to prove the authentication of those credentials.
	"""
	serializer_class = serializers.TokenObtainPairSerializer
	

			
		
D:\WorkSpace\Archiver\archiver_gitcode\venv\Lib\site-packages\rest_framework_simplejwt\serializers.py
from .tokens import RefreshToken, SlidingToken, UntypedToken
class TokenObtainPairSerializer(TokenObtainSerializer):
	@classmethod
	def get_token(cls, user):
		return RefreshToken.for_user(user)             -------------------------------------step4

	def validate(self, attrs):
		data = super(TokenObtainPairSerializer, self).validate(attrs)

		refresh = self.get_token(self.user)             ------------------------------------step3

		data['refresh'] = text_type(refresh)
		data['access'] = text_type(refresh.access_token)  ----------------------------------step7

		return data

			
			
			
			
D:\WorkSpace\Archiver\archiver_gitcode\venv\Lib\site-packages\rest_framework_simplejwt\tokens.py			
class RefreshToken(BlacklistMixin, Token):
	token_type = 'refresh'
	lifetime = api_settings.REFRESH_TOKEN_LIFETIME
	no_copy_claims = (api_settings.TOKEN_TYPE_CLAIM, 'exp', 'jti')

	@property
	def access_token(self):
		"""
		Returns an access token created from this refresh token.  Copies all
		claims present in this refresh token to the new access token except
		those claims listed in the `no_copy_claims` attribute.
		"""
		access = AccessToken()          -----------------------------------------------------step8: 创建 AccessToken 的实例

		# Use instantiation time of refresh token as relative timestamp for
		# access token "exp" claim.  This ensures that both a refresh and
		# access token expire relative to the same time if they are created as
		# a pair.
		access.set_exp(from_time=self.current_time)

		no_copy = self.no_copy_claims
		for claim, value in self.payload.items():   -----------------------------------------step9: 根据 step6_x 我们知道 RefreshToken 的 payload 里面有 token_type/exp/jti/user_id
			if claim in no_copy:                                                             把 RefreshToken 实例中的除了 token_type/exp/jti 不拷贝 其他都拷贝其实拷贝的就是 user_id
				continue
			access[claim] = value

		return access

		
class AccessToken(Token):
	token_type = 'access'
	lifetime = api_settings.ACCESS_TOKEN_LIFETIME

		
class BlacklistMixin(object):
	"""
	If the `rest_framework_simplejwt.token_blacklist` app was configured to be
	used, tokens created from `BlacklistMixin` subclasses will insert
	themselves into an outstanding token list and also check for their
	membership in a token blacklist.
	"""
	if 'rest_framework_simplejwt.token_blacklist' in settings.INSTALLED_APPS:
		def verify(self, *args, **kwargs):
			self.check_blacklist()

			super(BlacklistMixin, self).verify(*args, **kwargs)

		def check_blacklist(self):
			"""
			Checks if this token is present in the token blacklist.  Raises
			`TokenError` if so.
			"""
			jti = self.payload['jti']

			if BlacklistedToken.objects.filter(token__jti=jti).exists():
				raise TokenError(_('Token is blacklisted'))

		def blacklist(self):
			"""
			Ensures this token is included in the outstanding token list and
			adds it to the blacklist.
			"""
			jti = self.payload['jti']
			exp = self.payload['exp']

			# Ensure outstanding token exists with given jti
			token, _ = OutstandingToken.objects.get_or_create(
				jti=jti,
				defaults={
					'token': str(self),
					'expires_at': datetime_from_epoch(exp),
				},
			)

			return BlacklistedToken.objects.get_or_create(token=token)

		@classmethod
		def for_user(cls, user):        -----------------------------------------------------step5_1
			"""
			Adds this token to the outstanding token list.
			"""
			token = super(BlacklistMixin, cls).for_user(user)

			jti = token['jti']
			exp = token['exp']

			OutstandingToken.objects.create(
				user=user,
				jti=jti,
				token=str(token),
				created_at=token.current_time,
				expires_at=datetime_from_epoch(exp),
			)

			return token
	
	
@python_2_unicode_compatible
class Token(object):
	"""
	A class which validates and wraps an existing JWT or can be used to build a
	new JWT.
	"""
	token_type = None
	lifetime = None

	def __init__(self, token=None, verify=True):
		"""
		!!!! IMPORTANT !!!! MUST raise a TokenError with a user-facing error
		message if the given token is invalid, expired, or otherwise not safe
		to use.
		"""
		if self.token_type is None or self.lifetime is None:
			raise TokenError(_('Cannot create token with no type or lifetime'))

		self.token = token
		self.current_time = aware_utcnow()

		# Set up token
		if token is not None:
			# An encoded token was provided
			from .state import token_backend

			# Decode token
			try:
				self.payload = token_backend.decode(token, verify=verify)
			except TokenBackendError:
				raise TokenError(_('Token is invalid or expired'))

			if verify:
				self.verify()
		else:
			# New token.  Skip all the verification steps.
			self.payload = {api_settings.TOKEN_TYPE_CLAIM: self.token_type}     ------------step6_2 把 token_type 加到 RefreshToken 的 payload 里面

			# Set "exp" claim with default value
			self.set_exp(from_time=self.current_time, lifetime=self.lifetime)   ------------step6_3 把 exp 加到 RefreshToken 的 payload 里面

			# Set "jti" claim
			self.set_jti()                                                      ------------step6_4 把 jti(jwt-token-id) 加到 RefreshToken 的 payload 里面

	def __repr__(self):
		return repr(self.payload)

	def __getitem__(self, key):
		return self.payload[key]

	def __setitem__(self, key, value):
		self.payload[key] = value

	def __delitem__(self, key):
		del self.payload[key]

	def __contains__(self, key):
		return key in self.payload

	def get(self, key, default=None):
		return self.payload.get(key, default)

	def __str__(self):
		"""
		Signs and returns a token as a base64 encoded string.
		"""
		from .state import token_backend

		return token_backend.encode(self.payload)

	def verify(self):
		"""
		Performs additional validation steps which were not performed when this
		token was decoded.  This method is part of the "public" API to indicate
		the intention that it may be overridden in subclasses.
		"""
		# According to RFC 7519, the "exp" claim is OPTIONAL
		# (https://tools.ietf.org/html/rfc7519#section-4.1.4).  As a more
		# correct behavior for authorization tokens, we require an "exp"
		# claim.  We don't want any zombie tokens walking around.
		self.check_exp()

		# Ensure token id is present
		if 'jti' not in self.payload:
			raise TokenError(_('Token has no id'))

		self.verify_token_type()

	def verify_token_type(self):
		"""
		Ensures that the token type claim is present and has the correct value.
		"""
		try:
			token_type = self.payload[api_settings.TOKEN_TYPE_CLAIM]
		except KeyError:
			raise TokenError(_('Token has no type'))

		if self.token_type != token_type:
			raise TokenError(_('Token has wrong type'))

	def set_jti(self):
		"""
		Populates the "jti" claim of a token with a string where there is a
		negligible probability that the same string will be chosen at a
		later time.

		See here:
		https://tools.ietf.org/html/rfc7519#section-4.1.7
		"""
		self.payload['jti'] = uuid4().hex

	def set_exp(self, claim='exp', from_time=None, lifetime=None):
		"""
		Updates the expiration time of a token.
		"""
		if from_time is None:
			from_time = self.current_time

		if lifetime is None:
			lifetime = self.lifetime

		self.payload[claim] = datetime_to_epoch(from_time + lifetime)

	def check_exp(self, claim='exp', current_time=None):
		"""
		Checks whether a timestamp value in the given claim has passed (since
		the given datetime value in `current_time`).  Raises a TokenError with
		a user-facing error message if so.
		"""
		if current_time is None:
			current_time = self.current_time

		try:
			claim_value = self.payload[claim]
		except KeyError:
			raise TokenError(format_lazy(_("Token has no '{}' claim"), claim))

		claim_time = datetime_from_epoch(claim_value)
		if claim_time <= current_time:
			raise TokenError(format_lazy(_("Token '{}' claim has expired"), claim))

	@classmethod
	def for_user(cls, user):   --------------------------------------------------------------step5_2
		"""
		Returns an authorization token for the given user that will be provided
		after authenticating the user's credentials.
		"""
		user_id = getattr(user, api_settings.USER_ID_FIELD)
		if not isinstance(user_id, int):
			user_id = text_type(user_id)

		token = cls()         ---------------------------------------------------------------step6_1 创建 RefreshToken 的实例
		token[api_settings.USER_ID_CLAIM] = user_id     -------------------------------------step6_5 把 user_id 加到 RefreshToken 的 payload 里面
                                                                                             因为 RefreshToken 的父类 Token 中定义了 __setitem__ 等方法,
		return token																		 所以可以把 RefreshToken 的实例像字典那样使用

你可能感兴趣的:(Django,#,auth)