Django DRF JWT模块验证源码解读

最近刚好有项目用到Django的DRF刚好可以学习实现一下。
官网推荐使用的是djangorestframework_simplejwt也不知道有什么特殊原因
https://www.django-rest-framework.org/api-guide/authentication/#django-rest-auth
看代码也差不多,这里是根据网上素材比较多的rest_framework_jwt实现的,但是注意这个库已经没有维护更新了issues/484

jwt已经做了深度封装,在保证实现了jwt的情况下,做到了配置简单快捷。在urls配置中引入obtain_jwt_token就是view

from django.urls import re_path
from rest_framework_jwt.views import obtain_jwt_token
urlpatterns = [
    re_path('^api-token-auth/', obtain_jwt_token),
]

进入obtain_jwt_token进一步搜索发现具体的view中POST方法实现

    def post(self, request, *args, **kwargs):
        serializer = self.get_serializer(data=request.data)
        # validate JSONWebTokenSerializer
        if serializer.is_valid():
            user = serializer.object.get('user') or request.user
            token = serializer.object.get('token')
            response_data = jwt_response_payload_handler(token, user, request)
            response = Response(response_data)
            if api_settings.JWT_AUTH_COOKIE:
                expiration = (datetime.utcnow() +
                              api_settings.JWT_EXPIRATION_DELTA)
                response.set_cookie(api_settings.JWT_AUTH_COOKIE,
                                    token,
                                    expires=expiration,
                                    httponly=True)
            return response

其中下面这一句会验证序列化器的合法性。这个serializer对应的是JSONWebTokenSerializerserializer.is_valid() 会对JSONWebTokenSerializer对象serializer内容做一定验证

self.get_serializer(data=request.data)
def is_valid(self, raise_exception=False):
	assert not hasattr(self, 'restore_object'), (
	   'Serializer `%s.%s` has old-style version 2 `.restore_object()` '
	   'that is no longer compatible with REST framework 3. '
	   'Use the new-style `.create()` and `.update()` methods instead.' %
	   (self.__class__.__module__, self.__class__.__name__)
	)

	assert hasattr(self, 'initial_data'), (
	  'Cannot call `.is_valid()` as no `data=` keyword argument was '
	  'passed when instantiating the serializer instance.'
	)

	if not hasattr(self, '_validated_data'):
	   try:
	       self._validated_data = self.run_validation(self.initial_data)
	   except ValidationError as exc:
	       self._validated_data = {}
	       self._errors = exc.detail
	   else:
	       self._errors = {}

	if self._errors and raise_exception:
	   raise ValidationError(self.errors)

    return not bool(self._errors)

注意看其中的重点是

self.run_validation(self.initial_data)

再看其中内容

def run_validation(self, data=empty):
   """
   We override the default `run_validation`, because the validation
   performed by validators and the `.validate()` method should
   be coerced into an error dictionary with a 'non_fields_error' key.
   """
   (is_empty_value, data) = self.validate_empty_values(data)
   if is_empty_value:
       return data

   value = self.to_internal_value(data)
   try:
       self.run_validators(value)
       # validate through serializer validate
       value = self.validate(value)
       assert value is not None, '.validate() should return the validated data'
   except (ValidationError, DjangoValidationError) as exc:
       raise ValidationError(detail=as_serializer_error(exc))

   return value

终于调用了关键的验证函数,该函数在JSONWebTokenSerializer中重写了。

value = self.validate(value)

def validate(self, attrs):
    credentials = {
        self.username_field: attrs.get(self.username_field),
        'password': attrs.get('password')
    }

    if all(credentials.values()):
        user = authenticate(**credentials)

        if user:
            if not user.is_active:
                msg = _('User account is disabled.')
                raise serializers.ValidationError(msg)
            # defult payload
            # dict {'user_id':1, 'username':'axioum','exp:'datetime.datetime(),'email':''}
            payload = jwt_payload_handler(user)

            return {
                'token': jwt_encode_handler(payload),
                'user': user
            }
        else:
            msg = _('Unable to log in with provided credentials.')
            raise serializers.ValidationError(msg)
    else:
        msg = _('Must include "{username_field}" and "password".')
        msg = msg.format(username_field=self.username_field)
        raise serializers.ValidationError(msg)

其中最为关键的是

user = authenticate(**credentials)

authenticate的犯法内容如下。获取后端配置的函数是
_get_backends(return_tuples=True),获取后端所有配置的函数是 settings = LazySettings()

def authenticate(request=None, **credentials):
    """
    If the given credentials are valid, return a User object.
    """
    for backend, backend_path in _get_backends(return_tuples=True):
        try:
            inspect.getcallargs(backend.authenticate, request, **credentials)
        except TypeError:
            # This backend doesn't accept these credentials as arguments. Try the next one.
            continue
        try:
            user = backend.authenticate(request, **credentials)
        except PermissionDenied:
            # This backend says to stop in our tracks - this user should not be allowed in at all.
            break
        if user is None:
            continue
        # Annotate the user object with the path of the backend.
        user.backend = backend_path
        return user

    # The credentials supplied are invalid to all backends, fire signal
    user_login_failed.send(sender=__name__, credentials=_clean_credentials(credentials), request=request)

def _get_backends(return_tuples=False):
    backends = []
    for backend_path in settings.AUTHENTICATION_BACKENDS:
        backend = load_backend(backend_path)
        backends.append((backend, backend_path) if return_tuples else backend)
    if not backends:
        raise ImproperlyConfigured(
            'No authentication backends have been defined. Does '
            'AUTHENTICATION_BACKENDS contain anything?'
        )
    return backends

settings = LazySettings()

LazySettings() 和基础类 LazyObject 的写法非常值得学习,大量的运用python特有的黑魔法。只是还需要一些时间研究写法。LazySettings在运行runserver时就会加载。注意另外一个会用到LazyObject 的地方是中间件验证。

class LazySettings(LazyObject):
    """
    A lazy proxy for either global Django settings or a custom settings object.
    The user can manually configure settings prior to using them. Otherwise,
    Django uses the settings module pointed to by DJANGO_SETTINGS_MODULE.
    """
    def _setup(self, name=None):
        """
        Load the settings module pointed to by the environment variable. This
        is used the first time settings are needed, if the user hasn't
        configured settings manually.
        """
        settings_module = os.environ.get(ENVIRONMENT_VARIABLE)
        if not settings_module:
            desc = ("setting %s" % name) if name else "settings"
            raise ImproperlyConfigured(
                "Requested %s, but settings are not configured. "
                "You must either define the environment variable %s "
                "or call settings.configure() before accessing settings."
                % (desc, ENVIRONMENT_VARIABLE))

        self._wrapped = Settings(settings_module)

    def __repr__(self):
        """ Hardcode the class name as otherwise it yields 'Settings'.
        setting the return of LazySettings object
        """
        if self._wrapped is empty:
            return ''
        return '' % {
            'settings_module': self._wrapped.SETTINGS_MODULE,
        }

    def __getattr__(self, name):
        """Return the value of a setting and cache it in self.__dict__."""
        if self._wrapped is empty:
            self._setup(name)
        val = getattr(self._wrapped, name)
        self.__dict__[name] = val
        return val

    def __setattr__(self, name, value):
        """
        Set the value of setting. Clear all cached values if _wrapped changes
        (@override_settings does this) or clear single values when set.
        """
        if name == '_wrapped':
            self.__dict__.clear()
        else:
            self.__dict__.pop(name, None)
        super().__setattr__(name, value)

    def __delattr__(self, name):
        """Delete a setting and clear it from cache if needed."""
        super().__delattr__(name)
        self.__dict__.pop(name, None)

    def configure(self, default_settings=global_settings, **options):
        """
        Called to manually configure the settings. The 'default_settings'
        parameter sets where to retrieve any unspecified values from (its
        argument must support attribute access (__getattr__)).
        """
        if self._wrapped is not empty:
            raise RuntimeError('Settings already configured.')
        holder = UserSettingsHolder(default_settings)
        for name, value in options.items():
            setattr(holder, name, value)
        self._wrapped = holder

    @property
    def configured(self):
        """Return True if the settings have already been configured."""
        return self._wrapped is not empty


class Settings:
    def __init__(self, settings_module):
        # update this dict from global settings (but only for ALL_CAPS settings)
        for setting in dir(global_settings):
            if setting.isupper():
                setattr(self, setting, getattr(global_settings, setting))

        # store the settings module in case someone later cares
        self.SETTINGS_MODULE = settings_module

        mod = importlib.import_module(self.SETTINGS_MODULE)

        tuple_settings = (
            "INSTALLED_APPS",
            "TEMPLATE_DIRS",
            "LOCALE_PATHS",
        )
        self._explicit_settings = set()
        for setting in dir(mod):
            if setting.isupper():
                setting_value = getattr(mod, setting)

                if (setting in tuple_settings and
                        not isinstance(setting_value, (list, tuple))):
                    raise ImproperlyConfigured("The %s setting must be a list or a tuple. " % setting)
                setattr(self, setting, setting_value)
                self._explicit_settings.add(setting)

        if not self.SECRET_KEY:
            raise ImproperlyConfigured("The SECRET_KEY setting must not be empty.")

        if self.is_overridden('DEFAULT_CONTENT_TYPE'):
            warnings.warn('The DEFAULT_CONTENT_TYPE setting is deprecated.', RemovedInDjango30Warning)

        if hasattr(time, 'tzset') and self.TIME_ZONE:
            # When we can, attempt to validate the timezone. If we can't find
            # this file, no check happens and it's harmless.
            zoneinfo_root = Path('/usr/share/zoneinfo')
            zone_info_file = zoneinfo_root.joinpath(*self.TIME_ZONE.split('/'))
            if zoneinfo_root.exists() and not zone_info_file.exists():
                raise ValueError("Incorrect timezone setting: %s" % self.TIME_ZONE)
            # Move the time zone info into os.environ. See ticket #2315 for why
            # we don't do this unconditionally (breaks Windows).
            os.environ['TZ'] = self.TIME_ZONE
            time.tzset()

    def is_overridden(self, setting):
        return setting in self._explicit_settings

    def __repr__(self):
        return '<%(cls)s "%(settings_module)s">' % {
            'cls': self.__class__.__name__,
            'settings_module': self.SETTINGS_MODULE,
        }

backend对应的是django.contrib.auth.backends.ModelBackend
中的 authenticate。验证过程中会验证UserModel中是否有POST的username,和对应的password。注意UserModel是自定义的在主配置settings中的settings.AUTH_USER_MODEL 。JWT验证时默认密码password值是非明文,因此确认生成的密码使用了 set_password() 参数,而不是简单粗暴对数据模型本身添加。

backend.authenticate(request, **credentials)
class ModelBackend:
    """
    Authenticates against settings.AUTH_USER_MODEL.
    """

    def authenticate(self, request, username=None, password=None, **kwargs):
        if username is None:
            username = kwargs.get(UserModel.USERNAME_FIELD)

        try:
            user = UserModel._default_manager.get_by_natural_key(username)
        except UserModel.DoesNotExist:
            # Run the default password hasher once to reduce the timing
            # difference between an existing and a nonexistent user (#20760).
            UserModel().set_password(password)
        else:
            if user.check_password(password) and self.user_can_authenticate(user):
                return user

    def user_can_authenticate(self, user):
        """
        Reject users with is_active=False. Custom user models that don't have
        that attribute are allowed.
        """
        is_active = getattr(user, 'is_active', None)
        return is_active or is_active is None

    def _get_user_permissions(self, user_obj):
        return user_obj.user_permissions.all()

    def _get_group_permissions(self, user_obj):
        user_groups_field = get_user_model()._meta.get_field('groups')
        user_groups_query = 'group__%s' % user_groups_field.related_query_name()
        return Permission.objects.filter(**{user_groups_query: user_obj})

    def _get_permissions(self, user_obj, obj, from_name):
        """
        Return the permissions of `user_obj` from `from_name`. `from_name` can
        be either "group" or "user" to return permissions from
        `_get_group_permissions` or `_get_user_permissions` respectively.
        """
        if not user_obj.is_active or user_obj.is_anonymous or obj is not None:
            return set()

        perm_cache_name = '_%s_perm_cache' % from_name
        if not hasattr(user_obj, perm_cache_name):
            if user_obj.is_superuser:
                perms = Permission.objects.all()
            else:
                perms = getattr(self, '_get_%s_permissions' % from_name)(user_obj)
            perms = perms.values_list('content_type__app_label', 'codename').order_by()
            setattr(user_obj, perm_cache_name, {"%s.%s" % (ct, name) for ct, name in perms})
        return getattr(user_obj, perm_cache_name)

    def get_user_permissions(self, user_obj, obj=None):
        """
        Return a set of permission strings the user `user_obj` has from their
        `user_permissions`.
        """
        return self._get_permissions(user_obj, obj, 'user')

    def get_group_permissions(self, user_obj, obj=None):
        """
        Return a set of permission strings the user `user_obj` has from the
        groups they belong.
        """
        return self._get_permissions(user_obj, obj, 'group')

    def get_all_permissions(self, user_obj, obj=None):
        if not user_obj.is_active or user_obj.is_anonymous or obj is not None:
            return set()
        if not hasattr(user_obj, '_perm_cache'):
            user_obj._perm_cache = {
                *self.get_user_permissions(user_obj),
                *self.get_group_permissions(user_obj),
            }
        return user_obj._perm_cache

    def has_perm(self, user_obj, perm, obj=None):
        return user_obj.is_active and perm in self.get_all_permissions(user_obj, obj)

    def has_module_perms(self, user_obj, app_label):
        """
        Return True if user_obj has any permissions in the given app_label.
        """
        return user_obj.is_active and any(
            perm[:perm.index('.')] == app_label
            for perm in self.get_all_permissions(user_obj)
        )

    def get_user(self, user_id):
        try:
            user = UserModel._default_manager.get(pk=user_id)
        except UserModel.DoesNotExist:
            return None
        return user if self.user_can_authenticate(user) else None

你可能感兴趣的:(python,Django)