类和接口
Python作为面向对象语言,继承多态和封装有良好的应用,如何编写可维护的代码呢?
- Item37: 组合类而不是嵌套多层的Built-in类型
假设现在要记录一群学生(不知道姓名)的分数。我可以定义一个类来把姓名存储为字典。
class SimpleGradebook:
def __init__(self):
self._grades = {}
def add_student(self, name):
self._grades[name] = []
def report_grade(self, name, score):
self._grades[name].append(score)
def average_grade(self, name):
grades = self._grades[name]
return sum(grades) / len(grades)
book = SimpleGradebook()
book.add_student('Isaac Newton')
book.report_grade('Isaac Newton', 90)
book.report_grade('Isaac Newton', 95)
book.report_grade('Isaac Newton', 85)
print(book.average_grade('Isaac Newton'))
>>>
90.0
字典及相关的built-in类型容易用,但是有过度扩展的危险。比如现在不止想保存分数,还想保存对应科目:
from collections import defaultdict
class BySubjectGradebook:
def __init__(self):
self._grades = {} # Outer dict
def add_student(self, name):
self._grades[name] = defaultdict(list) # Inner dict
这足够直接且符合直觉,多层的字典似乎也还能管理。继续修改对应的代码:
def report_grade(self, name, subject, grade):
by_subject = self._grades[name]
grade_list = by_subject[subject]
grade_list.append(grade)
def average_grade(self, name):
by_subject = self._grades[name]
total, count = 0, 0
for grades in by_subject.values():
total += sum(grades)
count += len(grades)
return total / count
book = BySubjectGradebook()
book.add_student('Albert Einstein')
book.report_grade('Albert Einstein', 'Math', 75)
book.report_grade('Albert Einstein', 'Math', 65)
book.report_grade('Albert Einstein', 'Gym', 90)
book.report_grade('Albert Einstein', 'Gym', 95)
print(book.average_grade('Albert Einstein'))
>>>
81.25
假如现在又有新的需求,需要变为不同测试带有不同的权重:(不止是分数,还有权重)
class WeightedGradebook:
def __init__(self):
self._grades = {}
def add_student(self, name):
self._grades[name] = defaultdict(list)
def report_grade(self, name, subject, score, weight):
by_subject = self._grades[name]
grade_list = by_subject[subject]
grade_list.append((score, weight))
def average_grade(self, name):
by_subject = self._grades[name]
score_sum, score_count = 0, 0
for subject, scores in by_subject.items():
subject_avg, total_weight = 0, 0
for score, weight in scores:
subject_avg += score * weight
total_weight += weight
score_sum += subject_avg / total_weight
score_count += 1
return score_sum / score_count
book = WeightedGradebook()
book.add_student('Albert Einstein')
book.report_grade('Albert Einstein', 'Math', 75, 0.05)
book.report_grade('Albert Einstein', 'Math', 65, 0.15)
book.report_grade('Albert Einstein', 'Math', 70, 0.80)
book.report_grade('Albert Einstein', 'Gym', 100, 0.40)
book.report_grade('Albert Einstein', 'Gym', 85, 0.60)
print(book.average_grade('Albert Einstein'))
>>>
80.25
超过一层的嵌套尽量就不要继续用了。(维护噩梦)
应该重构成类。
grades = []
grades.append((95, 0.45))
grades.append((85, 0.55))
total = sum(score * weight for score, weight in grades)
total_weight = sum(weight for _, weight in grades)
average_grade = total / total_weight
如果要加一些教师评价,可能就会引入很多下划线_:
grades = []
grades.append((95, 0.45, 'Great job'))
grades.append((85, 0.55, 'Better next time'))
total = sum(score * weight for score, weight, _ in grades)
total_weight = sum(weight for _, weight, _ in grades)
average_grade = total / total_weight
这里,namedtuple刚好符合要求:
from collections import namedtuple
Grade = namedtuple('Grade', ('score', 'weight'))
但是,namedtuple也有限制:
不能指定默认参数。
当你的数据有很多可选的属性时,这点就很不好。属性多的时候用built-in可能更合适。
namedtuple的属性值仍可访问。如果不能控制它们的使用,最好还是显式地定义一个新的类。
class Subject:
def __init__(self):
self._grades = []
def report_grade(self, score, weight):
self._grades.append(Grade(score, weight))
def average_grade(self):
total, total_weight = 0, 0
for grade in self._grades:
total += grade.score * grade.weight
total_weight += grade.weight
return total / total_weight
class Student:
def __init__(self):
self._subjects = defaultdict(Subject)
def get_subject(self, name):
return self._subjects[name]
def average_grade(self):
total, count = 0, 0
for subject in self._subjects.values():
total += subject.average_grade()
count += 1
return total / count
class Gradebook:
def __init__(self):
self._students = defaultdict(Student)
def get_student(self, name):
return self._students[name]
book = Gradebook()
albert = book.get_student('Albert Einstein')
math = albert.get_subject('Math')
math.report_grade(75, 0.05)
math.report_grade(65, 0.15)
math.report_grade(70, 0.80)
gym = albert.get_subject('Gym')
gym.report_grade(100, 0.40)
gym.report_grade(85, 0.60)
print(albert.average_grade())
>>>
80.25
- Item38: 对于简单的接口,接受函数而不是类
许多built-in的API允许传递函数。这些钩子(hooks)被API回调。比如:sort函数的key参数可以传递函数:
names = ['Socrates', 'Archimedes', 'Plato', 'Aristotle']
names.sort(key=len)
print(names)
>>>
['Plato', 'Socrates', 'Aristotle', 'Archimedes']
当然,还有很多例子,比如defaultdict的参数也可以是类名或者函数,就是需要返回默认的值。
如果定义为每次返回0:
def log_missing():
print('Key added')
return 0
先构建出current的result,再增量地加回去。默认值为log_missing返回的0。
from collections import defaultdict
current = {'green': 12, 'blue': 3}
increments = [
('red', 5),
('blue', 17),
('orange', 9),
]
result = defaultdict(log_missing, current)
print('Before:', dict(result))
for key, amount in increments:
result[key] += amount
print('After: ', dict(result))
>>>
Before: {'green': 12, 'blue': 3}
Key added
Key added
After: {'green': 12, 'blue': 20, 'red': 5, 'orange': 9}
假如现在在添加的时候,需要统计添加的列别的数目,如下:(利用了闭包的属性,可以在内部进行统计。)
def increment_with_report(current, increments):
added_count = 0
def missing():
nonlocal added_count # Stateful closure
added_count += 1
return 0
result = defaultdict(missing, current)
for key, amount in increments:
result[key] += amount
return result, added_count
尽管defaultdict不知道missing这个hook保持了什么状态信息,最终结果也可以得到为2。
result, count = increment_with_report(current, increments)
assert count == 2
其它的语言可能可以定义一个类来保持状态,然后传递这个实例的方法:
class CountMissing:
def __init__(self):
self.added = 0
def missing(self):
self.added += 1
return 0
同样也是可以达到效果:
counter = CountMissing()
result = defaultdict(counter.missing, current) # Method ref
for key, amount in increments:
result[key] += amount
assert counter.added == 2
虽然类比闭包清晰一些,但是CountMissing类的目的不是很显而易见,直到看到defaultdict的时候。(谁创建,谁调用missing,这个类未来需要其它的puclic方法吗?)
python允许类定义__call__的方法,调用callable时,如果该类实现了__call__会返回true。
class BetterCountMissing:
def __init__(self):
self.added = 0
def __call__(self):
self.added += 1
return 0
counter = BetterCountMissing()
assert counter() == 0
assert callable(counter)
当key缺失的时候,会调用一次counter,即其call方法。
counter = BetterCountMissing()
result = defaultdict(counter, current) # Relies on __call__
for key, amount in increments:
result[key] += amount
assert counter.added == 2
这样,就可以很方便的实现上面的需求。
- Item39: 用@classmethod多态来泛化(泛型)地构建对象
不止对象支持多态,类也同样支持,有什么好处?
多态允许多个类在一个层级制度下实现它们自己的特有的版本。这意味着许多类可以提供不同的功能给同一个接口或者抽象类。
比如,现在在写MapReduce的实现,要一个公共的抽象类来表示输入数据:
class InputData:
def read(self):
raise NotImplementedError
从磁盘上的文件读数据:
class PathInputData(InputData):
def __init__(self, path):
super().__init__()
self.path = path
def read(self):
with open(self.path) as f:
return f.read()
我可以有很多种InputData,比如NetworkInputData。而对于MapReduce的worker来说,需要输入和消费这些数据:
class Worker:
def __init__(self, input_data):
self.input_data = input_data
self.result = None
def map(self):
raise NotImplementedError
def reduce(self, other):
raise NotImplementedError
此时,有一个具体的获取行数的Worker:
class LineCountWorker(Worker):
def map(self):
data = self.input_data.read() # 读数据
self.result = data.count('\n') # 当前数据的行数
def reduce(self, other):
self.result += other.result # 合并其它的Worker的结果。
似乎需要一个helper函数来生成数据。
import os
def generate_inputs(data_dir):
for name in os.listdir(data_dir):
yield PathInputData(os.path.join(data_dir, name))
然后根据这些数据,来生成worker:
def create_workers(input_list):
workers = []
for input_data in input_list:
workers.append(LineCountWorker(input_data))
return workers
然后调用map来分散到各个线程计算,最后用reduce来产生最终结果:
from threading import Thread
def execute(workers):
threads = [Thread(target=w.map) for w in workers]
for thread in threads: thread.start()
for thread in threads: thread.join()
first, *rest = workers
for worker in rest:
first.reduce(worker)
return first.result
最后把几个helper连接到一起返回结果:
def mapreduce(data_dir):
inputs = generate_inputs(data_dir)
workers = create_workers(inputs)
return execute(workers)
随机生成一些文件,发现可以工作得很好:
import os
import random
def write_test_files(tmpdir):
os.makedirs(tmpdir)
for i in range(100):
with open(os.path.join(tmpdir, str(i)), 'w') as f:
f.write('\n' * random.randint(0, 100))
tmpdir = 'test_inputs'
write_test_files(tmpdir)
result = mapreduce(tmpdir)
print(f'There are {result} lines')
>>>
There are 4360 lines
问题出现在哪?mapreduce方法不够泛化。如果我要写另一种InputData或者Worker的子类,需要重写上面的几个方法来匹配。
最好的方式是用类多态(因为init只有一个,对每个InputData的子类来写适配的constructor不合理。)
使用了@classmethod来创建新的InputData:
class GenericInputData:
def read(self):
raise NotImplementedError
@classmethod
def generate_inputs(cls, config):
raise NotImplementedError
用config来找到字典值来处理:
class PathInputData(GenericInputData):
...
@classmethod
def generate_inputs(cls, config):
data_dir = config['data_dir']
for name in os.listdir(data_dir):
yield cls(os.path.join(data_dir, name))
类似地,可以创建泛型Worker。用cls()创建特定的子类。
class GenericWorker:
def __init__(self, input_data):
self.input_data = input_data
self.result = None
def map(self):
raise NotImplementedError
def reduce(self, other):
raise NotImplementedError
@classmethod
def create_workers(cls, input_class, config):
workers = []
for input_data in input_class.generate_inputs(config):
workers.append(cls(input_data))
return workers
注意到调用input_class.generate_inputs是类的多态。可以看到create_workers调用了cls()来提供额外的方式来构建GenericWorker(用到__init__)
class LineCountWorker(GenericWorker):
...
最后,重写mapreduce函数:
def mapreduce(worker_class, input_class, config):
workers = worker_class.create_workers(input_class,
config)
return execute(workers)
config = {'data_dir': tmpdir}
result = mapreduce(LineCountWorker, PathInputData, config)
print(f'There are {result} lines')
>>>
There are 4360 lines
可以看出,通过@classmethod的cls可以建立具体类的连接。
- Item40: 用super来初始化父类
古老且简单的方式来初始化父类是直接调用父类的__init__方法:
class MyBaseClass:
def __init__(self, value):
self.value = value
class MyChildClass(MyBaseClass):
def __init__(self):
MyBaseClass.__init__(self, 5)
但是在许多情况下失效。比如定义类来操作实例变量value。
class TimesTwo:
def __init__(self):
self.value *= 2
class PlusFive:
def __init__(self):
self.value += 5
构建的时候,继承的时候是匹配结果的顺序。
class OneWay(MyBaseClass, TimesTwo, PlusFive):
def __init__(self, value):
MyBaseClass.__init__(self, value)
TimesTwo.__init__(self)
PlusFive.__init__(self)
结果为:
foo = OneWay(5)
print('First ordering value is (5 * 2) + 5 =', foo.value)
>>>
First ordering value is (5 * 2) + 5 = 15
另一种是定义一样的父类但是不一样的顺序:
class AnotherWay(MyBaseClass, PlusFive, TimesTwo):
def __init__(self, value):
MyBaseClass.__init__(self, value)
TimesTwo.__init__(self)
PlusFive.__init__(self)
定义和实现的顺序不同。这种顺序比较难发现,对于新手来说不友好。
bar = AnotherWay(5)
print('Second ordering value is', bar.value)
>>>
Second ordering value is 15
另一个问题发生在菱形继承。比如两个类继承同一个类:
class TimesSeven(MyBaseClass):
def __init__(self, value):
MyBaseClass.__init__(self, value)
self.value *= 7
class PlusNine(MyBaseClass):
def __init__(self, value):
MyBaseClass.__init__(self, value)
self.value += 9
然后定义一个类继承这两个类:
class ThisWay(TimesSeven, PlusNine):
def __init__(self, value):
TimesSeven.__init__(self, value)
PlusNine.__init__(self, value)
foo = ThisWay(5)
print('Should be (5 * 7) + 9 = 44 but is', foo.value)
>>>
Should be (5 * 7) + 9 = 44 but is 14
由于__init__再次被调用,因此结果变为5+9=14,如果情况更复杂的话,这点是比较难以debug的。
为了解决这些问题,Python自带了super自建的函数还有标准方法解析顺序(MRO)。super确保了公共的父类只运行一次。MRO定义了父类被初始化的顺序(以C3线性(C3 linearization)算法的顺序进行)
class TimesSevenCorrect(MyBaseClass):
def __init__(self, value):
super().__init__(value)
self.value *= 7
class PlusNineCorrect(MyBaseClass):
def __init__(self, value):
super().__init__(value)
self.value += 9
现在,正确地运行如下:
class GoodWay(TimesSevenCorrect, PlusNineCorrect):
def __init__(self, value):
super().__init__(value)
foo = GoodWay(5)
print('Should be 7 * (5 + 9) = 98 and is', foo.value)
>>>
Should be 7 * (5 + 9) = 98 and is 98
顺序看着是反着来的,实际是根据MRO的顺序来的:
mro_str = '\n'.join(repr(cls) for cls in GoodWay.mro())
print(mro_str)
>>>
super的两个参数:MRO父视图的类类型、访问这个视图的实例。
class ExplicitTrisect(MyBaseClass):
def __init__(self, value):
super(ExplicitTrisect, self).__init__(value)
self.value /= 3
对于object实例的初始化,参数不是要求的。(因为如果使用super(),编译器会自动提供正确的参数__class__和self,因此,下面几种都是等价的。)
class AutomaticTrisect(MyBaseClass):
def __init__(self, value):
super(__class__, self).__init__(value)
self.value /= 3
class ImplicitTrisect(MyBaseClass):
def __init__(self, value):
super().__init__(value)
self.value /= 3
assert ExplicitTrisect(9).value == 3
assert AutomaticTrisect(9).value == 3
assert ImplicitTrisect(9).value == 3
- Item41: 考虑用Mix-in类来组合功能性
最好还是避免多继承,考虑编写mix-in(定义了小的、额外的方法类,供子类使用)。
比如,假如现在需要从内存表示转换Python对象到序列化的字典:
class ToDictMixin:
def to_dict(self):
return self._traverse_dict(self.__dict__)
用hasattr来进行动态属性访问,用isinstance来进行动态类检查。并且访问实例字典__dict__:
def _traverse_dict(self, instance_dict):
output = {}
for key, value in instance_dict.items():
output[key] = self._traverse(key, value)
return output
def _traverse(self, key, value):
if isinstance(value, ToDictMixin):
return value.to_dict()
elif isinstance(value, dict):
return self._traverse_dict(value)
elif isinstance(value, list):
return [self._traverse(key, i) for i in value]
elif hasattr(value, '__dict__'):
return self._traverse_dict(value.__dict__)
else:
return value
这里定义了一个类来使得字典表达为二叉树:
class BinaryTree(ToDictMixin):
def __init__(self, value, left=None, right=None):
self.value = value
self.left = left
self.right = right
# 把大量的对象转换成字典变得容易:
tree = BinaryTree(10,
left=BinaryTree(7, right=BinaryTree(9)),
right=BinaryTree(13, left=BinaryTree(11)))
print(tree.to_dict())
>>>
{'value': 10,
'left': {'value': 7,
'left': None,
'right': {'value': 9, 'left': None, 'right':
None}},
'right': {'value': 13,
'left': {'value': 11, 'left': None, 'right':
None},
'right': None}}
定义了BinaryTree的子类,带着父节点的引用。这个循环引用可能会导致ToDictMixin.to_dict无限循环:
class BinaryTreeWithParent(BinaryTree):
def __init__(self, value, left=None,
right=None, parent=None):
super().__init__(value, left=left, right=right)
self.parent = parent
解决方案就是重写(override)此类中的_traverse方法,使得方法只处理数值,避免mix-in带来循环。这里给了父节点的数值,否则就用默认的实现。
def _traverse(self, key, value):
if (isinstance(value, BinaryTreeWithParent) and
key == 'parent'):
return value.value # Prevent cycles
else:
return super()._traverse(key, value)
调用BinaryTreeWithParent.to_dict没有问题,因为循环引用的属性不被允许:
root = BinaryTreeWithParent(10)
root.left = BinaryTreeWithParent(7, parent=root)
root.left.right = BinaryTreeWithParent(9, parent=root.left)
print(root.to_dict())
>>>
{'value': 10,
'left': {'value': 7,
'left': None,
'right': {'value': 9,
'left': None,
'right': None,
'parent': 7},
'parent': 10},
'right': None,
'parent': None}
可以使得拥有类型BinaryTreeWithParent的属性的类自动和ToDictMixin工作得很好。
class NamedSubTree(ToDictMixin):
def __init__(self, name, tree_with_parent):
self.name = name
self.tree_with_parent = tree_with_parent
my_tree = NamedSubTree('foobar', root.left.right)
print(my_tree.to_dict()) # No infinite loop
>>>
{'name': 'foobar',
'tree_with_parent': {'value': 9,
'left': None,
'right': None,
'parent': 7}}
Mix-in可以被组合。比如,需要提供JSON序列化:
import json
class JsonMixin:
@classmethod
def from_json(cls, data):
kwargs = json.loads(data)
return cls(**kwargs)
def to_json(self):
return json.dumps(self.to_dict())
JsonMixin定义了两个方法,下面是数据中心的拓扑结构:
class DatacenterRack(ToDictMixin, JsonMixin):
def __init__(self, switch=None, machines=None):
self.switch = Switch(**switch)
self.machines = [
Machine(**kwargs) for kwargs in machines]
class Switch(ToDictMixin, JsonMixin):
def __init__(self, ports=None, speed=None):
self.ports = ports
self.speed = speed
class Machine(ToDictMixin, JsonMixin):
def __init__(self, cores=None, ram=None, disk=None):
self.cores = cores
self.ram = ram
self.disk = disk
这里测试了从json中加载对象,然后序列化回json的整个闭环:
serialized = """{
"switch": {"ports": 5, "speed": 1e9},
"machines": [
{"cores": 8, "ram": 32e9, "disk": 5e12},
{"cores": 4, "ram": 16e9, "disk": 1e12},
{"cores": 2, "ram": 4e9, "disk": 500e9}
]
}"""
deserialized = DatacenterRack.from_json(serialized)
roundtrip = deserialized.to_json()
assert json.loads(serialized) == json.loads(roundtrip)
可以看出,用这种插件类的方式,也可以实现很多灵活性。
- Item42: 使用公有属性而不是私有属性
在Python中,有两种可见性:public和private
class MyObject:
def __init__(self):
self.public_field = 5
self.__private_field = 10
def get_private_field(self):
return self.__private_field
公有直接访问:
foo = MyObject()
assert foo.public_field == 5
私有通过get方法获得:
assert foo.get_private_field() == 10
直接访问会引发Error:
foo.__private_field
>>>
Traceback ...
AttributeError: 'MyObject' object has no attribute '__private_field'
类方法同样有访问私有属性的权限,因为它们在类内被声明:
class MyOtherObject:
def __init__(self):
self.__private_field = 71
@classmethod
def get_private_field_of_instance(cls, instance):
return instance.__private_field
bar = MyOtherObject()
assert MyOtherObject.get_private_field_of_instance(bar) == 71
继承访问不到父类的私有域:
class MyParentObject:
def __init__(self):
self.__private_field = 71
class MyChildObject(MyParentObject):
def get_private_field(self):
return self.__private_field
baz = MyChildObject()
baz.get_private_field()
>>>
Traceback ...
AttributeError: 'MyChildObject' object has no attribute
'_MyChildObject__private_field'
私有域的实现是简单地把属性名做了个转换。比如__private_field其实被转换成_MyChildObject__private_field。如果是指代父类的__private_field,则是被转换成了_MyParentObject__private_field。知道这个规则的话,就可以直接访问到对应的属性值了:
assert baz._MyParentObject__private_field == 71
或者直接通过__dict__来查看类内的属性:
print(baz.__dict__)
>>>
{'_MyParentObject__private_field': 71}
Python为了功能性,用户实际上可以绕开private。
根据Item2的PEP8的风格指引:一个下划线_protected_field表示保护域,表示使用类的外界用户需要小心处理。而私有域则是不希望被外界使用和继承。
class MyStringClass:
def __init__(self, value):
self.__value = value
def get_value(self):
return str(self.__value)
foo = MyStringClass(5)
assert foo.get_value() == '5'
这是错误的方式。
class MyIntegerSubclass(MyStringClass):
def get_value(self):
return int(self._MyStringClass__value)
foo = MyIntegerSubclass('5')
assert foo.get_value() == 5
class MyBaseClass:
def __init__(self, value):
self.__value = value
def get_value(self):
return self.__value
class MyStringClass(MyBaseClass):
def get_value(self):
return str(super().get_value()) # Updated
class MyIntegerSubclass(MyStringClass):
def get_value(self):
return int(self._MyStringClass__value) # Not updated
foo = MyIntegerSubclass(5)
foo.get_value()
>>>
Traceback ...
AttributeError: 'MyIntegerSubclass' object has no attribute
'_MyStringClass__value'
最好还是以protected的形式,同时给出注释,告诉其他人这是内部的。
class MyStringClass:
def __init__(self, value):
# This stores the user-supplied value for the object.
# It should be coercible to a string. Once assigned
in
# the object it should be treated as immutable.
self._value = value
...
需要考虑的是使用私有属性来区分变量名:
class ApiClass:
def __init__(self):
self._value = 5
def get(self):
return self._value
class Child(ApiClass):
def __init__(self):
super().__init__()
self._value = 'hello' # Conflicts
a = Child()
print(f'{a.get()} and {a._value} should be different')
>>>
hello and hello should be different
为了减少变量名被覆盖的风险,区别域是一种可行的选择:
class ApiClass:
def __init__(self):
self.__value = 5 # Double underscore
def get(self):
return self.__value # Double underscore
class Child(ApiClass):
def __init__(self):
super().__init__()
self._value = 'hello' # OK!
a = Child()
print(f'{a.get()} and {a._value} are different')
>>>
5 and hello are different
- Item43: 继承collections.abc,来定制Container类型
每个Python类是一个容器,封装属性和功能。同时内部还提供了很多的容器类型(比如:list,tuple,set和dict)。比如现在要统计元素的频率:
class FrequencyList(list):
def __init__(self, members):
super().__init__(members)
def frequency(self):
counts = {}
for item in self:
counts[item] = counts.get(item, 0) + 1
return counts
通过继承list,可以得到list的基础功能。然后可以定义方法来提供定制的功能:
foo = FrequencyList(['a', 'b', 'a', 'c', 'b', 'a', 'd'])
print('Length is', len(foo))
foo.pop()
print('After pop:', repr(foo))
print('Frequency:', foo.frequency())
>>>
Length is 7
After pop: ['a', 'b', 'a', 'c', 'b', 'a']
Frequency: {'a': 3, 'b': 2, 'c': 1}
现在,假设我要提供一个类似list的取下标功能,但是针对二叉树的结点:
class BinaryNode:
def __init__(self, value, left=None, right=None):
self.value = value
self.left = left
self.right = right
如何使得这个类像序列一样工作?即:
bar = [1, 2, 3]
bar[0]
# 实际上就是:
bar.__getitem__(0)
可以提供__getitem__的实现:使用前序遍历,每次记录index。
class IndexableNode(BinaryNode):
def _traverse(self):
if self.left is not None:
yield from self.left._traverse()
yield self
if self.right is not None:
yield from self.right._traverse()
def __getitem__(self, index):
for i, item in enumerate(self._traverse()):
if i == index:
return item.value
raise IndexError(f'Index {index} is out of range')
可以构建二叉树如下:
tree = IndexableNode(
10,
left=IndexableNode(
5,
left=IndexableNode(2),
right=IndexableNode(
6,
right=IndexableNode(7))),
right=IndexableNode(
15,
left=IndexableNode(11)))
可以像list一样进行访问:
print('LRR is', tree.left.right.right.value)
print('Index 0 is', tree[0])
print('Index 1 is', tree[1])
print('11 in the tree?', 11 in tree)
print('17 in the tree?', 17 in tree)
print('Tree is', list(tree))
>>>
LRR is 7
Index 0 is 2
Index 1 is 5
11 in the tree? True
17 in the tree? False
Tree is [2, 5, 6, 7, 10, 11, 15]
问题是实现了__getitem__对于list的功能并不齐全,比如:
len(tree)
>>>
Traceback ...
TypeError: object of type 'IndexableNode' has no len()
此时要实现__len__:
class SequenceNode(IndexableNode):
def __len__(self):
for count, _ in enumerate(self._traverse(), 1):
pass
return count
tree = SequenceNode(
10,
left=SequenceNode(
5,
left=SequenceNode(2),
right=SequenceNode(
6,
right=SequenceNode(7))),
right=SequenceNode(
15,
left=SequenceNode(11))
)
print('Tree length is', len(tree))
>>>
Tree length is 7
不幸的是,count和index方法还是无法使用。这就使得自己定义容器类比较困难。为了避免这个困难,collections.abc有一系列的抽象类提供:
from collections.abc import Sequence
class BadType(Sequence):
pass
foo = BadType()
>>>
Traceback ...
TypeError: Can't instantiate abstract class BadType with abstract methods __getitem__, __len__
同时继承Sequence,可以满足一些方法,比如index,count等的使用:
class BetterNode(SequenceNode, Sequence):
pass
tree = BetterNode(
10,
left=BetterNode(
5,
left=BetterNode(2),
right=BetterNode(
6,
right=BetterNode(7))),
right=BetterNode(
15,
left=BetterNode(11))
)
print('Index of 7 is', tree.index(7))
print('Count of 10 is', tree.count(10))
>>>
Index of 7 is 3
Count of 10 is 1
还有更多的比如Set和MutableMapping,可以来实现来匹配Python自建的容器类。排序也是如此(见Item73)