python 重载中缀运算符

重载向量加法运算符 +

还来看我们之前的Vector例子:

from array import array
import reprlib
import math
import numbers
import functools
import operator
import itertools
import pysnooper


class Vector:
    typecode = 'd'
    shortcut_names = 'xyzt'

    def __init__(self, components):
        self._components = array(self.typecode, components)

    def __iter__(self):
        return iter(self._components)

    def __repr__(self):
        components = reprlib.repr(self._components)
        components = components[components.find('['):-1]
        return 'Vector({})'.format(components)

    def __str__(self):
        return str(tuple(self))

    def __bytes__(self):
        return (bytes([ord(self.typecode)]) +
                bytes(self._components))

    def __eq__(self, other):
        # 方案一、
        # if len(self) != len(other):
        #     return False
        # for a, b in zip(self, other):
        #     if a != b:
        #         return False
        # return True

        # 方案二、
        return len(self) == len(other) and all(a == b for a, b in zip(self, other))

    def __abs__(self):
        return math.sqrt(sum(x * x for x in self))

    def __bool__(self):
        return bool(abs(self))

    @classmethod
    def frombytes(cls, octets):
        typecode = chr(octets[0])
        memv = memoryview(octets[1:]).cast(typecode)
        return cls(memv)

    def __len__(self):
        return len(self._components)

    def __getitem__(self, index):
        cls = type(self)
        # print(cls, type(cls))
        if isinstance(index, slice):
            return cls(self._components[index])
        elif isinstance(index, numbers.Integral):
            return self._components[index]
        else:
            msg = '{cls.__name__} indices must be integers'
            raise TypeError(msg.format(cls=cls))

    def __getattr__(self, name):
        cls = type(self)
        if len(name) == 1:
            print(cls.shortcut_names)
            pos = cls.shortcut_names.find(name)
            if 0 <= pos < len(self._components):
                return self._components[pos]
        msg = '{.__name__!r} object has no attribute {!r}'
        raise AttributeError(msg.format(cls, name))

    def __setattr__(self, name, value):
        cls = type(self)
        if len(name) == 1:
            if name in cls.shortcut_names:
                error = 'readonly attribute {attr_name!r}'
            elif name.islower():
                error = "can't set attributes 'a' to 'z' in {cls_name!r}"
            else:
                error = ''
            if error:
                msg = error.format(cls_name=cls.__name__, attr_name=name)
                raise AttributeError(msg)
        super().__setattr__(name, value)

    def __hash__(self):
        hashes = (hash(x) for x in self._components)
        return functools.reduce(operator.xor, hashes, 0)

    def angle(self, n):
        r = math.sqrt(sum(x * x for x in self[n:]))
        a = math.atan2(r, self[n - 1])
        if (n == len(self) - 1) and (self[-1] < 0):
            return math.pi * 2 - a
        else:
            return a

    def angles(self):
        return (self.angle(n) for n in range(1, len(self)))

    def __format__(self, fmt_spec=''):
        if fmt_spec.endswith('h'):
            fmt_spec = fmt_spec[:-1]
            coords = itertools.chain([abs(self)], self.angles())
            outer_fmt = '<{}>'
        else:
            coords = self
            outer_fmt = '({})'
        components = (format(c, fmt_spec) for c in coords)
        return outer_fmt.format(', '.join(components))

    @pysnooper.snoop()
    def __add__(self, other):
        """
        实现向量的相加:
        :param other: 
        :return: 
        """
        try:
            pairs = itertools.zip_longest(self, other, fillvalue=0.0)
            return Vector(a + b for a, b in pairs)
        except TypeError:
            return NotImplemented

    def __radd__(self, other):
        """
        加法的交换规则,直接委托给__add__方法。
        :param other: 
        :return: 
        """
        return self + other


if __name__ == "__main__":
    # v7 = Vector(range(7))
    # print(v7)
    # print(v7[-1])
    # print(v7[1:4])
    # print(v7.x)
    # v7.K = 10
    # print(v7.K)
    # print(v7)
    v7 = Vector((1, 3))
    # v8 = Vector((3, 4, 5))
    # print(v7 + v8)
    print(v7 + '123')

通过协议接口可知,上面的代码的类是序列,我们重写的__add__方法,使其能够实现对于不同的长度的Vector相加,如: Vector((1, 3)) + Vector((3, 4, 5)) 相加。

如果中缀运算符方法抛出异常,就终止了运算符分派机制。对 TypeError 来 说,通常最好将其捕获,然后返回 NotImplemented。这样,解释器会尝试调 用反向运算符方法,如果操作数是不同的类型,对调之后,反向运算符方法 可能会正确计算。

重载标量乘法运算符: *

from array import array
import reprlib
import math
import numbers
import functools
import operator
import itertools
import pysnooper


class Vector:
    typecode = 'd'
    shortcut_names = 'xyzt'

    def __init__(self, components):
        self._components = array(self.typecode, components)

    def __iter__(self):
        return iter(self._components)

    def __repr__(self):
        components = reprlib.repr(self._components)
        components = components[components.find('['):-1]
        return 'Vector({})'.format(components)

    def __str__(self):
        return str(tuple(self))

    def __bytes__(self):
        return (bytes([ord(self.typecode)]) +
                bytes(self._components))

    def __eq__(self, other):
        # 方案一、
        # if len(self) != len(other):
        #     return False
        # for a, b in zip(self, other):
        #     if a != b:
        #         return False
        # return True

        # 方案二、
        return len(self) == len(other) and all(a == b for a, b in zip(self, other))

    def __abs__(self):
        return math.sqrt(sum(x * x for x in self))

    def __bool__(self):
        return bool(abs(self))

    @classmethod
    def frombytes(cls, octets):
        typecode = chr(octets[0])
        memv = memoryview(octets[1:]).cast(typecode)
        return cls(memv)

    def __len__(self):
        return len(self._components)

    def __getitem__(self, index):
        cls = type(self)
        # print(cls, type(cls))
        if isinstance(index, slice):
            return cls(self._components[index])
        elif isinstance(index, numbers.Integral):
            return self._components[index]
        else:
            msg = '{cls.__name__} indices must be integers'
            raise TypeError(msg.format(cls=cls))

    def __getattr__(self, name):
        cls = type(self)
        if len(name) == 1:
            print(cls.shortcut_names)
            pos = cls.shortcut_names.find(name)
            if 0 <= pos < len(self._components):
                return self._components[pos]
        msg = '{.__name__!r} object has no attribute {!r}'
        raise AttributeError(msg.format(cls, name))

    def __setattr__(self, name, value):
        cls = type(self)
        if len(name) == 1:
            if name in cls.shortcut_names:
                error = 'readonly attribute {attr_name!r}'
            elif name.islower():
                error = "can't set attributes 'a' to 'z' in {cls_name!r}"
            else:
                error = ''
            if error:
                msg = error.format(cls_name=cls.__name__, attr_name=name)
                raise AttributeError(msg)
        super().__setattr__(name, value)

    def __hash__(self):
        hashes = (hash(x) for x in self._components)
        return functools.reduce(operator.xor, hashes, 0)

    def angle(self, n):
        r = math.sqrt(sum(x * x for x in self[n:]))
        a = math.atan2(r, self[n - 1])
        if (n == len(self) - 1) and (self[-1] < 0):
            return math.pi * 2 - a
        else:
            return a

    def angles(self):
        return (self.angle(n) for n in range(1, len(self)))

    def __format__(self, fmt_spec=''):
        if fmt_spec.endswith('h'):
            fmt_spec = fmt_spec[:-1]
            coords = itertools.chain([abs(self)], self.angles())
            outer_fmt = '<{}>'
        else:
            coords = self
            outer_fmt = '({})'
        components = (format(c, fmt_spec) for c in coords)
        return outer_fmt.format(', '.join(components))

    @pysnooper.snoop()
    def __add__(self, other):
        """
        实现向量的相加:
        :param other:
        :return:
        """
        try:
            pairs = itertools.zip_longest(self, other, fillvalue=0.0)
            return Vector(a + b for a, b in pairs)
        except TypeError:
            return NotImplemented

    def __radd__(self, other):
        """
        加法的交换规则,直接委托给__add__方法。
        :param other:
        :return:
        """
        return self + other

    def __mul__(self, other):
        """
        乘积计算
        :param other:
        :return:
        """
        if isinstance(other, numbers.Real):
            return Vector(n * other for n in self)

        else:
            return NotImplemented

    def __rmul__(self, other):
        return self * other


if __name__ == "__main__":
    # v7 = Vector(range(7))
    # print(v7)
    # print(v7[-1])
    # print(v7[1:4])
    # print(v7.x)
    # v7.K = 10
    # print(v7.K)
    # print(v7)
    v7 = Vector((1, 3))
    v8 = Vector((3, 4, 5))
    print(v7 + v8)
    print(2 * v7)
    # print(v7 + '123')
    

这里实现了 __mul__ 方法 乘积运算,__rmul__ 反向乘积,直接委托给__mul__

下面是其他的一些中缀运算符:
python 重载中缀运算符_第1张图片
python 重载中缀运算符_第2张图片

        def __matmul__(self, other):
        """
        点积运算符:这里实现那个矩阵相乘
        :param other:
        :return:
        """
        try:
            return sum(a * b for a, b in zip(self, other))
        except TypeError:
            return NotImplemented

    def __rmatmul__(self, other):
        return self @ other

对于点积运算符的实现(python3.5才有的)。

众多比较运算符也是一类中缀运算符,但是规则稍有不同。

众多比较运算符
Python 解释器对众多比较运算符(==、!=、>、<、>=、<=)的处理与前文类似,不过在两 个方面有重大区别。

1、正向和反向调用使用的是同一系列方法。这方面的规则如表13-2所示。例如,对==来说, 正向和反向调用都是 __eq__ 方法,只是把参数对调了;而正向的 __gt__ 方法调用的是 反向的 __lt__ 方法,并把参数对调。

2、对 ==!= 来说,如果反向调用失败,Python 会比较对象的 ID,而不抛出 TypeError

python 重载中缀运算符_第3张图片
Python 2 之后的比较运算符后备机制都变了。对于 __ne__,现在 Python 3 返回 结果是对 __eq__ 结果的取反。对于排序比较运算符,Python 3 抛出 TypeError, 并把错误消息设为 'unorderable types: int() < tuple()'。在 Python 2 中, 这些比较的结果很怪异,会考虑对象的类型ID,而且无规律可循。然而, 比较整数和元组确实没有意义,因此此时抛出 TypeError 是这门语言的一大 进步。

    def __eq__(self, other):
        # 方案一、
        # if len(self) != len(other):
        #     return False
        # for a, b in zip(self, other):
        #     if a != b:
        #         return False
        # return True

        # 方案二、
        # return len(self) == len(other) and all(a == b for a, b in zip(self, other))
        # 方案三、
        if isinstance(other, Vector):
            return len(self) == len(other) and all(a == b for a, b in zip(self, other))
        else:
            return NotImplemented

增量赋值运算符
Vector 类已经支持增量赋值运算符 += 和 *= 了,如示例 :
python 重载中缀运算符_第4张图片
如果一个类没有实现表 13-1 列出的就地运算符,增量赋值运算符只是语法糖:
a += b 的 作用与 a = a + b 完全一样。对不可变类型来说,这是预期的行为,而且,如果定义了 __add__ 方法的话,不用编写额外的代码,+= 就能使用。

然而,如果实现了就地运算符方法,例如 __iadd__,计算 a += b 的结果时会调用就地运算 符方法。这种运算符的名称表明,它们会就地修改左操作数,而不会创建新对象作为结果。

不可变类型,如 Vector 类,一定不能实现就地特殊方法。这是明显的事实, 不过还是值得提出来。

总结:首先说明了 Python 对运算符重载施加的一些限制:禁止重载内置类型的运算符,而且 限于重载现有的运算符,不过有几个例外(is、and、or、not)。

你可能感兴趣的:(python,特殊方法,python)