今天的任务很简单,就是熟悉一下Python
中的运算符重载。一般,我们想让自定义的类支持一些计算操作,比如会添加如下方法以期达到计算的目的:
class Vector:
def __init__(self, x=0, y=0):
self.x = x
self.y = y
def __repr__(self):
return 'Vector(%r, %r)' % (self.x, self.y)
def __abs__(self):
return hypot(self.x, self.y)
def __bool__(self):
return bool(abs(self))
def __add__(self, other):
x = self.x + other.x
y = self.y + other.y
return Vector(x, y)
def __mul__(self, scalar):
return Vector(self.x * scalar, self.y * scalar)
v1 = Vector(1, 2)
v2 = Vector(3, 4)
v3 = v1 + v2 # Vector(4, 6)
这种简单的方式没问题,但是python
对此也有一定的约束。
- 不能重载内置类型的运算符
- 不能新建运算符,只能重载现有的
-
is
,and
,or
,not
不可以重载
因为在其他语言中,程序员已经把重载运算符给滥用了。
一元运算符
- 取负数 实现
__neg__(self)
方法 - 取正
__pos__(self)
- 按位取反
__invert__(self)
...
这些需要遵循 : 始终返回一个新对象,不能修改self
举例:
# 这里的Vector 兼容迭代器
def __add__(self, other):
paris = itertools.zip_longest(self, other, fillvalue=0.0)
return Vector(a+b for a, b in Paris)
这里有个问题是 不支持左操作数是 非Vector
的对象,但支持但操作数
针对中缀运算符(a+b),Python提供了特殊的分派机制
def __radd__(self, other):
return self + other #直接委托给 __add__
其实这里面还涉及到一个问题,就是操作数是不可迭代对象或者迭代元素不支持该操作符,比如: vector + 1 或者将一个str
和 int
相加,这时候我们就得做出处理
def __add__(self, other):
try:
pairs = itertools.zip_longest(self, other, fillvalue=0.0)
return Vector(a + b for a, b in pairs)
except TypeError:
return NotImplemented
def __radd__(self, other):
return self + other
给出一个具体的示例:
from array import array
import reprlib
import math
import functools
import operator
import itertools
import numbers
class Vector:
typecode = 'd'
def __init__(self, components):
self._components = array(self.typecode, components)
def __iter__(self):
return iter(self._components)
def __repr__(self):
components = reprlib.repr(self._components)
components = components[components.find('['):-1]
return 'Vector({})'.format(components)
def __str__(self):
return str(tuple(self))
def __bytes__(self):
return (bytes([ord(self.typecode)]) +
bytes(self._components))
def __eq__(self, other):
if isinstance(other, Vector):
return (len(self) == len(other) and
all(a == b for a, b in zip(self, other)))
else:
return NotImplemented
def __hash__(self):
hashes = (hash(x) for x in self)
return functools.reduce(operator.xor, hashes, 0)
def __abs__(self):
return math.sqrt(sum(x * x for x in self))
def __bool__(self):
return bool(abs(self))
def __len__(self):
return len(self._components)
def __getitem__(self, index):
cls = type(self)
if isinstance(index, slice):
return cls(self._components[index])
elif isinstance(index, int):
return self._components[index]
else:
msg = '{.__name__} indices must be integers'
raise TypeError(msg.format(cls))
shortcut_names = 'xyzt'
def __getattr__(self, name):
cls = type(self)
if len(name) == 1:
pos = cls.shortcut_names.find(name)
if 0 <= pos < len(self._components):
return self._components[pos]
msg = '{.__name__!r} object has no attribute {!r}'
raise AttributeError(msg.format(cls, name))
def angle(self, n):
r = math.sqrt(sum(x * x for x in self[n:]))
a = math.atan2(r, self[n-1])
if (n == len(self) - 1) and (self[-1] < 0):
return math.pi * 2 - a
else:
return a
def angles(self):
return (self.angle(n) for n in range(1, len(self)))
def __format__(self, fmt_spec=''):
if fmt_spec.endswith('h'): # hyperspherical coordinates
fmt_spec = fmt_spec[:-1]
coords = itertools.chain([abs(self)],
self.angles())
outer_fmt = '<{}>'
else:
coords = self
outer_fmt = '({})'
components = (format(c, fmt_spec) for c in coords)
return outer_fmt.format(', '.join(components))
@classmethod
def frombytes(cls, octets):
typecode = chr(octets[0])
memv = memoryview(octets[1:]).cast(typecode)
return cls(memv)
def __add__(self, other):
try:
pairs = itertools.zip_longest(self, other, fillvalue=0.0)
return Vector(a + b for a, b in pairs)
except TypeError:
return NotImplemented
def __radd__(self, other):
return self + other
def __mul__(self, scalar):
if isinstance(scalar, numbers.Real):# 这里不使用具体的类,而是使用抽象基类,它涵盖了所需的全部类型
return Vector(n * scalar for n in self)
else:
return NotImplemented
def __rmul__(self, scalar):
return self * scalar
def __matmul__(self, other):
try:
return sum(a * b for a, b in zip(self, other))
except TypeError:
return NotImplemented
def __rmatmul__(self, other):
return self @ other # this only works in Python 3.5
接下来再看一下中缀运算符方法名
何时会调用就地运算方法呢?在使用增量赋值运算符中(a+=b ; a*=b)
如果没有实现就地运算方法,a+=b 其实是 a = a+b 创建新的实例,而不是就地修改左操作数