PyTorch入门(三)模块的保存与加载

  本文将介绍如何使用PyTorch保存模块和加载模型。

PyTorch模型保存与加载

  在PyTorch中,一个torch.nn.Module模型的可训练参数(即权重与偏移项)保存在模型的参数parameters),使用model.parameters()获得)中。一个state_dict就是一个简单的Python字典,将每层映射到其参数张量。PyTorch的模型文件以.pt.pth为后缀。使用函数torch.save保存模型,使用函数torch.load加载模型。

  PyTorch有两种保存与加载模型的方式,一种是保存整个模型(包括模型结构及参数值),另一种是只保存模型的参数值(即state_dict)。

  1. 保存整个网络结构信息和模型参数信息:
torch.save(model_object, './model.pth')

直接加载即可使用:

model = torch.load('./model.pth')
  1. 只保存网络的模型参数
torch.save(model_object.state_dict(), './params.pth')

加载则要先从本地网络模块导入网络,然后再加载参数:

from models import Model
model = Model()
model.load_state_dict(torch.load('./params.pth'))

示例代码

  我们以文章PyTorch入门(二)搭建MLP模型实现分类任务中的二分类MLP模型为例,来演示如何在PyTorch中保存模型和加载代码。

  只保存模型参数值的示例Python代码(save_model.py)如下:

# -*- coding: utf-8 -*-
from numpy import vstack
from pandas import read_csv
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score
import torch
from torch import Tensor
from torch.optim import SGD
from torch.utils.data import Dataset, DataLoader, random_split
from torch.nn import Linear, ReLU, Sigmoid, Module, BCELoss
from torch.nn.init import kaiming_uniform_, xavier_uniform_


# dataset definition
class CSVDataset(Dataset):
    # load the dataset
    def __init__(self, path):
        # load the csv file as a dataframe
        df = read_csv(path, header=None)
        # store the inputs and outputs
        self.X = df.values[:, :-1]
        self.y = df.values[:, -1]
        # ensure input data is floats
        self.X = self.X.astype('float32')
        # label encode target and ensure the values are floats
        self.y = LabelEncoder().fit_transform(self.y)
        self.y = self.y.astype('float32')
        self.y = self.y.reshape((len(self.y), 1))

    # number of rows in the dataset
    def __len__(self):
        return len(self.X)

    # get a row at an index
    def __getitem__(self, idx):
        return [self.X[idx], self.y[idx]]

    # get indexes for train and test rows
    def get_splits(self, n_test=0.3):
        # determine sizes
        test_size = round(n_test * len(self.X))
        train_size = len(self.X) - test_size
        # calculate the split
        return random_split(self, [train_size, test_size])


# model definition
class MLP(Module):
    # define model elements
    def __init__(self, n_inputs):
        super(MLP, self).__init__()
        # input to first hidden layer
        self.hidden1 = Linear(n_inputs, 10)
        kaiming_uniform_(self.hidden1.weight, nonlinearity='relu')
        self.act1 = ReLU()
        # second hidden layer
        self.hidden2 = Linear(10, 8)
        kaiming_uniform_(self.hidden2.weight, nonlinearity='relu')
        self.act2 = ReLU()
        # third hidden layer and output
        self.hidden3 = Linear(8, 1)
        xavier_uniform_(self.hidden3.weight)
        self.act3 = Sigmoid()

    # forward propagate input
    def forward(self, X):
        # input to first hidden layer
        X = self.hidden1(X)
        X = self.act1(X)
        # second hidden layer
        X = self.hidden2(X)
        X = self.act2(X)
        # third hidden layer and output
        X = self.hidden3(X)
        X = self.act3(X)
        return X


# prepare the dataset
def prepare_data(path):
    # load the dataset
    dataset = CSVDataset(path)
    # calculate split
    train, test = dataset.get_splits()
    # prepare data loaders
    train_dl = DataLoader(train, batch_size=32, shuffle=True)
    test_dl = DataLoader(test, batch_size=1024, shuffle=False)
    return train_dl, test_dl


# train the model
def train_model(train_dl, model):
    # define the optimization
    criterion = BCELoss()
    optimizer = SGD(model.parameters(), lr=0.01, momentum=0.9)
    # enumerate epochs
    for epoch in range(100):
        # enumerate mini batches
        for i, (inputs, targets) in enumerate(train_dl):
            # clear the gradients
            optimizer.zero_grad()
            # compute the model output
            yhat = model(inputs)
            # calculate loss
            loss = criterion(yhat, targets)
            # credit assignment
            loss.backward()
            print("epoch: {}, batch: {}, loss: {}".format(epoch, i, loss.data))
            # update model weights
            optimizer.step()


# evaluate the model
def evaluate_model(test_dl, model):
    predictions, actuals = [], []
    for i, (inputs, targets) in enumerate(test_dl):
        # evaluate the model on the test set
        yhat = model(inputs)
        # retrieve numpy array
        yhat = yhat.detach().numpy()
        actual = targets.numpy()
        actual = actual.reshape((len(actual), 1))
        # round to class values
        yhat = yhat.round()
        # store
        predictions.append(yhat)
        actuals.append(actual)
    predictions, actuals = vstack(predictions), vstack(actuals)
    # calculate accuracy
    acc = accuracy_score(actuals, predictions)
    return acc


# make a class prediction for one row of data
def predict(row, model):
    # convert row to data
    row = Tensor([row])
    # make prediction
    yhat = model(row)
    # retrieve numpy array
    yhat = yhat.detach().numpy()
    return yhat


if __name__ == '__main__':
    # prepare the data
    path = './data/ionosphere.csv'
    train_dl, test_dl = prepare_data(path)
    print(len(train_dl.dataset), len(test_dl.dataset))
    # define the network
    model = MLP(34)
    print(model)
    # train the model
    train_model(train_dl, model)
    torch.save(model.state_dict(), 'binary_classification.pth')
    print(model.state_dict())
    # evaluate the model
    acc = evaluate_model(test_dl, model)
    print('Accuracy: %.3f' % acc)

运行代码,会输出该MLP模型的参数值(state_dict)如下:

OrderedDict([('hidden1.weight', tensor([[-4.3042e-02, -1.3315e-01, -3.5050e-01, -1.4949e-01, -1.6642e-01,
         ......), ('hidden1.bias', tensor([ 0.2563, -0.0024, -0.1276,  0.1943, -0.2728, -0.2992,  0.3130,  0.0245,
        -0.0381,  0.4498])), ('hidden2.weight', tensor([[-0.5759, -0.9750,  1.0027,  0.5148,  0.6903,  0.3534, -1.0665,  0.1220,
         -0.0757,  0.4448], ......), ('hidden2.bias', tensor([ 1.7468e-01,  5.9972e-02, -4.2997e-02, -2.2675e-01,  8.3250e-01,
        -3.2392e-04,  3.9665e-01, -2.5674e-01])), ('hidden3.weight', tensor([[ 1.3292, -0.6698, -0.2412,  1.0923, -2.5248,  0.3479, -1.1331, -0.0240]])), ('hidden3.bias', tensor([-0.8218]))])

值得注意的是,state_dict输出的格式为Python字典结构。保存为文件名称为binary_classification.pth。

  接着我们加载该模型文件,并对新数据进行预测,示例代码(load_model.py)如下:

# -*- coding: utf-8 -*-
import torch
from torch import Tensor

from save_model import MLP

model = MLP(34)
state_dict = torch.load('./binary_classification.pth')
model.load_state_dict(state_dict)
print(model)
# make a single prediction (expect class=1)
row = [1, 0, 0.99539, -0.05889, 0.85243, 0.02306, 0.83398, -0.37708, 1, 0.03760, 0.85243, -0.17755, 0.59755, -0.44945,
       0.60536, -0.38223, 0.84356, -0.38542, 0.58212, -0.32192, 0.56971, -0.29674, 0.36946, -0.47357, 0.56811, -0.51171,
       0.41078, -0.46168, 0.21266, -0.34090, 0.42267, -0.54487, 0.18641, -0.45300]
row = Tensor([row])
# make prediction
yhat = model(row)
# retrieve numpy array
yhat = yhat.detach().numpy()
print('Predicted: %.3f (class=%d)' % (yhat, yhat.round()))

  如果我们想保存、加载整个模型及模型参数,则在模型保存代码(save_model.py)中使用代码:

torch.save(model, 'binary_classification.pth')

加载模型部分代码如下:

# -*- coding: utf-8 -*-
import torch
from torch import Tensor

from save_model import MLP

model = torch.load('./binary_classification.pth')

# make a single prediction (expect class=1)
row = [1, 0, 0.99539, -0.05889, 0.85243, 0.02306, 0.83398, -0.37708, 1, 0.03760, 0.85243, -0.17755, 0.59755, -0.44945,
       0.60536, -0.38223, 0.84356, -0.38542, 0.58212, -0.32192, 0.56971, -0.29674, 0.36946, -0.47357, 0.56811, -0.51171,
       0.41078, -0.46168, 0.21266, -0.34090, 0.42267, -0.54487, 0.18641, -0.45300]
row = Tensor([row])
# make prediction
yhat = model(row)
# retrieve numpy array
yhat = yhat.detach().numpy()
print('Predicted: %.3f (class=%d)' % (yhat, yhat.round()))

需要注意的是,模型结构MLP类仍需在代码中(虽然后面代码中并没有用到MLP类),这样模型才能加载成功,否则会报模型加载失败。

总结

  本文简单介绍了如何在PyTorch中保存和加载模型。本文介绍的模型代码已开源,Github地址为:https://github.com/percent4/PyTorch_Learning。后续将持续介绍PyTorch内容,欢迎大家关注~

你可能感兴趣的:(算法,pytorch)