创建一个符合Python风格的对象(2)

在创建一个符合Python风格的对象(1)中,定义了一个二维向量 Vector2d 类,现在以该类为基础,继续扩展,定义表示多维向量的Vector类。
支持的功能如下:

  • 基本的序列协议,__len____getitem__
  • 正确表述拥有很多元素的实例
  • 适当的切片支持,用于生产新的Vector实例
  • 综合各个元素的值计算散列值
  • 自定义的格式语言扩展

此外,通过 __getattr__ 方法实现属性的动态存取,以此取代 Vector2d 使用的只读特性——不过,序列类型通常不会这么做。

下面来一步步实现。

1.为了支持N维向量,让构造函数接受可迭代对象

    def __init__(self, components):
        # 把 Vector 的分量保存在一个数组中
        self._components = array(self.typecode, components)

2.为了支持迭代,使用self.components构建一个迭代器

    def __iter__(self):
        return iter(self._components)

3.使用reprlib.repr() 函数获取 self._components 的有限长度表示形式(如 array('d', [0.0, 1.0, 2.0, 3.0, 4.0, ...])

    def __repr__(self):
        components = reprlib.repr(self._components)
        # 去掉前面的 array('d' 和后面的 )。
        components = components[components.find('['):-1]
        return 'Vector({})'.format(components)

4.直接使用self.components构建bytes对象

    def __bytes__(self):
        return (bytes(ord([self.typecode])) + bytes(self._components))

5计算模

    def __abs__(self):
        """计算各分量的平方之和,然后再使用 sqrt 方法开平方"""
        return math.sqrt(sum(x * x for x in self))

6.针对frombytes,直接把 memoryview 传给构造方法,不用像前面那样使用 * 拆包

@classmethod
def frombytes(cls, octets):
    typecode = chr(octets[0])
    memv = memoryview(octets[1:]).cast(typecode)
    return cls(memv) 

7.为了支持序列协议,实现__len____getitem__方法

    def __len__(self):
        return len(self._components)

    def __getitem__(self, index):
        """自定义切片操作"""
        cls = type(self)
        #  如果 index 参数的值是 slice 对象,调用类的构造方法,使用 _components 数组的切片构建一个新 Vector 实例
        if isinstance(index, slice):
            return cls(self._components[index])
        # 如果 index 是 int 或其他整数类型,那就返回 _components 中相应的元素
        elif isinstance(index, numbers.Integral):
            return self._components[index]
        # 否则,抛出异常
        else:
            msg = '{.__name__} indices must be integers'
            raise TypeError(msg.format(cls))

8.动态存取属性
因为现在是N维向量,使用Vector2d中获取属性的方式显然太麻烦。
要想依旧使用my_obj.x方式获取属性,可以实现__getattr__方法,因为属性查找失败后,解释器会调用 __getattr__ 方法。

    # 定义几个可以获取的常用分量
    shortcut__names = 'xyzt'

    def __getattr__(self, name):
        """检查所查找的属性是不是 shortcut__names 中的某个字母,如果是,那么返回对应的分量。"""
        cls = type(self)
        # 如果属性名只有一个字母,可能是shortcut_names 中的一个
        if len(name) == 1:
            # 找到所在位置
            pos = cls.shortcut_names.find(name)
            if 0 <= pos < len(self._components):
                return self._components[pos]
        msg = '{.__name__ !r} object has no attribute {!r}'
        raise AttributeError(msg.format(cls, name))

但是仅仅实现这样一个方法还不够,需要注意到对于实例v,如果执行了v.x命令,实际上v对象就有x属性了,因此使用v.x不会调用__getattr__方法。
为了避免上述情况,需要改写Vector类中设置属性的逻辑,通过自定义__setattr__方法实现。

def __setattr__(self, name, value):
    cls = type(self)
    # 特别处理名称是单个字符的属性
    if len(name) == 1:
        # 如果 name 是 shortcut_names 中的一个,设置特殊的错误消息
        if name in cls.shortcut_names:
            error = 'readonly attribute {attr_name!r}'
        # 如果 name 是小写字母,为所有小写字母设置一个错误消息
        elif name.islower():
            error = "can't set attributes 'a' to 'z' in {cls_name!r}"
        # 否则,把错误消息设为空字符串
        else:
            error = ''
        #  如果有错误消息,抛出 AttributeError
        if error:
            msg = error.format(cls_name=cls.__name__, attr_name=name)
            raise AttributeError(msg)
    # 默认情况:在超类上调用 __setattr__ 方法,提供标准行为
    super().__setattr__(name, value)

在类中声明 __slots__ 属性也可以防止设置新实例属性。但是不建议只为了避免创建实例属性而使用 __slots__ 属性。__slots__ 属性只应该用于节省内存,而且仅当内存严重不足时才应该这么做。
另外,为了将该类实例变成是可散列的,需要保持Vector是不可变的。

9.支持散列和快速等值测试

    def __eq__(self, other):
        # 首先要检查两个操作数的长度是否相同,因为 zip 函数会在最短的那个操作数耗尽时停止,而且不发出警告。
        # 然后再依次比较两个序列中的每一个元素
        return (len(self) == len(other) and all(a == b for a, b in zip(self, other)))

    def __hash__(self):
        # 创建一个生成器表达式,惰性计算各个分量的散列值
        hashes = (hash(x) for x in self)
        # 把 hashes 提供给 reduce 函数,使用 xor 函数计算聚合的散列值;第三个参数,0 是初始值
        return functools.reduce(operator.xor, hashes, 0)

10.格式化
Vector 类支持 N 个维度,所以这里使用球面坐标,格式后缀定义为'h'。这里的难点主要是涉及数学原理,理解意思即可。具体可以查看n 维球体

def angle(self, n):
    """使用公式计算某个角坐标"""
    r = math.sqrt(sum(x * x for x in self[n:]))
    a = math.atan2(r, self[n-1])
    if (n == len(self) - 1) and (self[-1] < 0):
        return math.pi * 2 - a
    else:
        return a

def angles(self):
    """创建生成器表达式,按需计算所有角坐标"""
    return (self.angle(n) for n in range(1, len(self)))

def __format__(self, fmt_spec=''):
    if fmt_spec.endswith('h'):  # 超球面坐标
        fmt_spec = fmt_spec[:-1]
        # 使用 itertools.chain 函数生成生成器表达式,无缝迭代向量的模和各个角坐标
        coords = itertools.chain([abs(self)], self.angles())
        outer_fmt = '<{}>'
    else:
        coords = self
        outer_fmt = '({})'
    components = (format(c, fmt_spec) for c in coords)
    return outer_fmt.format(', '.join(components))

下面给出完整代码

from array import array
import reprlib
import math
import numbers
import functools
import operator
import itertools


class Vector:
    typecode = 'd'

    def __init__(self, components):
        self._components = array(self.typecode, components)

    def __iter__(self):
        return iter(self._components)

    def __repr__(self):
        components = reprlib.repr(self._components)
        components = components[components.find('['):-1]
        return 'Vector({})'.format(components)

    def __str__(self):
        return str(tuple(self))

    def __bytes__(self):
        return (bytes([ord(self.typecode)]) +
                bytes(self._components))

    def __eq__(self, other):
        return (len(self) == len(other) and
                all(a == b for a, b in zip(self, other)))

    def __hash__(self):
        hashes = (hash(x) for x in self)
        return functools.reduce(operator.xor, hashes, 0)

    def __abs__(self):
        return math.sqrt(sum(x * x for x in self))

    def __bool__(self):
        return bool(abs(self))

    def __len__(self):
        return len(self._components)

    def __getitem__(self, index):
        cls = type(self)
        if isinstance(index, slice):
            return cls(self._components[index])
        elif isinstance(index, numbers.Integral):
            return self._components[index]
        else:
            msg = '{.__name__} indices must be integers'
            raise TypeError(msg.format(cls))

    shortcut_names = 'xyzt'

    def __getattr__(self, name):
        cls = type(self)
        if len(name) == 1:
            pos = cls.shortcut_names.find(name)
            if 0 <= pos < len(self._components):
                return self._components[pos]
        msg = '{.__name__!r} object has no attribute {!r}'
        raise AttributeError(msg.format(cls, name))

    def angle(self, n):
        r = math.sqrt(sum(x * x for x in self[n:]))
        a = math.atan2(r, self[n-1])
        if (n == len(self) - 1) and (self[-1] < 0):
            return math.pi * 2 - a
        else:
            return a

    def angles(self):
        return (self.angle(n) for n in range(1, len(self)))

    def __format__(self, fmt_spec=''):
        if fmt_spec.endswith('h'):  # 超球面坐标
            fmt_spec = fmt_spec[:-1]
            coords = itertools.chain([abs(self)],
                                     self.angles())
            outer_fmt = '<{}>'
        else:
            coords = self
            outer_fmt = '({})'
        components = (format(c, fmt_spec) for c in coords)
        return outer_fmt.format(', '.join(components))

    @classmethod
    def frombytes(cls, octets):
        typecode = chr(octets[0])
        memv = memoryview(octets[1:]).cast(typecode)
        return cls(memv)

你可能感兴趣的:(创建一个符合Python风格的对象(2))