Effective Python 笔记摘录5.1

类和接口

Python作为面向对象语言,继承多态和封装有良好的应用,如何编写可维护的代码呢?


  • Item37: 组合类而不是嵌套多层的Built-in类型

假设现在要记录一群学生(不知道姓名)的分数。我可以定义一个类来把姓名存储为字典。

class SimpleGradebook:
    def __init__(self):
        self._grades = {}
    def add_student(self, name):
        self._grades[name] = []
    def report_grade(self, name, score):
        self._grades[name].append(score)
    def average_grade(self, name):
        grades = self._grades[name]
        return sum(grades) / len(grades)

book = SimpleGradebook()
book.add_student('Isaac Newton')
book.report_grade('Isaac Newton', 90)
book.report_grade('Isaac Newton', 95)
book.report_grade('Isaac Newton', 85)
print(book.average_grade('Isaac Newton'))
>>>
90.0

字典及相关的built-in类型容易用,但是有过度扩展的危险。比如现在不止想保存分数,还想保存对应科目:

from collections import defaultdict

class BySubjectGradebook:
    def __init__(self):
        self._grades = {} # Outer dict
    def add_student(self, name):
        self._grades[name] = defaultdict(list) # Inner dict

这足够直接且符合直觉,多层的字典似乎也还能管理。继续修改对应的代码:

def report_grade(self, name, subject, grade):
    by_subject = self._grades[name]
    grade_list = by_subject[subject]
    grade_list.append(grade)

def average_grade(self, name):
    by_subject = self._grades[name]
    total, count = 0, 0
    for grades in by_subject.values():
        total += sum(grades)
        count += len(grades)
    return total / count
book = BySubjectGradebook()
book.add_student('Albert Einstein')
book.report_grade('Albert Einstein', 'Math', 75)
book.report_grade('Albert Einstein', 'Math', 65)
book.report_grade('Albert Einstein', 'Gym', 90)
book.report_grade('Albert Einstein', 'Gym', 95)
print(book.average_grade('Albert Einstein'))
>>>
81.25

假如现在又有新的需求,需要变为不同测试带有不同的权重:(不止是分数,还有权重)

class WeightedGradebook:
    def __init__(self):
        self._grades = {}
    def add_student(self, name):
        self._grades[name] = defaultdict(list)
    def report_grade(self, name, subject, score, weight):
        by_subject = self._grades[name]
        grade_list = by_subject[subject]
        grade_list.append((score, weight))
    def average_grade(self, name):
        by_subject = self._grades[name]
        score_sum, score_count = 0, 0
        for subject, scores in by_subject.items():
            subject_avg, total_weight = 0, 0
            for score, weight in scores:
                subject_avg += score * weight
                total_weight += weight
            score_sum += subject_avg / total_weight
            score_count += 1
        return score_sum / score_count
book = WeightedGradebook()
book.add_student('Albert Einstein')
book.report_grade('Albert Einstein', 'Math', 75, 0.05)
book.report_grade('Albert Einstein', 'Math', 65, 0.15)
book.report_grade('Albert Einstein', 'Math', 70, 0.80)
book.report_grade('Albert Einstein', 'Gym', 100, 0.40)
book.report_grade('Albert Einstein', 'Gym', 85, 0.60)
print(book.average_grade('Albert Einstein'))
>>>
80.25

超过一层的嵌套尽量就不要继续用了。(维护噩梦)
应该重构成类。

grades = []
grades.append((95, 0.45))
grades.append((85, 0.55))
total = sum(score * weight for score, weight in grades)
total_weight = sum(weight for _, weight in grades)
average_grade = total / total_weight

如果要加一些教师评价,可能就会引入很多下划线_:

grades = []
grades.append((95, 0.45, 'Great job'))
grades.append((85, 0.55, 'Better next time'))
total = sum(score * weight for score, weight, _ in grades)
total_weight = sum(weight for _, weight, _ in grades)
average_grade = total / total_weight

这里,namedtuple刚好符合要求:

from collections import namedtuple

Grade = namedtuple('Grade', ('score', 'weight'))

但是,namedtuple也有限制:

不能指定默认参数。
当你的数据有很多可选的属性时,这点就很不好。属性多的时候用built-in可能更合适。
namedtuple的属性值仍可访问。如果不能控制它们的使用,最好还是显式地定义一个新的类。

class Subject:
    def __init__(self):
        self._grades = []
    def report_grade(self, score, weight):
        self._grades.append(Grade(score, weight))
    def average_grade(self):
        total, total_weight = 0, 0
        for grade in self._grades:
            total += grade.score * grade.weight
            total_weight += grade.weight
        return total / total_weight

class Student:
    def __init__(self):
        self._subjects = defaultdict(Subject)
    def get_subject(self, name):
        return self._subjects[name]
    def average_grade(self):
        total, count = 0, 0
        for subject in self._subjects.values():
            total += subject.average_grade()
            count += 1
        return total / count

class Gradebook:
    def __init__(self):
        self._students = defaultdict(Student)
    def get_student(self, name):
        return self._students[name]

book = Gradebook()
albert = book.get_student('Albert Einstein')
math = albert.get_subject('Math')
math.report_grade(75, 0.05)
math.report_grade(65, 0.15)
math.report_grade(70, 0.80)
gym = albert.get_subject('Gym')
gym.report_grade(100, 0.40)
gym.report_grade(85, 0.60)
print(albert.average_grade())
>>>
80.25

  • Item38: 对于简单的接口,接受函数而不是类

许多built-in的API允许传递函数。这些钩子(hooks)被API回调。比如:sort函数的key参数可以传递函数:

names = ['Socrates', 'Archimedes', 'Plato', 'Aristotle']
names.sort(key=len)
print(names)
>>>
['Plato', 'Socrates', 'Aristotle', 'Archimedes']

当然,还有很多例子,比如defaultdict的参数也可以是类名或者函数,就是需要返回默认的值。

如果定义为每次返回0:

def log_missing():
    print('Key added')
    return 0

先构建出current的result,再增量地加回去。默认值为log_missing返回的0。

from collections import defaultdict
current = {'green': 12, 'blue': 3}
increments = [
    ('red', 5),
    ('blue', 17),
    ('orange', 9),
]
result = defaultdict(log_missing, current)
print('Before:', dict(result))
for key, amount in increments:
    result[key] += amount
print('After: ', dict(result))
>>>
Before: {'green': 12, 'blue': 3}
Key added
Key added
After: {'green': 12, 'blue': 20, 'red': 5, 'orange': 9}

假如现在在添加的时候,需要统计添加的列别的数目,如下:(利用了闭包的属性,可以在内部进行统计。)

def increment_with_report(current, increments):
    added_count = 0
    def missing():
        nonlocal added_count # Stateful closure
        added_count += 1
        return 0
    result = defaultdict(missing, current)
    for key, amount in increments:
        result[key] += amount
    return result, added_count

尽管defaultdict不知道missing这个hook保持了什么状态信息,最终结果也可以得到为2。

result, count = increment_with_report(current, increments)
assert count == 2

其它的语言可能可以定义一个类来保持状态,然后传递这个实例的方法:

class CountMissing:
    def __init__(self):
        self.added = 0
    def missing(self):
        self.added += 1
        return 0

同样也是可以达到效果:

counter = CountMissing()
result = defaultdict(counter.missing, current) # Method ref
for key, amount in increments:
    result[key] += amount
assert counter.added == 2

虽然类比闭包清晰一些,但是CountMissing类的目的不是很显而易见,直到看到defaultdict的时候。(谁创建,谁调用missing,这个类未来需要其它的puclic方法吗?)

python允许类定义__call__的方法,调用callable时,如果该类实现了__call__会返回true。

class BetterCountMissing:
    def __init__(self):
        self.added = 0
    def __call__(self):
        self.added += 1
        return 0

counter = BetterCountMissing()
assert counter() == 0
assert callable(counter)

当key缺失的时候,会调用一次counter,即其call方法。

counter = BetterCountMissing()
result = defaultdict(counter, current) # Relies on __call__
for key, amount in increments:
result[key] += amount
assert counter.added == 2

这样,就可以很方便的实现上面的需求。


  • Item39: 用@classmethod多态来泛化(泛型)地构建对象

不止对象支持多态,类也同样支持,有什么好处?
多态允许多个类在一个层级制度下实现它们自己的特有的版本。这意味着许多类可以提供不同的功能给同一个接口或者抽象类。
比如,现在在写MapReduce的实现,要一个公共的抽象类来表示输入数据:

class InputData:
    def read(self):
        raise NotImplementedError

从磁盘上的文件读数据:

class PathInputData(InputData):
    def __init__(self, path):
        super().__init__()
        self.path = path
    def read(self):
        with open(self.path) as f:
            return f.read()

我可以有很多种InputData,比如NetworkInputData。而对于MapReduce的worker来说,需要输入和消费这些数据:

class Worker:
    def __init__(self, input_data):
        self.input_data = input_data
        self.result = None
    def map(self):
        raise NotImplementedError
    def reduce(self, other):
        raise NotImplementedError

此时,有一个具体的获取行数的Worker:

class LineCountWorker(Worker):
    def map(self):
        data = self.input_data.read() # 读数据
        self.result = data.count('\n')  # 当前数据的行数
    def reduce(self, other):
        self.result += other.result # 合并其它的Worker的结果。

似乎需要一个helper函数来生成数据。

import os
def generate_inputs(data_dir):
    for name in os.listdir(data_dir):
        yield PathInputData(os.path.join(data_dir, name))

然后根据这些数据,来生成worker:

def create_workers(input_list):
    workers = []
    for input_data in input_list:
        workers.append(LineCountWorker(input_data))
    return workers

然后调用map来分散到各个线程计算,最后用reduce来产生最终结果:

from threading import Thread
def execute(workers):
    threads = [Thread(target=w.map) for w in workers]
    for thread in threads: thread.start()
    for thread in threads: thread.join()
    first, *rest = workers
    for worker in rest:
        first.reduce(worker)
    return first.result

最后把几个helper连接到一起返回结果:

def mapreduce(data_dir):
    inputs = generate_inputs(data_dir)
    workers = create_workers(inputs)
    return execute(workers)

随机生成一些文件,发现可以工作得很好:

import os
import random

def write_test_files(tmpdir):
    os.makedirs(tmpdir)
    for i in range(100):
        with open(os.path.join(tmpdir, str(i)), 'w') as f:
            f.write('\n' * random.randint(0, 100))

tmpdir = 'test_inputs'
write_test_files(tmpdir)

result = mapreduce(tmpdir)
print(f'There are {result} lines')

>>>
There are 4360 lines

问题出现在哪?mapreduce方法不够泛化。如果我要写另一种InputData或者Worker的子类,需要重写上面的几个方法来匹配。

最好的方式是用类多态(因为init只有一个,对每个InputData的子类来写适配的constructor不合理。)
使用了@classmethod来创建新的InputData:

class GenericInputData:
    def read(self):
        raise NotImplementedError
    @classmethod
    def generate_inputs(cls, config):
        raise NotImplementedError

用config来找到字典值来处理:

class PathInputData(GenericInputData):
    ...
    @classmethod
    def generate_inputs(cls, config):
        data_dir = config['data_dir']
        for name in os.listdir(data_dir):
            yield cls(os.path.join(data_dir, name))

类似地,可以创建泛型Worker。用cls()创建特定的子类。

class GenericWorker:
    def __init__(self, input_data):
        self.input_data = input_data
        self.result = None
    def map(self):
        raise NotImplementedError
    def reduce(self, other):
        raise NotImplementedError
    @classmethod
    def create_workers(cls, input_class, config):
        workers = []
        for input_data in input_class.generate_inputs(config):
            workers.append(cls(input_data))
        return workers

注意到调用input_class.generate_inputs是类的多态。可以看到create_workers调用了cls()来提供额外的方式来构建GenericWorker(用到__init__)

class LineCountWorker(GenericWorker):
    ...

最后,重写mapreduce函数:

def mapreduce(worker_class, input_class, config):
    workers = worker_class.create_workers(input_class,
config)
    return execute(workers)
config = {'data_dir': tmpdir}
result = mapreduce(LineCountWorker, PathInputData, config)
print(f'There are {result} lines')
>>>
There are 4360 lines

可以看出,通过@classmethod的cls可以建立具体类的连接。


  • Item40: 用super来初始化父类

古老且简单的方式来初始化父类是直接调用父类的__init__方法:

class MyBaseClass:
    def __init__(self, value):
        self.value = value
class MyChildClass(MyBaseClass):
    def __init__(self):
        MyBaseClass.__init__(self, 5)

但是在许多情况下失效。比如定义类来操作实例变量value。

class TimesTwo:
    def __init__(self):
        self.value *= 2
class PlusFive:
    def __init__(self):
        self.value += 5

构建的时候,继承的时候是匹配结果的顺序。

class OneWay(MyBaseClass, TimesTwo, PlusFive):
    def __init__(self, value):
        MyBaseClass.__init__(self, value)
        TimesTwo.__init__(self)
        PlusFive.__init__(self)

结果为:

foo = OneWay(5)
print('First ordering value is (5 * 2) + 5 =', foo.value)
>>>
First ordering value is (5 * 2) + 5 = 15

另一种是定义一样的父类但是不一样的顺序:

class AnotherWay(MyBaseClass, PlusFive, TimesTwo):
    def __init__(self, value):
        MyBaseClass.__init__(self, value)
        TimesTwo.__init__(self)
        PlusFive.__init__(self)

定义和实现的顺序不同。这种顺序比较难发现,对于新手来说不友好。

bar = AnotherWay(5)
print('Second ordering value is', bar.value)

>>>
Second ordering value is 15

另一个问题发生在菱形继承。比如两个类继承同一个类:

class TimesSeven(MyBaseClass):
    def __init__(self, value):
        MyBaseClass.__init__(self, value)
        self.value *= 7
class PlusNine(MyBaseClass):
    def __init__(self, value):
        MyBaseClass.__init__(self, value)
        self.value += 9

然后定义一个类继承这两个类:

class ThisWay(TimesSeven, PlusNine):
    def __init__(self, value):
        TimesSeven.__init__(self, value)
        PlusNine.__init__(self, value)

foo = ThisWay(5)
print('Should be (5 * 7) + 9 = 44 but is', foo.value)
>>>
Should be (5 * 7) + 9 = 44 but is 14

由于__init__再次被调用,因此结果变为5+9=14,如果情况更复杂的话,这点是比较难以debug的。

为了解决这些问题,Python自带了super自建的函数还有标准方法解析顺序(MRO)。super确保了公共的父类只运行一次。MRO定义了父类被初始化的顺序(以C3线性(C3 linearization)算法的顺序进行)

class TimesSevenCorrect(MyBaseClass):
    def __init__(self, value):
        super().__init__(value)
        self.value *= 7
class PlusNineCorrect(MyBaseClass):
    def __init__(self, value):
        super().__init__(value)
        self.value += 9

现在,正确地运行如下:

class GoodWay(TimesSevenCorrect, PlusNineCorrect):
def __init__(self, value):
    super().__init__(value)

foo = GoodWay(5)
print('Should be 7 * (5 + 9) = 98 and is', foo.value)
>>>
Should be 7 * (5 + 9) = 98 and is 98

顺序看着是反着来的,实际是根据MRO的顺序来的:

mro_str = '\n'.join(repr(cls) for cls in GoodWay.mro())
print(mro_str)
>>>





super的两个参数:MRO父视图的类类型、访问这个视图的实例。

class ExplicitTrisect(MyBaseClass):
    def __init__(self, value):
        super(ExplicitTrisect, self).__init__(value)
        self.value /= 3

对于object实例的初始化,参数不是要求的。(因为如果使用super(),编译器会自动提供正确的参数__class__和self,因此,下面几种都是等价的。)

class AutomaticTrisect(MyBaseClass):
    def __init__(self, value):
        super(__class__, self).__init__(value)
        self.value /= 3
class ImplicitTrisect(MyBaseClass):
    def __init__(self, value):
        super().__init__(value)
        self.value /= 3

assert ExplicitTrisect(9).value == 3
assert AutomaticTrisect(9).value == 3
assert ImplicitTrisect(9).value == 3

  • Item41: 考虑用Mix-in类来组合功能性

最好还是避免多继承,考虑编写mix-in(定义了小的、额外的方法类,供子类使用)。

比如,假如现在需要从内存表示转换Python对象到序列化的字典:

class ToDictMixin:
    def to_dict(self):
        return self._traverse_dict(self.__dict__)

用hasattr来进行动态属性访问,用isinstance来进行动态类检查。并且访问实例字典__dict__:

def _traverse_dict(self, instance_dict):
    output = {}
    for key, value in instance_dict.items():
        output[key] = self._traverse(key, value)
    return output
def _traverse(self, key, value):
    if isinstance(value, ToDictMixin):
        return value.to_dict()
    elif isinstance(value, dict):
        return self._traverse_dict(value)
    elif isinstance(value, list):
        return [self._traverse(key, i) for i in value]
    elif hasattr(value, '__dict__'):
        return self._traverse_dict(value.__dict__)
    else:
        return value

这里定义了一个类来使得字典表达为二叉树:

class BinaryTree(ToDictMixin):
    def __init__(self, value, left=None, right=None):
        self.value = value
        self.left = left
        self.right = right
# 把大量的对象转换成字典变得容易:
tree = BinaryTree(10,
    left=BinaryTree(7, right=BinaryTree(9)),
    right=BinaryTree(13, left=BinaryTree(11)))
print(tree.to_dict())
>>>
{'value': 10,
'left': {'value': 7,
          'left': None,
          'right': {'value': 9, 'left': None, 'right':
          None}},
'right': {'value': 13,
          'left': {'value': 11, 'left': None, 'right':
          None},
          'right': None}}

定义了BinaryTree的子类,带着父节点的引用。这个循环引用可能会导致ToDictMixin.to_dict无限循环:

class BinaryTreeWithParent(BinaryTree):
    def __init__(self, value, left=None,
                 right=None, parent=None):
        super().__init__(value, left=left, right=right)
        self.parent = parent

解决方案就是重写(override)此类中的_traverse方法,使得方法只处理数值,避免mix-in带来循环。这里给了父节点的数值,否则就用默认的实现。

def _traverse(self, key, value):
    if (isinstance(value, BinaryTreeWithParent) and
            key == 'parent'):
        return value.value # Prevent cycles
    else:
        return super()._traverse(key, value)

调用BinaryTreeWithParent.to_dict没有问题,因为循环引用的属性不被允许:

root = BinaryTreeWithParent(10)
root.left = BinaryTreeWithParent(7, parent=root)
root.left.right = BinaryTreeWithParent(9, parent=root.left)
print(root.to_dict())
>>>
{'value': 10,
'left': {'value': 7,
         'left': None,
         'right': {'value': 9,
                   'left': None,
                   'right': None,
                   'parent': 7},
         'parent': 10},
'right': None,
'parent': None}

可以使得拥有类型BinaryTreeWithParent的属性的类自动和ToDictMixin工作得很好。

class NamedSubTree(ToDictMixin):
    def __init__(self, name, tree_with_parent):
        self.name = name
        self.tree_with_parent = tree_with_parent

my_tree = NamedSubTree('foobar', root.left.right)
print(my_tree.to_dict()) # No infinite loop

>>>
{'name': 'foobar',
'tree_with_parent': {'value': 9,
                     'left': None,
                     'right': None,
                     'parent': 7}}

Mix-in可以被组合。比如,需要提供JSON序列化:

import json

class JsonMixin:
    @classmethod
    def from_json(cls, data):
        kwargs = json.loads(data)
        return cls(**kwargs)
    def to_json(self):
        return json.dumps(self.to_dict())

JsonMixin定义了两个方法,下面是数据中心的拓扑结构:

class DatacenterRack(ToDictMixin, JsonMixin):
    def __init__(self, switch=None, machines=None):
        self.switch = Switch(**switch)
        self.machines = [
            Machine(**kwargs) for kwargs in machines]
class Switch(ToDictMixin, JsonMixin):
    def __init__(self, ports=None, speed=None):
        self.ports = ports
        self.speed = speed
class Machine(ToDictMixin, JsonMixin):
    def __init__(self, cores=None, ram=None, disk=None):
        self.cores = cores
        self.ram = ram
        self.disk = disk

这里测试了从json中加载对象,然后序列化回json的整个闭环:

serialized = """{
    "switch": {"ports": 5, "speed": 1e9},
    "machines": [
        {"cores": 8, "ram": 32e9, "disk": 5e12},
        {"cores": 4, "ram": 16e9, "disk": 1e12},
        {"cores": 2, "ram": 4e9, "disk": 500e9}
    ]
}"""
deserialized = DatacenterRack.from_json(serialized)
roundtrip = deserialized.to_json()
assert json.loads(serialized) == json.loads(roundtrip)

可以看出,用这种插件类的方式,也可以实现很多灵活性。


  • Item42: 使用公有属性而不是私有属性

在Python中,有两种可见性:public和private

class MyObject:
    def __init__(self):
        self.public_field = 5
        self.__private_field = 10
    def get_private_field(self):
        return self.__private_field

公有直接访问:

foo = MyObject()
assert foo.public_field == 5

私有通过get方法获得:

assert foo.get_private_field() == 10

直接访问会引发Error:

foo.__private_field
>>>
Traceback ...
AttributeError: 'MyObject' object has no attribute '__private_field'

类方法同样有访问私有属性的权限,因为它们在类内被声明:

class MyOtherObject:
    def __init__(self):
        self.__private_field = 71
    @classmethod
    def get_private_field_of_instance(cls, instance):
        return instance.__private_field

bar = MyOtherObject()
assert MyOtherObject.get_private_field_of_instance(bar) == 71

继承访问不到父类的私有域:

class MyParentObject:
    def __init__(self):
        self.__private_field = 71

class MyChildObject(MyParentObject):
    def get_private_field(self):
        return self.__private_field

baz = MyChildObject()
baz.get_private_field()
>>>
Traceback ...
AttributeError: 'MyChildObject' object has no attribute
'_MyChildObject__private_field'

私有域的实现是简单地把属性名做了个转换。比如__private_field其实被转换成_MyChildObject__private_field。如果是指代父类的__private_field,则是被转换成了_MyParentObject__private_field。知道这个规则的话,就可以直接访问到对应的属性值了:

assert baz._MyParentObject__private_field == 71

或者直接通过__dict__来查看类内的属性:

print(baz.__dict__)
>>>
{'_MyParentObject__private_field': 71}

Python为了功能性,用户实际上可以绕开private。
根据Item2的PEP8的风格指引:一个下划线_protected_field表示保护域,表示使用类的外界用户需要小心处理。而私有域则是不希望被外界使用和继承。

class MyStringClass:
    def __init__(self, value):
        self.__value = value
    def get_value(self):
        return str(self.__value)

foo = MyStringClass(5)
assert foo.get_value() == '5'

这是错误的方式。

class MyIntegerSubclass(MyStringClass):
def get_value(self):
return int(self._MyStringClass__value)
foo = MyIntegerSubclass('5')
assert foo.get_value() == 5
class MyBaseClass:
def __init__(self, value):
self.__value = value
def get_value(self):
return self.__value
class MyStringClass(MyBaseClass):
def get_value(self):
return str(super().get_value()) # Updated
class MyIntegerSubclass(MyStringClass):
def get_value(self):
return int(self._MyStringClass__value) # Not updated
foo = MyIntegerSubclass(5)
foo.get_value()
>>>
Traceback ...
AttributeError: 'MyIntegerSubclass' object has no attribute
'_MyStringClass__value'

最好还是以protected的形式,同时给出注释,告诉其他人这是内部的。

class MyStringClass:
    def __init__(self, value):
        # This stores the user-supplied value for the object.
        # It should be coercible to a string. Once assigned
in
        # the object it should be treated as immutable.
        self._value = value
...

需要考虑的是使用私有属性来区分变量名:

class ApiClass:
    def __init__(self):
        self._value = 5
    def get(self):
        return self._value
class Child(ApiClass):
    def __init__(self):
        super().__init__()
        self._value = 'hello' # Conflicts

a = Child()
print(f'{a.get()} and {a._value} should be different')

>>>
hello and hello should be different

为了减少变量名被覆盖的风险,区别域是一种可行的选择:

class ApiClass:
    def __init__(self):
        self.__value = 5 # Double underscore
    def get(self):
        return self.__value # Double underscore
class Child(ApiClass):
    def __init__(self):
        super().__init__()
        self._value = 'hello' # OK!
a = Child()
print(f'{a.get()} and {a._value} are different')

>>>
5 and hello are different

  • Item43: 继承collections.abc,来定制Container类型

每个Python类是一个容器,封装属性和功能。同时内部还提供了很多的容器类型(比如:list,tuple,set和dict)。比如现在要统计元素的频率:

class FrequencyList(list):
    def __init__(self, members):
        super().__init__(members)
    def frequency(self):
        counts = {}
        for item in self:
            counts[item] = counts.get(item, 0) + 1
        return counts

通过继承list,可以得到list的基础功能。然后可以定义方法来提供定制的功能:

foo = FrequencyList(['a', 'b', 'a', 'c', 'b', 'a', 'd'])
print('Length is', len(foo))

foo.pop()
print('After pop:', repr(foo))
print('Frequency:', foo.frequency())
>>>
Length is 7
After pop: ['a', 'b', 'a', 'c', 'b', 'a']
Frequency: {'a': 3, 'b': 2, 'c': 1}

现在,假设我要提供一个类似list的取下标功能,但是针对二叉树的结点:

class BinaryNode:
    def __init__(self, value, left=None, right=None):
        self.value = value
        self.left = left
        self.right = right

如何使得这个类像序列一样工作?即:

bar = [1, 2, 3]
bar[0]
# 实际上就是:
bar.__getitem__(0)

可以提供__getitem__的实现:使用前序遍历,每次记录index。

class IndexableNode(BinaryNode):
    def _traverse(self):
        if self.left is not None:
            yield from self.left._traverse()
        yield self
        if self.right is not None:
            yield from self.right._traverse()
    def __getitem__(self, index):
        for i, item in enumerate(self._traverse()):
            if i == index:
                return item.value
        raise IndexError(f'Index {index} is out of range')

可以构建二叉树如下:

tree = IndexableNode(
    10,
    left=IndexableNode(
        5,
        left=IndexableNode(2),
        right=IndexableNode(
            6,
            right=IndexableNode(7))),
    right=IndexableNode(
        15,
        left=IndexableNode(11)))

可以像list一样进行访问:

print('LRR is', tree.left.right.right.value)
print('Index 0 is', tree[0])
print('Index 1 is', tree[1])
print('11 in the tree?', 11 in tree)
print('17 in the tree?', 17 in tree)
print('Tree is', list(tree))
>>>
LRR is 7
Index 0 is 2
Index 1 is 5
11 in the tree? True
17 in the tree? False
Tree is [2, 5, 6, 7, 10, 11, 15]

问题是实现了__getitem__对于list的功能并不齐全,比如:

len(tree)
>>>
Traceback ...
TypeError: object of type 'IndexableNode' has no len()

此时要实现__len__:

class SequenceNode(IndexableNode):
    def __len__(self):
        for count, _ in enumerate(self._traverse(), 1):
            pass
        return count
tree = SequenceNode(
    10,
    left=SequenceNode(
        5,
        left=SequenceNode(2),
        right=SequenceNode(
            6,
                right=SequenceNode(7))),
    right=SequenceNode(
        15,
        left=SequenceNode(11))
)
print('Tree length is', len(tree))
>>>
Tree length is 7

不幸的是,count和index方法还是无法使用。这就使得自己定义容器类比较困难。为了避免这个困难,collections.abc有一系列的抽象类提供:

from collections.abc import Sequence

class BadType(Sequence):
    pass

foo = BadType()
>>>
Traceback ...
TypeError: Can't instantiate abstract class BadType with abstract methods __getitem__, __len__

同时继承Sequence,可以满足一些方法,比如index,count等的使用:

class BetterNode(SequenceNode, Sequence):
    pass

tree = BetterNode(
    10,
    left=BetterNode(
        5,
        left=BetterNode(2),
        right=BetterNode(
            6,
            right=BetterNode(7))),
    right=BetterNode(
        15,
        left=BetterNode(11))
)

print('Index of 7 is', tree.index(7))
print('Count of 10 is', tree.count(10))

>>>
Index of 7 is 3
Count of 10 is 1

还有更多的比如Set和MutableMapping,可以来实现来匹配Python自建的容器类。排序也是如此(见Item73)

你可能感兴趣的:(Effective Python 笔记摘录5.1)