Python 自定义类中的函数和运算符重载

Python 自定义类中的函数和运算符重载_第1张图片

原文:https://realpython.com/operator-function-overloading/

如果你曾在字符串(str)对象上进行过 + 或 * 运算,你一定注意到它跟整数或浮点数对象的行为差异:

>>> # 加法
>>> 1 + 2
3

>>> # 拼接字符串
>>> 'Real' + 'Python'
'RealPython'


>>> # 乘法
>>> 3 * 2
6

>>> # 重复字符串
>>> 'Python' * 3
'PythonPythonPython'

你可能想知道,为什么同一个内置的操作符或函数,作用在不同类的对象上面会展现出不同的行为。这种现象被称为运算符重载或者函数重载。本文将帮助你理解这个机制,今后你可以运用到你的自定义类中,让你的编码更 Pythonic 。

以下你将学到:

  • Python 处理运算符和内置函数的API
  • len() 以及其它内置函数背后的“秘密”
  • 如何让你的类可以使用运算符进行运算
  • 如何让你的类与内置函数的操作保持兼容及行为一致

此外,后面还将提供一个具体的类的实例。它的实例对象的行为与运算符及内置函数的行为保持一致。

Python 数据模型

假设,你有一个用来表示在线购物车的类,包含一个购物车(列表)和一名顾客(字符串或者其它表示顾客类的实例)。

这种情形下,很自然地需要获取购物车的列表长度。Python 的新手可能会考虑在他的类中实现一个叫 get_cart_len() 的方法来处理这个需求。实际上,你只需要配置一下,当我们传入购物车实例对象时,使用内置函数 len() 就可以返回购物车的长度。

另一个场景中,我们可能需要添加某些商品到购物车。某些新手同样会想要实现一个叫 append_to_cart() 的方法来处理获取一个项,并将它添加到购物车列表中。其实你只需配置一下 + 运算符就可以实现将项目添加到购物车列表的操作。

Python 使用特定的方法来处理这些过程。这些特殊的方法都有一个特定的命名约定,以双下划线开始,后面跟命名标识符,最后以双下划线结束。

本质上讲,每一种内置的函数或运算符都对应着对象的特定方法。比如,len()方法对应内置 len() 函数,而 add() 方法对应 + 运算符。

默认情况下,绝大多数内置函数和运算符不会在你的类中工作。你需要在类定义中自己实现对应的特定方法,实例对象的行为才会和内置函数和运算符行为保持一致。当你完成这个过程,内置函数或运算符的操作才会如预期一样工作

这些正是数据模型帮你完成的过程(文档的第3部分)。该文档中列举了所有可用的特定方法,并提供了重载它们的方法以便你在自己的对象中使用。

我们看看这意味着什么。

趣事:由于这些方法的特殊命名方式,它们又被称作 dunder 方法,是双下划线方法的简称。有时候它们也被称作特殊方法或魔术方法。我们更喜欢 dunder 方法这个叫法。

len() 和 [] 的内部运行机制

每一个 Python 类都为内置函数或运算符定义了自己的行为方式。当你将某个实例对象传入内置函数或使用运算符时,实际上等同于调用带相应参数的特定方法。

如果有一个内置函数,func(),它关联的特定方法是 func(),Python 解释器解释为类似于 obj.func() 的函数调用,obj 就是实例对象。如果是运算符操作,比如 opr ,关联的特定方法是 opr(),Python 将 obj1 obj2 解释为类似于 obj1.opr(obj2) 的形式。
所以,当你在实例对象上调用 len() 时,Python 将它处理为 obj.len() 调用。当你在可迭代对象上使用 [] 运算符来获取指定索引位置上的值时,Python 将它处理为 itr.getitem(index),itr 表示可迭代对象,index 表示你要索引的位置。

因此,在你定义自己的类时,你可以重写关联的函数或运算符的行为。因为,Python 在后台调用的是你定义的方法。我们看个例子来理解这种机制:

>>> a = 'Real Python'
>>> b = ['Real', 'Python']
>>> len(a)
11
>>> a.__len__()
11
>>> b[0]
'Real'
>>> b.__getitem__(0)
'Real'

如你所见,当你分别使用函数或者关联的特定方法时,你获得了同样的结果。实际上,如果你使用内置函数 dir() 列出一个字符串对象的所有方法和属性,你也可以在里面找到这些特定方法:

>>> dir(a)
['__add__',
 '__class__',
 '__contains__',
 '__delattr__',
 '__dir__',
 ...,
 '__iter__',
 '__le__',
 '__len__',
 '__lt__',
 ...,
 'swapcase',
 'title',
 'translate',
 'upper',
 'zfill']

如果内置函数或运算符的行为没有在类中特定方法中定义,你会得到一个类型错误。
那么,如何在你的类中使用特定方法呢?

重载内置函数

数据模型中定义的大多数特定方法都可以用来改变 len, abs, hash, divmod 等内置函数的行为。你只需要在你的类中定义好关联的特定方法就好了。下面举几个栗子:

用 len() 函数获取你对象的长度

要更改 len() 的行为,你需要在你的类中定义 len() 这个特定方法。每次你传入类的实例对象给 len() 时,它都会通过你定义的 len() 来返回结果。下面,我们来实现前面 order 类的 len() 函数的行为:

>>> class Order:
...     def __init__(self, cart, customer):
...         self.cart = list(cart)
...         self.customer = customer
...
...     def __len__(self):
...         return len(self.cart)
...
>>> order = Order(['banana', 'apple', 'mango'], 'Real Python')
>>> len(order)
3

如你所见,你现在可以直接使用 len() 来获得购物车列表长度。相比 order.get_cart_len() 调用方式,使用 len() 更符合“队列长度”这个直观表述,你的代码调用更 Pythonic,更符合直观习惯。如果你没有定义 len() 这个方法,当你调用 len() 时就会返回一个类型错误:

>>> class Order:
...     def __init__(self, cart, customer):
...         self.cart = list(cart)
...         self.customer = customer
...
>>> order = Order(['banana', 'apple', 'mango'], 'Real Python')
>>> len(order)  # Calling len when no __len__
Traceback (most recent call last):
  File "", line 1, in 
TypeError: object of type 'Order' has no len()

此外,当你重载 len() 时,你需要记住的是 Python 需要该函数返回的是一个整数值,如果你的方法函数返回的是除整数外的其它值,也会报类型错误(TypeError)。此做法很可能是为了与 len() 通常用于获取序列的长度这种用途(序列的长度只能是整数)保持一致:

>>> class Order:
...     def __init__(self, cart, customer):
...         self.cart = list(cart)
...         self.customer = customer
...
...     def __len__(self):
...         return float(len(self.cart))  # Return type changed to float
...
>>> order = Order(['banana', 'apple', 'mango'], 'Real Python')
>>> len(order)
Traceback (most recent call last):
  File "", line 1, in 
TypeError: 'float' object cannot be interpreted as an integer

让你的对象提供 abs() 运算

你可以通过定义类的 abs() 方法来控制内置函数 abs() 作用于实例对象时的行为。abs() 函数对返回值没有约束,只是在你的类没有定义关联的特定方法时会得到类型错误。
在表示二维空间向量的类中, abs() 函数可以被用来获取向量的长度。下面演示如何做:

>>> class Vector:
...     def __init__(self, x_comp, y_comp):
...         self.x_comp = x_comp
...         self.y_comp = y_comp
...
...     def __abs__(self):
...         return (x * x + y * y) ** 0.5
...
>>> vector = Vector(3, 4)
>>> abs(vector)
5.0

这样表述为“向量的绝对值”相对于 vector.get_mag() 这样的调用会显得更直观。

通过 str() 提供更加美观的对象输出格式

内置函数 str() 通常用于将类实例转换为字符串对象,更准确地说,为普通用户提供更友好的字符串表示方式,而不仅仅是面向程序员。通过在你的类定义中实现 str() 特定方法你可以自定义你的对象使用 str() 输出时的字符串输出格式。此外,当你使用 print() 输出你的对象时 Python 实际上调用的也是 str() 方法。
我们将在 Vector 类中实现 Vector 对象的输出格式为 xi+yj。负的 Y 方向的分量输出格式使用迷你语言来处理:

>>> class Vector:
...     def __init__(self, x_comp, y_comp):
...         self.x_comp = x_comp
...         self.y_comp = y_comp
...
...     def __str__(self):
...         # By default, sign of +ve number is not displayed
...         # Using `+`, sign is always displayed
...         return f'{self.x_comp}i{self.y_comp:+}j'
...
>>> vector = Vector(3, 4)
>>> str(vector)
'3i+4j'
>>> print(vector)
3i+4j

需要注意的是 str() 必须返回一个字符串对象,如果我们返回值的类型为非字符串类型,将会报类型错误。

使用 repr() 来显示你的对象

repr() 内置函数通常用来获取对象的可解析字符串表示形式。如果一个对象是可解析的,这意味着使用 repr 再加上 eval() 此类函数,Python 就可以通过字符串表述来重建对象。要定义 repr() 函数的行为,你可以通过定义 repr() 方法来实现。

这也是 Python 在 REPL(交互式)会话中显示一个对象所使用的方式 。如果 repr() 方法没有定义,你在 REPL 会话中试图输出一个对象时,会得到类似 <main.Vector object at 0x...> 这样的结果。我们来看 Vector 类这个例子的实际运行情况:

>>> class Vector:
...     def __init__(self, x_comp, y_comp):
...         self.x_comp = x_comp
...         self.y_comp = y_comp
...
...     def __repr__(self):
...         return f'Vector({self.x_comp}, {self.y_comp})'
...

>>> vector = Vector(3, 4)
>>> repr(vector)
'Vector(3, 4)'

>>> b = eval(repr(vector))
>>> type(b), b.x_comp, b.y_comp
(__main__.Vector, 3, 4)

>>> vector  # Looking at object; __repr__ used
'Vector(3, 4)'

注意:如果 str() 方法没有定义,当在对象上调用 str() 函数,Python 会使用 repr() 方法来代替,如果两者都没有定义,默认输出为 <main.Vector ...>。在交互环境中 repr() 是用来显示对象的唯一方式,类定义中缺少它,只会输出 <main.Vector ...>。
尽管,这是官方推荐的两者行为的区别,但在很多流行的库中实际上都忽略了这种行为差异,而交替使用它们。
关于 repr() 和 str() 的问题推荐阅读 Dan Bader 写的这篇比较出名的文章:Python 字符串转换 101:为什么每个类都需要定义一个 “repr”

使用 bool() 提供布尔值判断

内置的函数 bool() 可以用来提供真值检测,要定义它的行为,你可以通过定义 bool() (Python 2.x版是 nonzero())特定方法来实现。
此处的定义将供所有需要判断真值的上下文(比如 if 语句)中使用。比如,前面定义的 Order 类,某个实例中可能需要判断购物车长度是否为非零。用来检测是否继续处理订单:

>>> class Order:
...     def __init__(self, cart, customer):
...         self.cart = list(cart)
...         self.customer = customer
...
...     def __bool__(self):
...         return len(self.cart) > 0
...
>>> order1 = Order(['banana', 'apple', 'mango'], 'Real Python')
>>> order2 = Order([], 'Python')

>>> bool(order1)
True
>>> bool(order2)
False

>>> for order in [order1, order2]:
...     if order:
...         print(f"{order.customer}'s order is processing...")
...     else:
...         print(f"Empty order for customer {order.customer}")

Real Python's order is processing...
Empty order for customer Python

注意:如果类的 bool() 特定方法没有定义, len() 方法返回值将会用来做真值判断,如果是一个非零值则为真,零值为假。如果两个方法都没有被定义,此类的所有实例检测都会被判断为真值。

还有更多用来重载内置函数的特定方法,你可以在官方文档中找到它们的用法,下面我们开始讨论运算符重载的问题。

重载内置运算符

要改变一个运算符的行为跟改变函数的行为一样,很简单。你只需在类中定义好对应的特定方法,运算符就会按照你设定的方式运行。

跟上面的特定方法不同的是,这些方法定义中,除了接收自身(self)这个参数外,它还需要另一个参数
下面,我们看几个例子。

让你的对象能够使用 + 运算符做加法运算

与 + 运算符对应的特定方法是 add() 方法。添加一个自定义的 add() 方法将会改变该运算符的行为。建议让 add() 方法返回一个新的实例对象而不要修改调用的实例本身。在 Python 中,这种行为非常常见:

>>> a = 'Real'
>>> a + 'Python'  # Gives new str instance
'RealPython'
>>> a  # Values unchanged
'Real'
>>> a = a + 'Python'  # Creates new instance and assigns a to it
>>> a
'RealPython'

你会发现上面例子中字符串对象进行 + 运算会返回一个新的字符串,原来的字符串本身并没有被改变。要改变这种方式,我们需要显式地将生成的新实例赋值给 a。

我们将在 Order 类中实现通过 + 运算符来将新的项目添加到购物车中。我们遵循推荐的方法,运算后返回一个新的 Order 实例对象而不是直接更改现有实例对象的值:

>>> class Order:
...     def __init__(self, cart, customer):
...         self.cart = list(cart)
...         self.customer = customer
...
...     def __add__(self, other):
...         new_cart = self.cart.copy()
...         new_cart.append(other)
...         return Order(new_cart, self.customer)
...
>>> order = Order(['banana', 'apple'], 'Real Python')

>>> (order + 'orange').cart  # New Order instance
['banana', 'apple', 'mango']
>>> order.cart  # Original instance unchanged
['banana', 'apple']

>>> order = order + 'mango'  # Changing the original instance
>>> order.cart
['banana', 'apple', 'mango']

同样的,还有其他的 sub(), mul() 等等特定方法,它们分别对应 -*,等等运算符。它们也都是返回新的实例对象。

一种快捷方式:+= 运算符

+= 运算符通常作为表达式 obj1 = obj1 + obj2 的一种快捷方式。对应的特定方法是 iadd(),该方法会直接修改自身的值,返回的结果可能是自身也可能不是自身。这一点跟 add() 方法有很大的区别,后者是生成新对象作为结果返回。

大致来说,+= 运算符等价于:

>>> result = obj1 + obj2
>>> obj1 = result

上面,result 是 iadd() 返回的值。第二步赋值是 Python 自动处理的,也就是说你无需显式地用表达式 obj1 = obj1 + obj2 将结果赋值给 obj1 。
我们将在 Order 类中实现这个功能,这样我们就可以使用 += 来添加新项目到购物车中:

>>> class Order:
...     def __init__(self, cart, customer):
...         self.cart = list(cart)
...         self.customer = customer
...
...     def __iadd__(self, other):
...         self.cart.append(other)
...         return self
...
>>> order = Order(['banana', 'apple'], 'Real Python')
>>> order += 'mango'
>>> order.cart
['banana', 'apple', 'mango']

如上所见,所有的更改是直接作用在对象自身上,并返回自身。如果我们让它返回一些随机值比如字符串、整数怎样?

>>> class Order:
...     def __init__(self, cart, customer):
...         self.cart = list(cart)
...         self.customer = customer
...
...     def __iadd__(self, other):
...         self.cart.append(other)
...         return 'Hey, I am string!'
...
>>> order = Order(['banana', 'apple'], 'Real Python')
>>> order += 'mango'
>>> order
'Hey, I am string!'

尽管,我们往购物车里添加的是相关的项,但购物车的值却变成了 iadd() 返回的值。Python 在后台隐式处理这个过程。如果你在方法实现中忘记处理返回内容,可能会出现令人惊讶的行为:

>>> class Order:
...     def __init__(self, cart, customer):
...         self.cart = list(cart)
...         self.customer = customer
...
...     def __iadd__(self, other):
...         self.cart.append(other)
...
>>> order = Order(['banana', 'apple'], 'Real Python')
>>> order += 'mango'
>>> order  # No output
>>> type(order)
NoneType

Python 中所有的函数(方法)默认都是返回 None,因此,order 的值被设置为默认值 None,交互界面不会有输出显示。如果检查 order 的类型,显示为 NoneType 类型。因此,你需要确保在 iadd() 的实现中返回期望得到的结果而不是其他什么东东。

iadd() 类似, isub(), imul(), idiv() 等特定方法相应地定义了 -=, *=, /= 等运算符的行为。

注意:当 iadd() 或者同系列的方法没有在你的类中定义,而你又在你的对象上使用这些运算符时。Python 会用 add() 系列方法来替代并返回结果。通常来讲,如果 add() 系列方法能够返回预期正确的结果,不使用 iadd() 系列的方法是一种安全的方式。

Python 的文档提供了这些方法的详细说明。此外,可以看看当使用不可变类型涉及到的 +=及其他运算符需要注意到的附加说明的代码实例。

使用 [] 运算符来索引和分片你的对象

[] 运算符被称作索引运算符,在 Python 各上下文中都有用到,比如获取序列某个索引的值,获取字典某个键对应的值,或者对序列的切片操作。你可以通过 getitem() 特定方法来控制该运算符的行为。

我们设置一下 Order 类的定义,让我们可以直接获取购物车对象中的项:

>>> class Order:
...     def __init__(self, cart, customer):
...         self.cart = list(cart)
...         self.customer = customer
...
...     def __getitem__(self, key):
...         return self.cart[key]
...
>>> order = Order(['banana', 'apple'], 'Real Python')
>>> order[0]
'banana'
>>> order[-1]
'apple'

你可能会注意到上面的例子中, getitem() 方法的参数名并不是 index 而是 key。这是因为,参数主要接收三种类型的值:整数值,通常是一个索引或字典的键值;字符串,字典的键值;切片对象,序列对象的切片。当然,也可能会有其他的值类型,但这三种是最常见的形式。
因为我们的内部数据结构是一个列表,我们可以使用 [] 运算符来对列表进行切片,这时 key 参数会接收一个切片对象。这就是在类中定义 getitem() 方法的最大优势。只要你使用的数据结构支持切片操作(列表、元组、字符串等等),你就可以定义你的对象直接对数据进行切片:

>>> order[1:]
['apple']
>>> order[::-1]
['apple', 'banana']

注意:有一个类似的 setitem() 特定方法可定义类似 obj[x] = y 这种行为。此方法除自身外还需要两个参数,一般称为 key 和 value,用来更改指定 key 索引的值。

逆运算符:让你的类在数学计算上正确

在你定义了 add(), sub(), mul(),以及类似的方法后,类实例作为左侧操作数时可以正确运行,但如果作为右侧操作数则不会正常工作:

>>> class Mock:
...     def __init__(self, num):
...         self.num = num
...     def __add__(self, other):
...         return Mock(self.num + other)
...
>>> mock = Mock(5)
>>> mock = mock + 6
>>> mock.num
11

>>> mock = 6 + Mock(5)
Traceback (most recent call last):
  File "", line 1, in 
TypeError: unsupported operand type(s) for +: 'int' and 'Mock'

如果你的类表示的是一个数学实体,比如向量、坐标或复数,运算符应该在这两种方式下都能正确运算,因为它是有效的数学运算规则。此外,如果某个运算符仅仅在操作数为左侧时才工作,这在数学上违背了交换律规则。因此,为了保证在数学上的正确,Python 为你提供了反向计算的 radd(), rsub(), rmul()等特定方法。

这些方法处理类似 x + obj, x - obj, 以及 x * obj 形式的运算,其中 x 不是一个类实例对象。和 add() 及其他方法一样,这些方法也应该返回一个新的实例对象,而不是修改自身。
我们在 Order 类中定义 radd() 方法,这样就可以将某些项操作数放在购物车对象前面进行添加。这还可以用在购物车内订单是按照优先次序排列的情况。:

>>> class Order:
...     def __init__(self, cart, customer):
...         self.cart = list(cart)
...         self.customer = customer
...
...     def __add__(self, other):
...         new_cart = self.cart.copy()
...         new_cart.append(other)
...         return Order(new_cart, self.customer)
...
...     def __radd__(self, other):
...         new_cart = self.cart.copy()
...         new_cart.insert(0, other)
...         return Order(new_cart, self.customer)
...
>>> order = Order(['banana', 'apple'], 'Real Python')

>>> order = order + 'orange'
>>> order.cart
['banana', 'apple', 'orange']

>>> order = 'mango' + order
>>> order.cart
['mango', 'banana', 'apple', 'orange']

一个完整的例子

想要掌握以上所有的关键点,最好自己实现一个包含以上所有操作的自定义类。我们自己来造一个轮子,实现一个复数的自定义类 CustomComplex。这个类的实例将支持各种内置函数和运算符,行为表现上非常类似于 Python 自带的复数类:

from math import hypot, atan, sin, cos

class CustomComplex:
    def __init__(self, real, imag):
        self.real = real
        self.imag = imag

构造函数只支持一种调用方式,即 CustomComplex(a, b)。它通过位置参数来表示复数的实部和虚部。我们在这个类中定义两个方法 conjugate() 和 argz()。它们分别提供复数共轭和复数的辐角:

def conjugate(self):
    return self.__class__(self.real, -self.imag)

def argz(self):
    return atan(self.imag / self.real)

注意: class 并不是特定方法,只是默认的一个类属性通常指向类本身。这里我们跟调用构造函数一样来对它进行调用,换句话来说其实调用的就是 CustomComplex(real, imag)。这样调用是为了防止今后更改类名时要再次重构代码。
下一步,我们配置 abs() 返回复数的模:

def __abs__(self):
    return hypot(self.real, self.imag)

我们遵循官方建议的 repr() 和 str() 两者差异,用第一个来实现可解析的字符串输出,用第二个来实现“更美观”的输出。 repr() 方法简单地返回 CustomComplex(a, b) 字符串,这样我们在调用 eval() 重建对象时很方便。 str() 方法用来返回带括号的复数输出形式,比例 (a+bj):

def __repr__(self):
    return f"{self.__class__.__name__}({self.real}, {self.imag})"

def __str__(self):
    return f"({self.real}{self.imag:+}j)"

数学上讲,我们可以进行两个复数相加或者将一个实数和复数相加。我们定义 + 运算符来实现这个功能。方法将会检测运算符右侧的类型,如果是一个整数或者浮点数,它将只增加实部(因为任意实数都可以看做是 a+0j),当类型是复数时,它会同时更改实部和虚部:

def __add__(self, other):
    if isinstance(other, float) or isinstance(other, int):
        real_part = self.real + other
        imag_part = self.imag

    if isinstance(other, CustomComplex):
        real_part = self.real + other.real
        imag_part = self.imag + other.imag

    return self.__class__(real_part, imag_part)

同样,我们定义 -* 运算符的行为:

def __sub__(self, other):
    if isinstance(other, float) or isinstance(other, int):
        real_part = self.real - other
        imag_part = self.imag

    if isinstance(other, CustomComplex):
        real_part = self.real - other.real
        imag_part = self.imag - other.imag

    return self.__class__(real_part, imag_part)

def __mul__(self, other):
    if isinstance(other, int) or isinstance(other, float):
        real_part = self.real * other
        imag_part = self.imag * other

    if isinstance(other, CustomComplex):
        real_part = (self.real * other.real) - (self.imag * other.imag)
        imag_part = (self.real * other.imag) + (self.imag * other.real)

    return self.__class__(real_part, imag_part)

因为加法和乘法可以交换操作数,我们可以在反向运算符 radd() 和 rmul() 方法中这样调用 add() 和 mul() 。此外,减法运算的操作数是不可以交换的,所以需要 rsub() 方法的行为:

def __radd__(self, other):
    return self.__add__(other)

def __rmul__(self, other):
    return self.__mul__(other)

def __rsub__(self, other):
    # x - y != y - x
    if isinstance(other, float) or isinstance(other, int):
        real_part = other - self.real
        imag_part = -self.imag

    return self.__class__(real_part, imag_part)

注意:你也许发现我们并没有增添一个构造函数来处理 CustomComplex 实例。因为这种情形下,两个操作数都是类的实例, rsub() 方法并不负责处理实际的运算,仅仅是调用 sub() 方法来处理。这是一个微妙但是很重要的细节。
现在我们来看看另外两个运算符:== 和 != 。这两个分别对应的特定方法是 eq() 和 ne()。如果两个复数的实部和虚部都相同则两者是相等的。只要两个部分任意一个不相等两者就不相等:

def __eq__(self, other):
    # Note: generally, floats should not be compared directly
    # due to floating-point precision
    return (self.real == other.real) and (self.imag == other.imag)

def __ne__(self, other):
    return (self.real != other.real) or (self.imag != other.imag)

注意:浮点指南这篇文章讨论了浮点数比较和浮点精度的问题,它涉及到一些浮点数直接比较的一些注意事项,这与我们在这里要处理的情况有点类似。
同样,我们也可以通过简单的公式来提供复数的幂运算。我们通过定义 pow() 特定方法来设置内置函数 pow() 和 ** 运算符的行为:

def __pow__(self, other):
    r_raised = abs(self) ** other
    argz_multiplied = self.argz() * other

    real_part = round(r_raised * cos(argz_multiplied))
    imag_part = round(r_raised * sin(argz_multiplied))

    return self.__class__(real_part, imag_part)

注意:认真看看方法的定义。我们调用 abs() 来获取复数的模。所以,我们一旦为特定功能函数或运算符定义好了特定方法,它就可以被用于此类的其他方法中。

我们创建这个类的两个实例,一个拥有正的虚部,一个拥有负的虚部:

>>> a = CustomComplex(1, 2)
>>> b = CustomComplex(3, -4)

字符串表示:

>>> a
CustomComplex(1, 2)
>>> b
CustomComplex(3, -4)
>>> print(a)
(1+2j)
>>> print(b)
(3-4j)

使用 eval() 和 repr()重建对象

>>> b_copy = eval(repr(b))
>>> type(b_copy), b_copy.real, b_copy.imag
(__main__.CustomComplex, 3, -4)

加减乘法:

>>> a + b
CustomComplex(4, -2)
>>> a - b
CustomComplex(-2, 6)
>>> a + 5
CustomComplex(6, 2)
>>> 3 - a
CustomComplex(2, -2)
>>> a * 6
CustomComplex(6, 12)
>>> a * (-6)
CustomComplex(-6, -12)

相等和不等检测:

>>> a == CustomComplex(1, 2)
True
>>> a ==  b
False
>>> a != b
True
>>> a != CustomComplex(1, 2)
False

最后,将复数加到某个幂上:

```python
>>> a ** 2
CustomComplex(-3, 4)
>>> b ** 5
CustomComplex(-237, 3116)

正如你所见到的,我们自定义类的对象外观及行为上类似于内置的对象而且很 Pythonic。

回顾总结

本教程中,你学习了 Python 数据模型,以及如何通过数据模型来构建 Pythonic 的类。学习了改变 len(), abs(), str(), bool() 等内置函数的行为,以及改变 +, -, *, **, 等内置运算符的行为。
如果想要进一步地了解数据模型、函数和运算符重载,请参考以下资源:

  • Python 文档,数据模型的第 3.3 节,特定方法名
  • 流畅的 Python(Fluent Python by Luciano Ramalho)
  • Python 技巧(Python Tricks)

本示例的完整代码:

from math import hypot, atan, sin, cos

class CustomComplex():
    """
    A class to represent a complex number, a+bj.
    Attributes:
        real - int, representing the real part
        imag - int, representing the imaginary part

    Implements the following:

    * Addition with a complex number or a real number using `+`
    * Multiplication with a complex number or a real number using `*`
    * Subtraction of a complex number or a real number using `-`
    * Calculation of absolute value using `abs`
    * Raise complex number to a power using `**`
    * Nice string representation using `__repr__`
    * Nice user-end viewing using `__str__`

    Notes:
        * The constructor has been intentionally kept simple
        * It is configured to support one kind of call:
            CustomComplex(a, b)
        * Error handling was avoided to keep things simple
    """

    def __init__(self, real, imag):
        """
        Initializes a complex number, setting real and imag part
        Arguments:
            real: Number, real part of the complex number
            imag: Number, imaginary part of the complex number
        """
        self.real = real
        self.imag = imag

    def conjugate(self):
        """
        Returns the complex conjugate of a complex number
        Return:
            CustomComplex instance
        """
        return CustomComplex(self.real, -self.imag)

    def argz(self):
        """
        Returns the argument of a complex number
        The argument is given by:
            atan(imag_part/real_part)
        Return:
            float
        """
        return atan(self.imag / self.real)

    def __abs__(self):
        """
        Returns the modulus of a complex number
        Return:
            float
        """
        return hypot(self.real, self.imag)

    def __repr__(self):
        """
        Returns str representation of an instance of the 
        class. Can be used with eval() to get another 
        instance of the class
        Return:
            str
        """
        return f"CustomComplex({self.real}, {self.imag})"


    def __str__(self):
        """
        Returns user-friendly str representation of an instance 
        of the class
        Return:
            str
        """
        return f"({self.real}{self.imag:+}j)"

    def __add__(self, other):
        """
        Returns the addition of a complex number with
        int, float or another complex number
        Return:
            CustomComplex instance
        """
        if isinstance(other, float) or isinstance(other, int):
            real_part = self.real + other
            imag_part = self.imag

        if isinstance(other, CustomComplex):
            real_part = self.real + other.real
            imag_part = self.imag + other.imag

        return CustomComplex(real_part, imag_part)

    def __sub__(self, other):
        """
        Returns the subtration from a complex number of
        int, float or another complex number
        Return:
            CustomComplex instance
        """
        if isinstance(other, float) or isinstance(other, int):
            real_part = self.real - other
            imag_part = self.imag

        if isinstance(other, CustomComplex):
            real_part = self.real - other.real
            imag_part = self.imag - other.imag

        return CustomComplex(real_part, imag_part)

    def __mul__(self, other):
        """
        Returns the multiplication of a complex number with
        int, float or another complex number
        Return:
            CustomComplex instance
        """
        if isinstance(other, int) or isinstance(other, float):
            real_part = self.real * other
            imag_part = self.imag * other

        if isinstance(other, CustomComplex):
            real_part = (self.real * other.real) - (self.imag * other.imag)
            imag_part = (self.real * other.imag) + (self.imag * other.real)

        return CustomComplex(real_part, imag_part)

    def __radd__(self, other):
        """
        Same as __add__; allows 1 + CustomComplex('x+yj')
        x + y == y + x
        """
        pass

    def __rmul__(self, other):
        """
        Same as __mul__; allows 2 * CustomComplex('x+yj')
        x * y == y * x
        """
        pass

    def __rsub__(self, other):
        """
        Returns the subtraction of a complex number from
        int or float
        x - y != y - x
        Subtration of another complex number is not handled by __rsub__
        Instead, __sub__ handles it since both sides are instances of
        this class
        Return:
            CustomComplex instance
        """
        if isinstance(other, float) or isinstance(other, int):
            real_part = other - self.real
            imag_part = -self.imag

        return CustomComplex(real_part, imag_part)

    def __eq__(self, other):
        """
        Checks equality of two complex numbers
        Two complex numbers are equal when:
            * Their real parts are equal AND
            * Their imaginary parts are equal
        Return:
            bool
        """
        # note: comparing floats directly is not a good idea in general
        # due to floating-point precision
        return (self.real == other.real) and (self.imag == other.imag)

    def __ne__(self, other):
        """
        Checks inequality of two complex numbers
        Two complex numbers are unequal when:
            * Their real parts are unequal OR
            * Their imaginary parts are unequal
        Return:
            bool
        """
        return (self.real != other.real) or (self.imag != other.imag)

    def __pow__(self, other):
        """
        Raises a complex number to a power
        Formula:
            z**n = (r**n)*[cos(n*agrz) + sin(n*argz)j], where
            z = complex number
            n = power
            r = absolute value of z
            argz = argument of z
        Return:
            CustomComplex instance
        """
        r_raised = abs(self) ** other
        argz_multiplied = self.argz() * other

        real_part = round(r_raised * cos(argz_multiplied))
        imag_part = round(r_raised * sin(argz_multiplied))

        return CustomComplex(real_part, imag_part)

本作品采用:知识共享 署名-非商业性使用-禁止演绎 4.0 国际许可协议进行许可。

你可能感兴趣的:(Python 自定义类中的函数和运算符重载)