[python] 使用装饰器强制要求函数的参数类型和返回值类型

import inspect
from functools import wraps


class enforce_types:
    """
    验证函数参数类型
    并且可选是否验证函数返回值类型, 默认不验证
    """

    def __init__(self, check_return=False):
        self.check_return = check_return

    def __call__(self, func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            sig = inspect.signature(func)

            # 验证位置参数类型
            for i, (arg, arg_type) in enumerate(zip(args, sig.parameters.values())):
                if not isinstance(arg, arg_type.annotation):
                    raise TypeError(
                        f"Expected {arg_type.annotation} but got {type(arg)} for argument '{arg_type.name}' "
                        f"in function '{func.__module__}.{func.__name__}' at position {i + 1}")

            # 验证关键字参数类型
            for arg_name, arg_type in sig.parameters.items():
                if arg_name != 'return' and arg_name in kwargs and not isinstance(kwargs[arg_name],
                                                                                  arg_type.annotation):
                    raise TypeError(
                        f"Expected {arg_type.annotation} but got {type(kwargs[arg_name])} for keyword argument '{arg_name}' "
                        f"in function '{func.__module__}.{func.__name__}'")

            result = func(*args, **kwargs)

            if self.check_return:
                return_type = func.__annotations__.get('return')
                if return_type and not isinstance(result, return_type):
                    raise TypeError(f"Return value '{result}' doesn't match the expected type '{return_type.__name__}'")

            return result

        return wrapper


@enforce_types()
def func1(a: int, b: str) -> float:
    result = a + int(b)
    return result / 2


@enforce_types(check_return=True)
def func2(a: int, b: str) -> float:
    result = a + int(b)
    return result  # 返回 int类型


if __name__ == '__main__':
    pass
    # func1(1, "2")  # 正常

    # func1(1, 2)  # 参数类型错误, 将TypeError

    func2(1, "2")  # 返回值错误, 将TypeError


你可能感兴趣的:(python,python,开发语言)