在创建一个符合Python风格的对象(1)中,定义了一个二维向量 Vector2d
类,现在以该类为基础,继续扩展,定义表示多维向量的Vector
类。
支持的功能如下:
- 基本的序列协议,
__len__
和__getitem__
- 正确表述拥有很多元素的实例
- 适当的切片支持,用于生产新的
Vector
实例 - 综合各个元素的值计算散列值
- 自定义的格式语言扩展
此外,通过 __getattr__
方法实现属性的动态存取,以此取代 Vector2d
使用的只读特性——不过,序列类型通常不会这么做。
下面来一步步实现。
1.为了支持N维向量,让构造函数接受可迭代对象
def __init__(self, components):
# 把 Vector 的分量保存在一个数组中
self._components = array(self.typecode, components)
2.为了支持迭代,使用self.components
构建一个迭代器
def __iter__(self):
return iter(self._components)
3.使用reprlib.repr()
函数获取 self._components
的有限长度表示形式(如 array('d', [0.0, 1.0, 2.0, 3.0, 4.0, ...])
def __repr__(self):
components = reprlib.repr(self._components)
# 去掉前面的 array('d' 和后面的 )。
components = components[components.find('['):-1]
return 'Vector({})'.format(components)
4.直接使用self.components
构建bytes
对象
def __bytes__(self):
return (bytes(ord([self.typecode])) + bytes(self._components))
5计算模
def __abs__(self):
"""计算各分量的平方之和,然后再使用 sqrt 方法开平方"""
return math.sqrt(sum(x * x for x in self))
6.针对frombytes
,直接把 memoryview
传给构造方法,不用像前面那样使用 *
拆包
@classmethod
def frombytes(cls, octets):
typecode = chr(octets[0])
memv = memoryview(octets[1:]).cast(typecode)
return cls(memv)
7.为了支持序列协议,实现__len__
和__getitem__
方法
def __len__(self):
return len(self._components)
def __getitem__(self, index):
"""自定义切片操作"""
cls = type(self)
# 如果 index 参数的值是 slice 对象,调用类的构造方法,使用 _components 数组的切片构建一个新 Vector 实例
if isinstance(index, slice):
return cls(self._components[index])
# 如果 index 是 int 或其他整数类型,那就返回 _components 中相应的元素
elif isinstance(index, numbers.Integral):
return self._components[index]
# 否则,抛出异常
else:
msg = '{.__name__} indices must be integers'
raise TypeError(msg.format(cls))
8.动态存取属性
因为现在是N维向量,使用Vector2d
中获取属性的方式显然太麻烦。
要想依旧使用my_obj.x
方式获取属性,可以实现__getattr__
方法,因为属性查找失败后,解释器会调用 __getattr__
方法。
# 定义几个可以获取的常用分量
shortcut__names = 'xyzt'
def __getattr__(self, name):
"""检查所查找的属性是不是 shortcut__names 中的某个字母,如果是,那么返回对应的分量。"""
cls = type(self)
# 如果属性名只有一个字母,可能是shortcut_names 中的一个
if len(name) == 1:
# 找到所在位置
pos = cls.shortcut_names.find(name)
if 0 <= pos < len(self._components):
return self._components[pos]
msg = '{.__name__ !r} object has no attribute {!r}'
raise AttributeError(msg.format(cls, name))
但是仅仅实现这样一个方法还不够,需要注意到对于实例v
,如果执行了v.x
命令,实际上v
对象就有x
属性了,因此使用v.x
不会调用__getattr__
方法。
为了避免上述情况,需要改写Vector类中设置属性的逻辑,通过自定义__setattr__
方法实现。
def __setattr__(self, name, value):
cls = type(self)
# 特别处理名称是单个字符的属性
if len(name) == 1:
# 如果 name 是 shortcut_names 中的一个,设置特殊的错误消息
if name in cls.shortcut_names:
error = 'readonly attribute {attr_name!r}'
# 如果 name 是小写字母,为所有小写字母设置一个错误消息
elif name.islower():
error = "can't set attributes 'a' to 'z' in {cls_name!r}"
# 否则,把错误消息设为空字符串
else:
error = ''
# 如果有错误消息,抛出 AttributeError
if error:
msg = error.format(cls_name=cls.__name__, attr_name=name)
raise AttributeError(msg)
# 默认情况:在超类上调用 __setattr__ 方法,提供标准行为
super().__setattr__(name, value)
在类中声明 __slots__
属性也可以防止设置新实例属性。但是不建议只为了避免创建实例属性而使用 __slots__
属性。__slots__
属性只应该用于节省内存,而且仅当内存严重不足时才应该这么做。
另外,为了将该类实例变成是可散列的,需要保持Vector
是不可变的。
9.支持散列和快速等值测试
def __eq__(self, other):
# 首先要检查两个操作数的长度是否相同,因为 zip 函数会在最短的那个操作数耗尽时停止,而且不发出警告。
# 然后再依次比较两个序列中的每一个元素
return (len(self) == len(other) and all(a == b for a, b in zip(self, other)))
def __hash__(self):
# 创建一个生成器表达式,惰性计算各个分量的散列值
hashes = (hash(x) for x in self)
# 把 hashes 提供给 reduce 函数,使用 xor 函数计算聚合的散列值;第三个参数,0 是初始值
return functools.reduce(operator.xor, hashes, 0)
10.格式化
Vector
类支持 N 个维度,所以这里使用球面坐标,格式后缀定义为'h'
。这里的难点主要是涉及数学原理,理解意思即可。具体可以查看n 维球体
def angle(self, n):
"""使用公式计算某个角坐标"""
r = math.sqrt(sum(x * x for x in self[n:]))
a = math.atan2(r, self[n-1])
if (n == len(self) - 1) and (self[-1] < 0):
return math.pi * 2 - a
else:
return a
def angles(self):
"""创建生成器表达式,按需计算所有角坐标"""
return (self.angle(n) for n in range(1, len(self)))
def __format__(self, fmt_spec=''):
if fmt_spec.endswith('h'): # 超球面坐标
fmt_spec = fmt_spec[:-1]
# 使用 itertools.chain 函数生成生成器表达式,无缝迭代向量的模和各个角坐标
coords = itertools.chain([abs(self)], self.angles())
outer_fmt = '<{}>'
else:
coords = self
outer_fmt = '({})'
components = (format(c, fmt_spec) for c in coords)
return outer_fmt.format(', '.join(components))
下面给出完整代码
from array import array
import reprlib
import math
import numbers
import functools
import operator
import itertools
class Vector:
typecode = 'd'
def __init__(self, components):
self._components = array(self.typecode, components)
def __iter__(self):
return iter(self._components)
def __repr__(self):
components = reprlib.repr(self._components)
components = components[components.find('['):-1]
return 'Vector({})'.format(components)
def __str__(self):
return str(tuple(self))
def __bytes__(self):
return (bytes([ord(self.typecode)]) +
bytes(self._components))
def __eq__(self, other):
return (len(self) == len(other) and
all(a == b for a, b in zip(self, other)))
def __hash__(self):
hashes = (hash(x) for x in self)
return functools.reduce(operator.xor, hashes, 0)
def __abs__(self):
return math.sqrt(sum(x * x for x in self))
def __bool__(self):
return bool(abs(self))
def __len__(self):
return len(self._components)
def __getitem__(self, index):
cls = type(self)
if isinstance(index, slice):
return cls(self._components[index])
elif isinstance(index, numbers.Integral):
return self._components[index]
else:
msg = '{.__name__} indices must be integers'
raise TypeError(msg.format(cls))
shortcut_names = 'xyzt'
def __getattr__(self, name):
cls = type(self)
if len(name) == 1:
pos = cls.shortcut_names.find(name)
if 0 <= pos < len(self._components):
return self._components[pos]
msg = '{.__name__!r} object has no attribute {!r}'
raise AttributeError(msg.format(cls, name))
def angle(self, n):
r = math.sqrt(sum(x * x for x in self[n:]))
a = math.atan2(r, self[n-1])
if (n == len(self) - 1) and (self[-1] < 0):
return math.pi * 2 - a
else:
return a
def angles(self):
return (self.angle(n) for n in range(1, len(self)))
def __format__(self, fmt_spec=''):
if fmt_spec.endswith('h'): # 超球面坐标
fmt_spec = fmt_spec[:-1]
coords = itertools.chain([abs(self)],
self.angles())
outer_fmt = '<{}>'
else:
coords = self
outer_fmt = '({})'
components = (format(c, fmt_spec) for c in coords)
return outer_fmt.format(', '.join(components))
@classmethod
def frombytes(cls, octets):
typecode = chr(octets[0])
memv = memoryview(octets[1:]).cast(typecode)
return cls(memv)