python3 自定义比较器/比较函数

python3 自定义比较器/比较函数

    • 方法一:自定义比较项
    • 方法二:自定义比较方法
    • 方法三:自定义运算符

python3 自带的排序函数 sort()sorted() 等,可以自定义 比较器,实现特殊需求,以下介绍 3 种常见的实现方法。
比如现在有二维空间中的 3 个点,需要对它们进行排序,规则是:先按 x 坐标排序,如果相同,则按 y 坐标排序。

class Pos:
	"""
	坐标类
	"""
    def __init__(self, x = 0, y = 0):
        self.x = x
        self.y = y
         
    def __str__(self):
        return ('(%s, %s)' % (self.x, self.y))

def print_list(l):
	"""
	打印坐标数组
	"""
	print(','.join([str(t) for t in l]))
	

# 待排序的3个坐标,正确的排序结果应该是:(2, 4),(2, 5),(5, 1)
l = [Pos(5, 1), Pos(2,5), Pos(2, 4)]

方法一:自定义比较项

可以通过修改 key= 参数,使用 lambda 将输入的元素映射成一个值,这个值就代表了这个元素的大小。

比如,根据坐标类的比较规则(先比较 x 再比较 y),可以将每个坐标映射成:1000 * x + y,这个值就可以代表这个坐标的大小,容易看出是满足比较规则的(前提条件是 y 不会超过 1000)。

l1 = sorted(l, key=lambda t : 1000 * t.x + t.y)
print_list(l1)

结果为:(2, 4),(2, 5),(5, 1)

方法二:自定义比较方法

可以通过修改 key= 参数,并借助内置的 cmp_to_key ,重写对两个元素的比较方法。

比如,可以直接将坐标的比较规则用代码写出来(可以新建一个函数,也可以直接使用lambda),作为新的比较方法。

from functools import cmp_to_key
def cmp(t1, t2):
    """
    比较函数,需要满足:t1>t2则返回正数,t1=t0则返回0,t1
l2`和`l3`的结果均为:`(2, 4),(2, 5),(5, 1)

方法三:自定义运算符

可以在类中重写比较函数,然后可以直接使用常规的比较运算符来判断两个对象(都属于这个类)的大小。

Python2 中可以直接重写__cmp__方法来实现比较,但是 Python3 中已经取消了,Python3 中需要细分每一个比较运算符:

__lt__: <
__gt__: >
__ge__: >=
__eq__: ==
__le__: <=

比如,我们可以在 Pos 类中重写这 5 种方法,作为不同 Pos 之间相互比较的规则。
重新定义 Pos 类:

class Pos:
    def __init__(self, x = 0, y = 0):
        self.x = x
        self.y = y
 
    def __str__(self):
        return ('(%s, %s)' % (self.x, self.y))
 
    def __lt__(self, other):
        return self.x < other.x if self.x != other.x else self.y < other.y
 
    def __gt__(self, other):
        return self.x > other.x if self.x != other.x else self.y > other.y
 
    def __ge__(self, other):
        return self.x >= other.x if self.x != other.x else self.y >= other.y
 
    def __eq__(self, other):
        return self.x == other.x and self.y == other.y
 
    def __le__(self, other):
        return self.x <= other.x if self.x != other.x else self.y <= other.y

这样的话,就可以将 Pos 类当做简单的基本类型数值,可以直接进行比较大小,也适用于所有排序方法。

#测试自定义运算符
print(Pos(5,1) > Pos(2,4)) # 输出 True
print(Pos(5,2) < Pos(5,1)) # 输出 False

# 排序
l4 = sorted(l)
print_list(l4)

结果为:(2, 4),(2, 5),(5, 1)

你可能感兴趣的:(python,数据结构)