Python中的上下文管理器

目录

  • 1. 上下文管理协议
  • 2. 常见的上下文管理器
    • 2.1 open()
    • 2.2 torch.no_grad()
  • 3. 异常处理
    • 3.1 try...except...else...finally
  • 4. contextlib
    • 4.1 contextmanager
    • 4.2 closing
    • 4.3 suppress
  • 5. 多个上下文管理器
  • 6. 深入探讨

1. 上下文管理协议

Python中的 with ... as ... 语句是一种上下文管理协议(Context Management Protocol),它用于自动管理资源,无论过程中是否发生错误,都可以保证对资源的正确处理。

最常见的使用场景是文件的读写。通常我们打开一个文件后,读写完之后需要关闭它。如果手动管理这个过程,可能会遇到因为忘记关闭文件或者程序在关闭文件前出错而导致的问题。with ... as ... 语句可以自动处理这些问题,确保文件在操作完成后被关闭。

示例代码如下:

with open('file.txt', 'r') as f:
    content = f.read()
# 在这一行,f已经被自动关闭,无需手动调用f.close()

上下文管理协议由两个方法组成:__enter__()__exit__()这两个方法需要在一个类中实现__enter__() 方法在实例化后被调用,它的返回值会被赋给 as 关键字后的变量(如果没有提供返回值则不需要 as,例如 torch.no_grad())。__exit__() 方法在 with 语句体执行结束后被调用,无论 with 语句体是否发生异常,都会执行这个方法。

实现了上下文管理协议的对象被称为上下文管理器,只有上下文管理器才可以使用 with ... as ... 语句。

例如:

class MyContext:
    def __init__(self):
        print('init')

    def __enter__(self):
        print('enter')
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        print('exit')

    def excute(self):
        print('excute')


with MyContext() as mc:
    mc.excute()

# init
# enter
# excute
# exit

2. 常见的上下文管理器

2.1 open()

open() 函数会返回 _io.TextIOWrapper 对象,该对象就是一个上下文管理器。_io.TextIOWrapper 类是用C语言实现的,关于其上下文协议,以下给出了这部分源码的简化Python实现:

class _IOBase:
    def __enter__(self):
        # 返回 self,使得能够在 with ... as 变量 中使用
        return self

    def __exit__(self, type, value, traceback):
        # 在退出上下文时自动关闭文件
        self.close()


class TextIOWrapper(_IOBase):
    # TextIOWrapper 具体的实现 ...
    pass

我们当然也可以基于 open() 函数实现一个自己的文件管理器:

class TextFileWrapper:
    def __init__(self, file, encoding='utf-8', mode='r'):
        self.file = file
        self.encoding = encoding
        self.mode = mode

    def __enter__(self):
        self.open_file = open(self.file, self.mode, encoding=self.encoding)
        return self.open_file

    def __exit__(self, exc_type, exc_value, traceback):
        self.open_file.close()


with TextFileWrapper('example.txt', mode='w') as f:
    f.write('Hello, world!')

with TextFileWrapper('example.txt') as f:
    content = f.read()
    print(content)

注意到 _io.TextIOWrapper 还是一个可迭代对象,因此我们可以对其调用 list() 来获取其中的内容。返回的列表中,每一个元素都代表了文件中的一行内容(会包含末尾的换行符)。

2.2 torch.no_grad()

torch.no_grad() 也是一个上下文管理器,不同之处在于,它的 __enter__() 方法没有返回值,因此也就不需要 as

class no_grad(_DecoratorContextManager):
    def __init__(self) -> None:
        if not torch._jit_internal.is_scripting():
            super().__init__()
        self.prev = False

    def __enter__(self) -> None:
        self.prev = torch.is_grad_enabled()
        torch.set_grad_enabled(False)

    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
        torch.set_grad_enabled(self.prev)

self.prev 是一个布尔变量,保存了在进入上下文之前的梯度计算状态,随后执行 torch.set_grad_enabled(False) 来关闭梯度计算。退出上下文后,torch.set_grad_enabled(self.prev) 用来恢复到之前的梯度计算状态。

我们通常会在模型推理阶段关闭梯度计算:

with torch.no_grad():
	# inference code...

注意到 no_grad 类继承了 _DecoratorContextManager,因此它还可以当作装饰器来用:

@torch.no_grad()
def inference(*args, **kwargs):
    # Your code...


inference()

3. 异常处理

__exit__() 方法中必须提供三个参数exc_typeexc_valexc_tb(不一定非要这样起名):

  • exc_type:异常类型(type)。exc_type 是一个 Python 类型对象,表示引发的异常的类型。
  • exc_val:异常值(value)。exc_val 是一个包含有关异常的详细信息的对象。它可以包含有关异常的描述性信息,以便你了解异常的具体内容。
  • exc_tb:异常回溯(traceback)。exc_tb 是一个 traceback 对象,用于跟踪异常发生的位置。它包含有关异常的堆栈跟踪信息,显示了异常是如何传播到当前代码位置的。

如果 with 语句体中没有发生异常,上述三个参数的值均为 None

class MyContext:
    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        print(f"exc_type: {exc_type}")
        print(f"exc_val: {exc_val}")
        print(f"exc_tb: {exc_tb}")


with MyContext() as mc:
    print(f"mc: {mc}")

# mc: <__main__.MyContext object at 0x101042ce0>
# exc_type: None
# exc_val: None
# exc_tb: None

我们可以手动构造异常:

with MyContext() as mc:
    a = 3 + '2'
    print(f"in context")

print(f"out context")

# exc_type: 
# exc_val: unsupported operand type(s) for +: 'int' and 'str'
# exc_tb: 
# Traceback (most recent call last):
#   File "/temp.py", line 12, in 
#     a = 3 + '2'
# TypeError: unsupported operand type(s) for +: 'int' and 'str'

可以看到两条 print 语句都没有被执行。这说明上下文管理器在遇到异常时会直接执行 __exit__()

__exit__() 方法的默认返回值是 None,代表不抑制异常。我们可以让它返回 True 以抑制异常:

class MyContext:
    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        print(f"exc_type: {exc_type}")
        print(f"exc_val: {exc_val}")
        print(f"exc_tb: {exc_tb}")
        return True


with MyContext() as mc:
    a = 3 + '2'
    print(f"in context")

print(f"out context")

# exc_type: 
# exc_val: unsupported operand type(s) for +: 'int' and 'str'
# exc_tb: 
# out context

由此可以总结出,上下文管理器在遇到异常时会直接执行 __exit__()。上下文之外的语句是否会被执行取决于是否在 __exit__() 中抑制了异常(否则异常会被传播到上下文之外)。

3.1 try…except…else…finally

除了上下文管理器,我们还可以用 try...except...else...finally 来优雅地处理异常。

  • try块: 通常会在这个块中放置可能会引发异常的代码。

  • except块: 如果在try块中发生异常,Python会跳转到except块来处理异常。你可以在except块中指定要捕获的异常类型,或者只使用except来捕获所有异常。你还可以访问异常对象以获取异常的详细信息。

  • else块: 只有在try块中没有引发异常时,才会执行else块中的代码。这个块用于处理正常情况下的逻辑。

  • finally块: finally块中的代码总是会执行,无论是否发生异常。这里通常放置一些清理代码,例如关闭文件或释放资源。

一些示例:

try:
    print('try')
except:
    print('except')
else:
    print('else')
finally:
    print('finally')
# try
# else
# finally

try:
    print('try')
    a = 1 / 0
except:
    print('except')
else:
    print('else')
finally:
    print('finally')
# try
# except
# finally

我们可以使用 except 来捕获特定的异常:

try:
    print('try')
    a = 1 / 0
except ZeroDivisionError as e:
    print(e)
else:
    print('else')
finally:
    print('finally')
# try
# division by zero
# finally

很多时候我们不知道try块中可能会引发什么样的异常,为此我们可以使用绝大部分异常的一个父类 Exception(点击查看Python中的所有异常):

try:
    print('try')
    a = 1 / 0
except Exception as e:
    print(e)
else:
    print('else')
finally:
    print('finally')
# try
# division by zero
# finally

如果你不放心甚至可以使用 BaseException,但对于大部分场景 Exception 已经足够。

假如没有捕获到,那么异常还是会抛出:

try:
    print('try')
    a = 1 / 0
except FileExistsError as e:
    print(e)
else:
    print('else')
finally:
    print('finally')
# try
# finally
# Traceback (most recent call last):
#   File "/temp.py", line 3, in 
#     a = 1 / 0
# ZeroDivisionError: division by zero

我们可以使用多个 except 对每种异常分别处理:

try:
    print('try')
    a = 1 / 0
except ZeroDivisionError as e:
    print(e)
except FileExistsError as e:
    print(e)
else:
    print('else')
finally:
    print('finally')
# try
# division by zero
# finally

也可以同时对多个异常进行处理:

try:
    print('try')
    a = 1 / 0
except (ZeroDivisionError, FileExistsError) as e:
    print(e)
else:
    print('else')
finally:
    print('finally')
# try
# division by zero
# finally

注意,无论你怎样使用 except,它都会仅仅捕获try块中的第一个异常,且try块中第一个异常之后到第一个except之前的语句都不会被执行。

4. contextlib

4.1 contextmanager

每次构建自己的上下文管理器都要写一个类未免有些麻烦,好在 contextlib.contextmanager 提供了一种简洁的实现方式。

contextmanager 用来装饰一个函数(必须是生成器)并将其变成上下文管理器。函数中 yield 之前的语句会被当成 __enter__() 方法执行,yield 之后的语句会被当成 __exit__() 方法执行,yield 的返回值会被赋给 as 后的变量。

@contextlib.contextmanager
def my_context():
    print('enter')
    yield 1
    print('exit')


with my_context() as mc:
    print(mc)

# enter
# 1
# exit

需要注意的是,contextmanager 并不会自动处理被装饰函数内部发生的异常,我们需要手动用 try...except...else...finally 语句进行处理。

4.2 closing

closing 用于创建一个上下文管理器,主要用于确保在使用完一个对象后关闭它。通常,closing 函数用于处理需要手动关闭的资源,比如文件、网络连接等。

先来看一下 closing 的源码:

class closing(AbstractContextManager):
    def __init__(self, thing):
        self.thing = thing

    def __enter__(self):
        return self.thing

    def __exit__(self, *exc_info):
        self.thing.close()

可以看出它会自动调用传入参数的 close 方法。

class Test:
    def close(self):
        print('closed')

    def __str__(self):
        return 'MyTest'


with closing(Test()) as t:
    print(t)
# MyTest
# closed

4.3 suppress

suppress 用于临时禁止或忽略指定的异常,通常在需要处理异常但不想中断程序流程的情况下使用。

先来看一下源码:

class suppress(AbstractContextManager):
    def __init__(self, *exceptions):
        self._exceptions = exceptions

    def __enter__(self):
        pass

    def __exit__(self, exctype, excinst, exctb):
        return exctype is not None and issubclass(exctype, self._exceptions)

可以看出只有触发了异常并且异常是 exceptions 中的一个时,__exit__() 才返回 True

with contextlib.suppress(ZeroDivisionError):
    print(1)
    raise ZeroDivisionError
    print(2)
# 1

with contextlib.suppress(ZeroDivisionError):
    print(1)
    raise FileExistsError
    print(2)
# 1
# Traceback (most recent call last):
#   File "/1.py", line 11, in 
#     raise FileExistsError
# FileExistsError

with contextlib.suppress(ZeroDivisionError, FileExistsError, FileNotFoundError):
    print(1)
    raise FileNotFoundError
    print(2)
# 1

5. 多个上下文管理器

我们可以使用 with 语句同时处理多个上下文管理器(3.1版本之后):

with A() as a, B() as b:
    BODY

当需要处理的上下文管理器过多时,我们可以使用圆括号将它们括起来,这样更加简洁(3.10版本之后):

with (
    A() as a,
    B() as b,
    C() as c,
):
    BODY

有了这种语法,我们就可以同时对多个文件进行读写了:

with open('a.txt') as fa, open('b.txt') as fb:
    content_a = fa.readlines()
    content_b = fb.readlines()

假如我们要同时处理的文件有很多个,一个个写 open 不太现实,这时候我们就可以自己定义一个上下文管理器了:

class open_many:
    def __init__(self, files: List[str], mode='r'):
        self.files = files
        self.mode = mode

    def __enter__(self):
        self.fds = [open(file=f, mode=self.mode) for f in self.files]
        return self.fds

    def __exit__(self, exc_type, exc_val, exc_tb):
        for f in self.fds:
            f.close()


with open_many(['a.txt', 'b.txt', 'c.txt'], 'r') as files:
    content = [f.readlines() for f in files]

我们也可以使用 contextmanager 来实现这一功能:

@contextlib.contextmanager
def open_many(files: List[str], mode='r'):
    fds = [open(file=f, mode=mode) for f in files]
    try:
        yield fds
    finally:
        for f in fds:
            f.close()

6. 深入探讨

sys.exc_info() 记录了当前线程中的异常信息,它是一个三元组 (exc_type, exc_val, exc_tb)。当没有异常发生时,它的值为 (None, None, None)

我们可以使用 try...except... 来捕获有异常情况下 sys.exc_info() 的值:

import sys

try:
    a = 1 / 0
except:
    typ, val, tb = sys.exc_info()
    print(typ)
    print(val)
    print(tb)
# 
# division by zero
# 

基于此,我们可以尝试从另一个角度审视 with...as... 语句:

with EXPRESSION as TARGET:
    BODY

一开始会执行 EXPRESSION 进行实例化,假设实例化后的对象是 manager,之后会调用 manager__enter__() 方法并把返回结果赋值给 TARGET。然后开始执行 BODY,如果 BODY 没有异常,那么就调用 manager__exit__() 方法正常退出;如果 BODY 有异常,则我们要看 manager__exit__() 方法是否返回 True,如果不返回,则抛出异常。

我们可以用 try...except...finally 来重新表述 with...as...

manager = (EXPRESSION)
enter = type(manager).__enter__
exit = type(manager).__exit__
hit_except = False

try:
    TARGET = enter(manager)
    SUITE
except:
    hit_except = True
    if not exit(manager, *sys.exc_info()):
        raise
finally:
    if not hit_except:
        exit(manager, None, None, None)

你可能感兴趣的:(Python,python,开发语言,上下文管理)