从Pytorch源码看.pt文件

Pytorch中张量的保存与加载

保存张量

在Pytorch中,一个约定俗成的方法是使用.pt扩展的文件格式来保存张量,使用的方法为torch.save()。

函数原型与参数说明

import torch

def save(obj, f: Union[str, os.PathLike, BinaryIO],
         pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_new_zipfile_serialization=True) -> None:
    """
    pytorch框架的原型代码
    """
    pass

# 参数说明
# obj:要保存的对象,类型为tensor
# f:保存的文件名,可以是文件路径(包含文件名的字符串)、可以是字符流、也可以是文件对象
# pickle:Python中的一个模块,实现了用于序列化和反序列化Python对象结构的二进制协议
# pickle_module:用来协议化元数据和对象的协议
# pickle_protocol:可以指定来默认覆盖的协议

# 使用save方法
def save_tensor():
    # 直接保存为一个张量
    x = torch.Tensor([1, 2, 3])
    torch.save(x, 'save_tensor.pt')
    # 保存为字符流的格式
    buffer = io.BytesIO()
    torch.save(x, buffer)

加载张量

在Pytorch中,使用torch.load()方法加载torch.save()方法保存的文件。

函数原型与参数说明

import torch

def load(f, map_location=None, pickle_module=pickle, **pickle_load_args):
    """
    Pytorch框架的原型代码
    """
    pass

# 参数说明
# f:保存的文件名
# map_location:加载位置,即将这个张量加载到哪,可选的内容包括:函数、torch.device、字符串以及指定如何重新映射存储的字典
# pickle_module:用来协议化元数据和对象的协议
# pickle_load_args:需要加载的pickle模块的参数设置。这个包含的内容相当丰富,感兴趣的可以去阅读Pytorch的官方手册

# 使用load方法
def tensor_load():
    # 小白式加载(最常用)
    torch.load('save_tensor.pt')
    # 加载到CPU中
    torch.load('save_tensor.pt', map_location=torch.device('cpu'))
    # 使用函数加载到CPU中
    torch.load('save_tensor.pt', map_location=lambda storage, loc: storage)
    # 加载到GPU1中
    torch.load('save_tensor.pt', map_location=lambda storage, loc: storage.cuda(1))
    # 从GPU0加载到GPU1中
    torch.load('save_tensor.pt', map_location={'cuda: 1': 'cuda: 0'})
    # 指定加载的编码方式
    torch.load('save_tensor.pt', encoding='ascii')
    # 加载字符流格式的张量
    with open('save_tensor.pt', 'rb') as f:
                buffer = io.BytesIO(f.read())
    torch.load(buffer)

你可能感兴趣的:(Python,深度学习)