python自定义dataclass实现

dataclass是python3.7中自带的decorator,可以为customer class生成init function,同时支持以attribute形式获取class item。

For example:

@dataclass
class Point:
  x: float
  _: KW_ONLY
  y: float
  z: float

p = Point(0, y=1.5, z=2.0)

定义dataclass的约束有以下几点:1.所有field必须有annotation,本约束的原因将在后续内容说明 2. ‘non-default parameter follows default parameter’ rule,即有default的param必须在无default的之后。

同时,dataclass并不会依据annotation对field value类型做强约束,这一点与python function相同,即使声明字段为str,传入其他类型也不会报错。

首先通过分析dataclass实现原理,判断如何嵌入自定义内容。

原理很简单,dataclass为class生成了许多function,其中最主要的是 __init__,同时还包括__repr__, __eq__等,每个方法最后都通过_create_fn方法生成。在生成方法时需要得到field name,default,及locals等内容。

关于参数annotation和default的获取分为两步,首先通过`__annotations__`方法得到所有field的annotationdict,注意,此处获得的是有序的。而后,通过get class attribute方法得到所有field的default。使用这种方式要求每个field必须有annotation,原因为,假如class中存在一个field有default没有annotation,则__dict__能够拿到该值,__annotation__中没有该key,无法恢复class field原始的顺序,因而会对生成function造成影响。

关于class field获取的说明如下(截图出自python dataclasses.py)

python自定义dataclass实现_第1张图片

 在搞清原理后,便可以动手实现自定义的dataclass,大致框架如下。

def my_dataclass(_cls):
    def _create_fn(name, args, body, *, globals=None, locals=None,
                   return_type):
        """To generate function in class."""
        # Reference: Source code of dataclasses.dataclass
        # Doc link: https://docs.python.org/3/library/dataclasses.html
        # Reference code link:
        # https://github.com/python/cpython/blob/17b16e13bb444001534ed6fccb459084596c8bcf/Lib/dataclasses.py#L412
        # Note that we mutate locals when exec() is called
        if locals is None:
            locals = {}
        if 'BUILTINS' not in locals:
            import builtins
            locals['BUILTINS'] = builtins
        locals['_return_type'] = return_type
        return_annotation = '->_return_type'
        args = ','.join(args)
        body = '\n'.join(f'  {b}' for b in body)
        # Compute the text of the entire function.
        txt = f' def {name}({args}){return_annotation}:\n{body}'
        local_vars = ', '.join(locals.keys())
        txt = f"def __create_fn__({local_vars}):\n{txt}\n return {name}"
        ns = {}
        exec(txt, globals, ns)
        return ns['__create_fn__'](**locals)

    def _create_init_fn(cls, fields):
        """Generate the __init__ function for user-defined class."""
        # Reference code link:
        # https://github.com/python/cpython/blob/17b16e13bb444001534ed6fccb459084596c8bcf/Lib/dataclasses.py#L523
        def _get_data_type_from_annotation(anno):
            # Get Input(min, max) object class Input
            if isinstance(anno, EnumMeta):
                return anno
            k_type = anno if type(anno) is type else type(anno)
            # Get data type if annotation is dsl.types, special logic for Enum
            if k_type is Enum:
                return anno._enum_class if anno._enum_class else str
            return k_type.DATA_TYPE if issubclass(k_type, _Param) else k_type

        init_func_annotation = {}
        for k, v in getattr(cls, '__annotations__', {}).items():
            init_func_annotation[k] = _get_data_type_from_annotation(v)
        locals = {f'_type_{key}': init_func_annotation[key] for key, val in fields.items()}
        # Collect field defaults if val is parameter and is optional
        default_keys = set({key for key, val in fields.items() if val._has_default})
        defaults = {f'_dlft_{key}': fields[key].default for key in default_keys}
        locals.update(defaults)
        _init_param = ['self']
        for key in fields:
            param = f'{key}:_type_{key}=_dlft_{key}' if key in default_keys else f'{key}:_type_{key}'
            _init_param.append(param)
        body_lines = [f'self.{key}={key}' for key in fields]
        # If no body lines, use 'pass'.
        if not body_lines:
            body_lines = ['pass']
        return _create_fn('__init__', _init_param, body_lines, locals=locals, return_type=None)


    def _process_class(cls, all_fields):
        """Generate some functions into class."""
        setattr(cls, '__init__', _create_init_fn(cls, all_fields))
        # set some other function for your class
        # setattr(cls, '__repr__', _create_repr_fn(all_fields))
        return cls

    def _wrap(cls):
        all_fields = your_own_get_fields_function()
        return _process_class(cls, all_fields)

    return _wrap(_cls)

你可能感兴趣的:(经验总结,学习笔记,python,开发语言)