翻译torch.load函数

Hello,大家好!下面是对torch.load函数的翻译~

torch.load(f, map_location=None, pickle_module=, **pickle_load_args)

从文件加载用torch.save()保存的对象。

torch.load()使用Python的拆封(Unpickling:从存储的字符串文件中提取原始Python对象的过程,叫做unpickling)能力,但是会特别处理张量的存储空间。他们首先在CPU上反序列化,然后移动到保存他们的设备上。如果这个过程失败了(例如,因为运行系统没有某些设备),会导致例外的发生。然而,可以通过map_location参数将存储空间动态映射到另一套设备上。

如果map_location是可调用的,它将为每个序列化存储空间调用一次,并含有两个参数:storage和location。storage参数将会是CPU上的存储空间的初始反序列化。每一个序列化的存储空间都有一个与之相关的location标签,该标签表示了它存储的设备,并且是传给map_location的第二个参数。CPU张量的location标签是'cpu',CUDA张量的location标签为'cuda:device_id' (e.g. 'cuda:2')。map_location应该返回None或存储空间。如果map_location返回的是存储空间,它将被用作最后的反序列化对象,并且该对象已经被移到正确的设备上。否则,torch.load()就会恢复默认模式,就好像map_location没有指定一样。

如果map_location是torch.device对象或包含设备标签的字符串,那么它指明了所有张量将要被加载的位置。

相反,如果map_location是一个字典,它将被用作将出现在文件(keys)中的位置标签重新映射到那些明确指明存储空间(values)位置的标签。

用户扩展可以使用torch.serialization.register_package()来注册自己的位置标签、标记和反序列化方法。

参数

  • f:类似于文件的对象(必须实现read(),:meth`readline`, :meth`tell`, and :meth`seek`),或者是包含文件名称的字符串。
  • map_location:一个函数,torch.device,字符串或字典,明确如何重映射存储空间位置。
  • pickle_module:用于解开元数据和对象的模块(必须与序列化文件的pickle_module相匹配)
  • pickle_load_args:(只有Python3才有)可选择的关键字参数,并传递给pickle_module.load()和pickle_module.Unpickler(),比如,errors=...。

注意

但你在一个包含GPU张量的文件上调用torch.load()时,那些张量将会被默认的加载到GPU。当加载模型检查点时,可以通过调用torch.load(.., map_location='cpu'),然后load_state_dict(),来避免GPU RAM的激增。

默认情况下,我们将字节字符串解码为utf-8。这是为了避免在python3加载由python2保存的文件时出现的普遍的错误UnicodeDecodeError: 'ascii' codec can't decode byte 0x...。如果这个默认值是错误的话,你可以使用额外的encoding关键字参数来明确这些对象是如何被加载的,如encoding='latin1'用latin1将他们解码到字符串,encoding='bytes'将他们保持为字节数组,从而后续可以被byte_array.decode(...)解码。

例子

>>> torch.load('tensors.pt')
# Load all tensors onto the CPU
>>> torch.load('tensors.pt', map_location=torch.device('cpu'))
# Load all tensors onto the CPU, using a function
>>> torch.load('tensors.pt', map_location=lambda storage, loc: storage)
# Load all tensors onto GPU 1
>>> torch.load('tensors.pt', map_location=lambda storage, loc: storage.cuda(1))
# Map tensors from GPU 1 to GPU 0
>>> torch.load('tensors.pt', map_location={'cuda:1':'cuda:0'})
# Load tensor from io.BytesIO object
>>> with open('tensor.pt', 'rb') as f:
        buffer = io.BytesIO(f.read())
>>> torch.load(buffer)
# Load a module with 'ascii' encoding for unpickling
>>> torch.load('module.pt', encoding='ascii')

 

 

你可能感兴趣的:(深度学习,机器学习,pytorch)