Flask-WTF源码分析中关于CSRF_TOKEN的生成和验证有问题,这里重新分析一下这个流程
-
csrf_token的生成
1.生成一个csrf对象,这个对象主要用于生成和校验csrf_token的
2.生成UnboundField(CSRFTokenField)对象,插入到我们自定义Field列表中
3.在BaseForm中调用process方法,进而调用Field中的process方法生成current_token的值
4.wiget的call方法在渲染input标签时,会将input的value属性赋值current_token
...
#wtforms.form.BaseForm
extra_fields = []
#根据配置,是否需要生成csrf_token
if meta.csrf:
#生成一个csrf对象,这个对象用于生成和校验csrf_token的
self._csrf = meta.build_csrf(self)
#初始化一个CSRFTokenField对象,并加入到extra_fields这个列表中
extra_fields.extend(self._csrf.setup_form(self))
...
#wtforms.meta.DefaultMeta
def build_csrf(self, form):
#csrf_class在DefaultMeta是None,但是在其子类中有实现是_FlaskFormCSRF这个私有类,这个if条件是满足的,不为None
if self.csrf_class is not None:
#返回这个_FlaskFormCSRF的对象
return self.csrf_class()
#上一篇文章分析错误出现在这里,误以为返回的是SessionCSRF这个对象
from wtforms.csrf.session import SessionCSRF
return SessionCSRF()
#wtforms.csrf.core.CSRF
field_class = CSRFTokenField
def setup_form(self, form):
meta = form.meta
#得到field_name=csrf_token
field_name = meta.csrf_field_name
#建立CSRFTokenField的对象,但是这里返回的应该是UnboundField对象
unbound_field = self.field_class(
label='CSRF Token',
csrf_impl=self
)
#返回一个列表[('csrf_token,UnboundField')]
#后续会调用UnboundField的bind方法建立CSRFTokenField对象
return [(field_name, unbound_field)]
#wtforms.csrf.core.CSRFTokenField
def process(self, *args):
super(CSRFTokenField, self).process(*args)
#self.csrf_Impl=_FlaskFormCSRF
#也就是调用_FlaskFormCSRF的generate_csrf_token()生成csrf_token
self.current_token = self.csrf_impl.generate_csrf_token(self)
#flask_wtf.csrf.generate_csrf
def generate_csrf(secret_key=None, token_key=None):
secret_key = _get_config(
secret_key, 'WTF_CSRF_SECRET_KEY', current_app.secret_key,
message='A secret key is required to use CSRF.'
)
field_name = _get_config(
token_key, 'WTF_CSRF_FIELD_NAME', 'csrf_token',
message='A field name is required to use CSRF.'
)
if field_name not in g:
if field_name not in session:
#同时放到session中一份
session[field_name] = hashlib.sha1(os.urandom(64)).hexdigest()
#使用URLSafeTimedSerializer对象对这个csrf_token进行签名并放到全局对象g中
s = URLSafeTimedSerializer(secret_key, salt='wtf-csrf-token')
setattr(g, field_name, s.dumps(session[field_name]))
return g.get(field_name)
至此,csrf_token的值生成完毕!
-
csrf_token的验证
1.调用表单的validate_on_submit方法,来开始校验
2.判断表单中是否有自定义的校验
3.没有则遍历Form中所有的Field的validate()方法进行校验
#表单提交时调用这个方法
if form.validate_on_submit():
#wtforms.form.Form
def validate(self):
extra = {}
for name in self._fields:
#是否有自定义的Field验证
inline = getattr(self.__class__, 'validate_%s' % name, None)
if inline is not None:
extra[name] = [inline]
#调用父类BaseForm的validate方法
return super(Form, self).validate(extra)
#wtforms.form.BaseForm
def validate(self, extra_validators=None):
self._errors = None
success = True
for name, field in iteritems(self._fields):
if extra_validators is not None and name in extra_validators:
extra = extra_validators[name]
else:
extra = tuple()
#调用CSRFTokenField的validate校验,进而调用Field的validate方法
if not field.validate(self, extra):
success = False
return success
#wtforms.fields.core.Field
def validate(self, form, extra_validators=tuple()):
self.errors = list(self.process_errors)
stop_validation = False
# Call pre_validate
try:
#调用CSRFTokenField的pre_validate方法
self.pre_validate(form)
except StopValidation as e:
if e.args and e.args[0]:
self.errors.append(e.args[0])
stop_validation = True
except ValueError as e:
self.errors.append(e.args[0])
# Run validators
if not stop_validation:
chain = itertools.chain(self.validators, extra_validators)
stop_validation = self._run_validation_chain(form, chain)
# Call post_validate
try:
#调用CSRFTokenField的pre_validate方法
self.post_validate(form, stop_validation)
except ValueError as e:
self.errors.append(e.args[0])
return len(self.errors) == 0
#wtforms.csrf.core.CSRFTokenField
def pre_validate(self, form):
# self.csrf_impl=flask_wtf.csrf._FlaskFormCSRF
self.csrf_impl.validate_csrf_token(form, self)
#flask_wtf.csrf.validate_csrf
def validate_csrf(data, secret_key=None, time_limit=None, token_key=None):
secret_key = _get_config(
secret_key, 'WTF_CSRF_SECRET_KEY', current_app.secret_key,
message='A secret key is required to use CSRF.'
)
field_name = _get_config(
token_key, 'WTF_CSRF_FIELD_NAME', 'csrf_token',
message='A field name is required to use CSRF.'
)
#csrf_token的有效期
time_limit = _get_config(
time_limit, 'WTF_CSRF_TIME_LIMIT', 3600, required=False
)
#csrf_token丢失
if not data:
raise ValidationError('The CSRF token is missing.')
#csrf_token没在session中
if field_name not in session:
raise ValidationError('The CSRF session token is missing.')
s = URLSafeTimedSerializer(secret_key, salt='wtf-csrf-token')
#判断csrf_token是否过期或是是否有效
try:
token = s.loads(data, max_age=time_limit)
except SignatureExpired:
raise ValidationError('The CSRF token has expired.')
except BadData:
raise ValidationError('The CSRF token is invalid.')
if not safe_str_cmp(session[field_name], token):
raise ValidationError('The CSRF tokens do not match.')
码字不易,喜欢就留下个小心心吧!