1.torch.load()函数介绍

首先介绍一下序列化和反序列化,下面会用到

一、什么是序列化与反序列化

序列化是将对象状态转换为可保持或传输的字节序列的过程。序列化的补集是反序列化,反序列化是将字节流转换为对象。两个过程一起保证能够存储和传输数据。

序列化最重要的作用:在传递和保存对象时.保证对象的完整性和可传递性。对象转换为有序字节流,以便在网络上传输或者保存在本地文件中。

二、开始介绍

torch.load(f, map_location=None, pickle_module=pickle, *, weights_only=False, **pickle_load_args)

从文件中加载用 torch.save ()保存的对象。

torch.load ()使用 Python 的unpickling(解包)工具处理存储,特别是以张量为基础的存储。它们首先在 CPU上反序列化,然后移动到保存它们的设备上。如果失败(例如,因为运行时系统没有特定的设备) ,将引发异常,比如你在gpu上训练保存的模型,而在cpu上加载,可能会报错。此时,使用 map _ location 参数将存储动态地重新映射到另一组设备,比如map_location=torch.device('cpu')意思是映射到cpu上,在cpu上加载模型,无论你这个模型从哪里训练保存的。

如果 map _ location 是可调用的,那么对于每个带有两个参数的序列化存储,它将被调用一次: Storage和location。存储参数将是驻留在 CPU 上的存储的初始反序列化。每个序列化存储都有一个与之关联的位置标记,用与标识它保存在哪个设备上,这个标记是传递给map _ location的第二个参数。内置的位置标签是对于CPU张量是“ cpu”和 对于CUDA张量是“ CUDA: device _ id”(例如“ cuda: 2”)。Map _ location应该返回 None或者一个存储。如果 map _ location 返回一个存储,它将被用作最终的反序列化对象,已经移动到了正确的设备上。否则,torch.load ()将退回到默认行为,就像没有指定map _ location一样。

如果 map _ location 是 torch.device 对象或包含设备标签的字符串,它指示加载所有张量的位置。

否则,如果 map _ location 是一个 dict,它将被用于重新映射出现在文件中的位置标签(key),以指定存储(value)的放置位置。

用户扩展可以使用 torch.seralization.register _ package ()注册自己的位置标记、标记和反序列化方法。

参数:

  • f (Union[str, PathLike, BinaryIO, IO[bytes]]) – a file-like object (has to implement read(), readline(), tell(), and seek()), or a string or os.PathLike object containing a file name

  • map_location (Optional[Union[Callable[[Tensor, str], Tensor], device, str, Dict[str, str]]]) – a function, torch.device, string or a dict specifying how to remap storage locations

  • pickle_module (Optional[Any]) – module used for unpickling metadata and objects (has to match the pickle_module used to serialize file)

  • weights_only (bool) – Indicates whether unpickler should be restricted to loading only tensors, primitive types and dictionaries

  • pickle_load_args (Any) – (Python 3 only) optional keyword arguments passed over to pickle_module.load() and pickle_module.Unpickler(), e.g., errors=....

返回:

Any

警告:

torch.load () ,除非 weight _ only 参数设置为 True,否则隐式使用 pickle 模块,这是已知的不安全的。可以构造恶意 pickle 数据,在解除 pickle 期间执行任意代码。永远不要加载可能来自不安全模式下的不可信源或可能已被篡改的数据,只加载您信任的数据。

注意:

当您对一个包含 GPU 张量的文件使用 torch.load ()调用它时,这些张量将默认加载到 GPU。可以调用 torch.load (...,map _ location = ‘ cpu’) ,然后 load _ state _ dict () ,以避免在加载模型checkpoint时出现 GPU RAM 急增。

还有一个注意点没写,看起来很高深

例子:

>>> torch.load('tensors.pt')
# Load all tensors onto the CPU
>>> torch.load('tensors.pt', map_location=torch.device('cpu'))
# Load all tensors onto the CPU, using a function
>>> torch.load('tensors.pt', map_location=lambda storage, loc: storage)
# Load all tensors onto GPU 1
>>> torch.load('tensors.pt', map_location=lambda storage, loc: storage.cuda(1))
# Map tensors from GPU 1 to GPU 0
>>> torch.load('tensors.pt', map_location={'cuda:1':'cuda:0'})
# Load tensor from io.BytesIO object
>>> with open('tensor.pt', 'rb') as f:
...     buffer = io.BytesIO(f.read())
>>> torch.load(buffer)
# Load a module with 'ascii' encoding for unpickling
>>> torch.load('module.pt', encoding='ascii')

你可能感兴趣的:(python,开发语言)