#文件读取
import torch
from torch import nn
from torch.nn import functional as F
x=torch.arange(4)
torch.save(x,'x-file')
#当前目录下新建文件
x2=torch.load("x-file")
print(x2)
输出:tensor([0, 1, 2, 3])
y=torch.zeros(4)
torch.save([x,y],"x-file")
x2,y2=torch.load("x-file")
print((x2,y2))
输出:(tensor([0, 1, 2, 3]), tensor([0., 0., 0., 0.]))
#写入或读取从字符串映射到张量的字典
mydict={'x':x,'y':y}
torch.save(mydict,"mydict")
mydict2=torch.load("mydict")
print(mydict2)
输出:{'x': tensor([0, 1, 2, 3]), 'y': tensor([0., 0., 0., 0.])}
#加载和保存模型参数:
class MLP(nn.Module):
def __init__(self):
super().__init__()
self.hidden=nn.Linear(20,256)
self.output=nn.Linear(256,10)
def forward(self,x):
return self.output(F.relu(self.hidden(x)))
net=MLP()
X=torch.randn(size=(2,20))
Y=net(X)
torch.save(net.state_dict(),"mlp.params")
#实例化了原始多层感知机模型的一个备份.直接读取文件中存储的参数.
clone=MLP()
clone.load_state_dict(torch.load("mlp.params")) #把存在磁盘上的参数写回网络
clone.eval()
Y_clone=clone(X)
print(Y_clone==Y)
输出:tensor([[True, True, True, True, True, True, True, True, True, True],
[True, True, True, True, True, True, True, True, True, True]])