torch.load()

torch.load(f, map_location=None, pickle_module=, **pickle_load_args)[source]

从文件中加载一个用torch.save()保存的对象。

load()使用Python的unpickling工具,但是专门处理存储,它是张量的基础。他们首先在CPU上并行化,然后移动到保存它们的设备。如果失败(例如,因为运行时系统没有某些设备),就会引发异常。但是,可以使用map_location参数动态地将存储重新映射到另一组设备。storage参数是存储的初始反序列化,驻留在CPU上。storage参数是存储的初始反序列化,驻留在CPU上。每个序列化存储都有一个与之关联的位置标记,它标识保存它的设备,这个标记是传递给map_location的第二个参数。内置的位置标签是“cpu”为cpu张量和“cuda:device_id”(例如:device_id)。“cuda:2”)表示cuda张量。map_location应该返回None或一个存储。如果map_location返回一个存储,它将被用作最终的反序列化对象,已经移动到正确的设备。否则,torch.load()将退回到默认行为,就好像没有指定map_location一样。如果map_location 是可以调用的,那么对于带有两个参数:存储和位置的序列化存储将被调用一次。如果map_location是一个torch.device对象或一个包含设备标签的字符串,它表示所有张量应该被加载的位置。否则,如果map_location是一个dict,它将用于将文件中出现的位置标记(键)重新映射为指定存储位置的位置标记(值)。用户扩展可以使用torch.serialize.register_package()注册他们自己的位置标签、标记和反序列化方法。

参数:

 

  • name 类似文件的对象(必须实现read(),:meth ' readline ',:meth ' tell '和:meth ' seek '),或者是包含文件的字符串。

  • map_location – 函数、torch.device或者字典指明如何重新映射存储位置。

  • pickle_module – 用于unpickling元数据和对象的模块(必须匹配用于序列化文件的pickle_module)

  • pickle_load_args – (仅适用于Python 3)传递给pickle_module.load()和pickle_module.Unpickler()的可选关键字参数,例如errors=…

警告:

load()隐式地使用pickle模块,这是不安全的。可以构造恶意pickle数据,在unpickle期间执行任意代码。永远不要加载可能来自不受信任的数据源或可能被篡改的数据。只加载你信任的数据。

注意:

当你在包含GPU张量的文件上调用torch.load()时,默认情况下这些张量会被加载到GPU。你可以调用torch.load(.., map_location='cpu'),然后load_state_dict()以避免在加载一个模型检查点时GPU内存激增。

注意:

默认情况下,我们将字节字符串解码为utf-8。这是为了避免一个常见的错误情况UnicodeDecodeError: 'ascii' codec can't decode byte 0x...在python3中加载由python2保存的文件时。如果这个默认是不正确的,你可以使用一个额外的编码关键字参数指定应该如何加载这些对象,例如,encoding='latin1'中的一个解码字符串使用latin1编码中的一个,和encoding='bytes'让他们作为字节数组可以解码后byte_array.decode (…)。

例:

>>> torch.load('tensors.pt')
# Load all tensors onto the CPU
>>> torch.load('tensors.pt', map_location=torch.device('cpu'))
# Load all tensors onto the CPU, using a function
>>> torch.load('tensors.pt', map_location=lambda storage, loc: storage)
# Load all tensors onto GPU 1
>>> torch.load('tensors.pt', map_location=lambda storage, loc: storage.cuda(1))
# Map tensors from GPU 1 to GPU 0
>>> torch.load('tensors.pt', map_location={'cuda:1':'cuda:0'})
# Load tensor from io.BytesIO object
>>> with open('tensor.pt', 'rb') as f:
        buffer = io.BytesIO(f.read())
>>> torch.load(buffer)
# Load a module with 'ascii' encoding for unpickling
>>> torch.load('module.pt', encoding='ascii')

 

你可能感兴趣的:(Pytorch)