pytorch 保存、读取 tensor 和 numpy数据

一. pytorch 保存、读取 tensor

  1. 首先导包:
import torch

save_torch = torch.Tensor([[1, 2, 3, 4],
                           [2, 34, 5, 6]])
  1. 保存 tensor
torch.save(save_torch, 'test_save_tensor.pt')
  1. 读取 tensor
load_torch = torch.load('test_save_tensor.pt')
  1. 完整测试代码
import torch

save_torch = torch.Tensor([[1, 2, 3, 4],
                           [2, 34, 5, 6]])
print(save_torch)
torch.save(save_torch, 'test_save_tensor.pt') # 保存
load_torch = torch.load('test_save_tensor.pt') # 读取
print(load_torch)
  1. 保存网络结构:model是自己定义的网络结构:
# 保存整个网络
torch.save(net, PATH.pth) 
# 保存网络中的参数, 速度快,占空间少
torch.save(net.state_dict(),PATH.pth)
#--------------------------------------------------
#针对上面一般的保存方法,加载的方法分别是:
model_dict=torch.load(PATH)
model_dict=model.load_state_dict(torch.load(PATH))

二. pytorch 保存、读取numpy

  1. Numpy保存数据:利用numpy.save()函数将array保存为.npy格式的数据:
import numpy as np
np.save('where/you/wanto/store/output',arr)    #numpy 会自动加上.npy后缀
  1. Numpy读取数据
b = np.load('here/you/wanto/store/output.npy')

三. 相关链接

pytorch 保存、读取 tensor 数据
Python Numpy Pytorch 保存数据
pytorch中的tensor以numpy形式进行输出保存
PyTorch教程-7:PyTorch中保存与加载tensor和模型详解
pytorch 模型输出特征 保存npy

你可能感兴趣的:(Pytorch基础,numpy,pytorch,python,深度学习)