torch模型的保存与加载

两种方式

'''
    第一种方式
    模型整体保存,占用空间会比较大
'''
torch.save(net, "../model/model.pkl")
torch.load("")

'''
    第二种方式
    保存模型参数,在加载模型参数之前,必须先建立模型
'''
torch.save(net.state_dict(), "params.pkl")
net.load_state_dict(torch.load("path_of_model_state_dict"))

下面是整个流程的代码

import torch

# data 数据加载
import numpy as np
import re  # regular expression

ff = open("../housing.data").readlines()
data = []
for item in ff:
    out = re.sub(r"\s{2,}", " ", item).strip()  # .strip()可去掉字符串前后的空格
    print(out)
    data.append(out.split(" "))

data = np.array(data).astype(np.float)
# print(data.shape)  # (506, 14)

Y = data[:, -1]
X = data[:, :-1]
'''
print(Y.shape)
print(X.shape)
'''

Y_train = Y[0:496]
X_train = X[0:496, ...]

Y_test = Y[496:]
X_test = X[496:, ...]
# print(Y_train.shape)
# print(X_train.shape)
# print(Y_test.shape)
# print(X_test.shape)

# net 网络定义
class Net(torch.nn.Module):
    def __init__(self, n_feature, n_output):
        super(Net, self).__init__()
        self.hidden = torch.nn.Linear(n_feature, 10)
        self.predict = torch.nn.Linear(10, n_output)

    def forward(self, x):
        _out = self.hidden(x)
        _out = torch.relu(_out)
        _out = self.predict(_out)
        return _out


net = Net(13, 1)

# loss 损失函数定义
loss_func = torch.nn.MSELoss()

# optimizer 优化器定义
optimizer = torch.optim.Adam(net.parameters(), lr=0.01)

# training  开始训练
for i in range(10000):
    x_data = torch.tensor(X_train, dtype=torch.float32)
    y_data = torch.tensor(Y_train, dtype=torch.float32)
    pred = net.forward(x_data)
    pred = torch.squeeze(pred)
    loss = loss_func(pred, y_data)*0.001
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    print("ite:{}, loss:{}".format(i, loss))
    print(pred[0:10])
    print(y_data[0:10])
# test 夹杂测试
    x_data = torch.tensor(X_test, dtype=torch.float32)
    y_data = torch.tensor(Y_test, dtype=torch.float32)
    pred = net.forward(x_data)
    pred = torch.squeeze(pred)
    loss_test = loss_func(pred, y_data)*0.001
    print("ite:{}, loss_test:{}".format(i, loss_test))

'''
	模型保存
    第一种方式
    模型整体保存,占用空间会比较大
'''
# torch.save(net, "../model/model.pkl")
# torch.load("")

'''
    第二种方式
    保存模型参数,在加载模型参数之前,必须先建立模型
'''
# torch.save(net.state_dict(), "params.pkl")
# net.load_state_dict()

你可能感兴趣的:(Pytorch,模型保存与加载)