PyTorch学习—19.模型的加载与保存(序列化与反序列化)

文章目录

      • 一、序列化与反序列化
      • 二、PyTorch中的模型保存与加载方式
        • 1.模型保存
        • 2. 模型加载
      • 三、断点续训练

一、序列化与反序列化

  训练好的模型存储在内存中,内存中的数据不具备长久性存储的功能,所以,我们需要将模型从内存搬到硬盘中进行长久的存储,以备将来使用。这就是模型的保存与加载,也即序列化与反序列化。序列化指的是将内存中的对象保存到硬盘当中,以二进制序列的形式存储下来。反序列化指的是将硬盘当中的二进制序列反序列化的存储到内存中,得到对象,这样,我们就可以在内存中使用这个模型。序列化与反序列的目的是将数据、模型长久的保存。
PyTorch学习—19.模型的加载与保存(序列化与反序列化)_第1张图片

二、PyTorch中的模型保存与加载方式

  模型的保存与加载分别有两种模式:保存整个模型与保存模型参数。

1.模型保存

  • torch.save
    模型保存(序列化)
    主要参数:
    • obj:对象(模型、张量、parameters、dict等)
    • f:输出路径(硬盘当中的路径,用于保存)

模式一:保存整个Module

torch.save(net,path)

模式二:保存模型参数(官方推荐)

state_dict=net.state_dict()
torch.save(state_dict,path)

下面我们通过代码来简单观察模型保存两种模式的不同

import torch
import numpy as np
import torch.nn as nn


class LeNet2(nn.Module):
    def __init__(self, classes):
        super(LeNet2, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 6, 5),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(6, 16, 5),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        self.classifier = nn.Sequential(
            nn.Linear(16*5*5, 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size()[0], -1)
        x = self.classifier(x)
        return x

    def initialize(self):
        for p in self.parameters():
            p.data.fill_(20191104)


net = LeNet2(classes=2019)

# "模拟训练"
# 查看第一个卷积层的第一个卷积核的参数
print("训练前: ", net.features[0].weight[0, ...])
net.initialize()
print("训练后: ", net.features[0].weight[0, ...])

path_model = "./model.pkl"
path_state_dict = "./model_state_dict.pkl"

# 保存整个模型
torch.save(net, path_model)

# 保存模型参数
net_state_dict = net.state_dict()
torch.save(net_state_dict, path_state_dict)

在当前文件目录下,就会生成model.pklmodel_state_dict.pkl

2. 模型加载

  • torch.load
    模型加载(反序列化)
    主要参数:
    • f:文件路径
    • map_location:指定存放位置,cpu or gpu

模式一:加载整个Module

torch.load(net,path)

模式二:加载模型参数(官方推荐)

state_dict_load = torch.load(path_state_dict)
# 初始化模型
net_new = LeNet2(classes=2019)
# 模型加载参数
net_new.load_state_dict(state_dict_load)

下面我们通过代码来简单观察模型加载两种模式的不同

import torch
import numpy as np
import torch.nn as nn


class LeNet2(nn.Module):
    def __init__(self, classes):
        super(LeNet2, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 6, 5),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(6, 16, 5),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        self.classifier = nn.Sequential(
            nn.Linear(16*5*5, 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size()[0], -1)
        x = self.classifier(x)
        return x

    def initialize(self):
        for p in self.parameters():
            p.data.fill_(20191104)


# ================================== load net ===========================

path_model = "./model.pkl"
net_load = torch.load(path_model)
# 打印模型结构
print(net_load)

# ================================== load state_dict ===========================

path_state_dict = "./model_state_dict.pkl"
state_dict_load = torch.load(path_state_dict)

print(state_dict_load.keys())

# ================================== update state_dict ===========================

net_new = LeNet2(classes=2019)

print("加载前: ", net_new.features[0].weight[0, ...])
net_new.load_state_dict(state_dict_load)
print("加载后: ", net_new.features[0].weight[0, ...])
LeNet2(
  (features): Sequential(
    (0): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (classifier): Sequential(
    (0): Linear(in_features=400, out_features=120, bias=True)
    (1): ReLU()
    (2): Linear(in_features=120, out_features=84, bias=True)
    (3): ReLU()
    (4): Linear(in_features=84, out_features=2019, bias=True)
  )
)

# 这些keys中的值全为20191104
odict_keys(['features.0.weight', 'features.0.bias', 'features.3.weight', 'features.3.bias', 'classifier.0.weight', 'classifier.0.bias', 'classifier.2.weight', 'classifier.2.bias', 'classifier.4.weight', 'classifier.4.bias'])

加载前:  tensor([[[ 0.1091,  0.0194, -0.0126, -0.0682,  0.1105],
         [-0.0846, -0.0898, -0.1119,  0.0616, -0.1045],
         [-0.0696,  0.0017, -0.0754,  0.0078, -0.0538],
         [-0.0208, -0.0018,  0.0816, -0.0684,  0.0662],
         [ 0.0453,  0.0400, -0.0184, -0.0038, -0.0663]],
        [[ 0.0226,  0.0914,  0.0725, -0.0074,  0.0676],
         [ 0.0181, -0.0345,  0.0596, -0.0887, -0.0850],
         [ 0.0062,  0.0234,  0.0859,  0.0583,  0.1044],
         [-0.1049,  0.0995, -0.0015,  0.0407,  0.0565],
         [ 0.0461, -0.0932,  0.0327, -0.0590, -0.0506]],
        [[-0.0758,  0.1004,  0.0691, -0.1083,  0.0960],
         [-0.1005, -0.0571,  0.0786,  0.0127,  0.0799],
         [-0.0487, -0.1015, -0.0787,  0.1073, -0.1136],
         [-0.0696,  0.0429, -0.0608,  0.0483, -0.1019],
         [-0.0581,  0.0846, -0.0643,  0.1144, -0.1064]]],
       grad_fn=<SelectBackward>)
加载后:  tensor([[[20191104., 20191104., 20191104., 20191104., 20191104.],
         [20191104., 20191104., 20191104., 20191104., 20191104.],
         [20191104., 20191104., 20191104., 20191104., 20191104.],
         [20191104., 20191104., 20191104., 20191104., 20191104.],
         [20191104., 20191104., 20191104., 20191104., 20191104.]],
        [[20191104., 20191104., 20191104., 20191104., 20191104.],
         [20191104., 20191104., 20191104., 20191104., 20191104.],
         [20191104., 20191104., 20191104., 20191104., 20191104.],
         [20191104., 20191104., 20191104., 20191104., 20191104.],
         [20191104., 20191104., 20191104., 20191104., 20191104.]],
        [[20191104., 20191104., 20191104., 20191104., 20191104.],
         [20191104., 20191104., 20191104., 20191104., 20191104.],
         [20191104., 20191104., 20191104., 20191104., 20191104.],
         [20191104., 20191104., 20191104., 20191104., 20191104.],
         [20191104., 20191104., 20191104., 20191104., 20191104.]]],
       grad_fn=<SelectBackward>)

三、断点续训练

  模型的保存与加载不仅可以使得模型永久使用,而且在模型训练过程中也是十分有用的。常用技巧是断点续训练。断点续训练可以解决因某种意外的原因导致模型训练的终止,而需要重新训练,重复训练的问题。所以,在模型训练过程就要有一个断点续训练的机制,保存模型参数,以备在中断之后接着这个checkpoint继续训练。那么,断点续训练需要保存那些数据呢?在模型迭代训练过程中,模型与优化器不断变化着。当然,也可以保存loss,用来指示当前模型的状态。

# 基本参数
checkpoint = {
	"model_state_dict": net.state_dict(),
	"optimizer_state_dict": optimizer.state_dict(),
	"epoch": epoch
	}

  下面我们模拟断点续训练

rmb_label = {"1": 0, "100": 1}

# 参数设置
checkpoint_interval = 5
MAX_EPOCH = 10
BATCH_SIZE = 16
LR = 0.01
log_interval = 10
val_interval = 1


# ============================ step 1/5 数据 ============================

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
split_dir = os.path.abspath(os.path.join(BASE_DIR, "..", "..", "materials", "rmb_split"))
train_dir = os.path.join(split_dir, "train")
valid_dir = os.path.join(split_dir, "valid")

if not os.path.exists(split_dir):
    raise Exception(r"数据 {} 不存在, 回到lesson-06\1_split_dataset.py生成数据".format(split_dir))

norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]

train_transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.RandomCrop(32, padding=4),
    transforms.RandomGrayscale(p=0.8),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])

valid_transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])

# 构建MyDataset实例
train_data = RMBDataset(data_dir=train_dir, transform=train_transform)
valid_data = RMBDataset(data_dir=valid_dir, transform=valid_transform)

# 构建DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)

# ============================ step 2/5 模型 ============================

net = LeNet(classes=2)
net.initialize_weights()

# ============================ step 3/5 损失函数 ============================
criterion = nn.CrossEntropyLoss()                                                   # 选择损失函数

# ============================ step 4/5 优化器 ============================
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9)                        # 选择优化器
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=6, gamma=0.1)     # 设置学习率下降策略

# ============================ step 5/5 训练 ============================
train_curve = list()
valid_curve = list()

start_epoch = -1
for epoch in range(start_epoch+1, MAX_EPOCH):

    loss_mean = 0.
    correct = 0.
    total = 0.

    net.train()
    for i, data in enumerate(train_loader):

        # forward
        inputs, labels = data
        outputs = net(inputs)

        # backward
        optimizer.zero_grad()
        loss = criterion(outputs, labels)
        loss.backward()

        # update weights
        optimizer.step()

        # 统计分类情况
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).squeeze().sum().numpy()

        # 打印训练信息
        loss_mean += loss.item()
        train_curve.append(loss.item())
        if (i+1) % log_interval == 0:
            loss_mean = loss_mean / log_interval
            print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
                epoch, MAX_EPOCH, i+1, len(train_loader), loss_mean, correct / total))
            loss_mean = 0.

    scheduler.step()  # 更新学习率

    if (epoch+1) % checkpoint_interval == 0:

        checkpoint = {"model_state_dict": net.state_dict(),
                      "optimizer_state_dict": optimizer.state_dict(),
                      "epoch": epoch}
        path_checkpoint = "./checkpoint_{}_epoch.pkl".format(epoch)
        torch.save(checkpoint, path_checkpoint)

    if epoch > 5:
        print("训练意外中断...")
        break

    # validate the model
    if (epoch+1) % val_interval == 0:

        correct_val = 0.
        total_val = 0.
        loss_val = 0.
        net.eval()
        with torch.no_grad():
            for j, data in enumerate(valid_loader):
                inputs, labels = data
                outputs = net(inputs)
                loss = criterion(outputs, labels)

                _, predicted = torch.max(outputs.data, 1)
                total_val += labels.size(0)
                correct_val += (predicted == labels).squeeze().sum().numpy()

                loss_val += loss.item()

            valid_curve.append(loss.item())
            print("Valid:\t Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
                epoch, MAX_EPOCH, j+1, len(valid_loader), loss_val/len(valid_loader), correct / total))


train_x = range(len(train_curve))
train_y = train_curve

train_iters = len(train_loader)
valid_x = np.arange(1, len(valid_curve)+1) * train_iters*val_interval # 由于valid中记录的是epochloss,需要对记录点进行转换到iterations
valid_y = valid_curve

plt.plot(train_x, train_y, label='Train')
plt.plot(valid_x, valid_y, label='Valid')

plt.legend(loc='upper right')
plt.ylabel('loss value')
plt.xlabel('Iteration')
plt.show()
Training:Epoch[000/010] Iteration[010/013] Loss: 0.6543 Acc:58.13%
Valid:	 Epoch[000/010] Iteration[003/003] Loss: 0.3089 Acc:64.62%
Training:Epoch[001/010] Iteration[010/013] Loss: 0.2285 Acc:94.38%
Valid:	 Epoch[001/010] Iteration[003/003] Loss: 0.0121 Acc:95.38%
Training:Epoch[002/010] Iteration[010/013] Loss: 0.1686 Acc:96.25%
Valid:	 Epoch[002/010] Iteration[003/003] Loss: 0.0582 Acc:95.38%
Training:Epoch[003/010] Iteration[010/013] Loss: 0.2316 Acc:90.00%
Valid:	 Epoch[003/010] Iteration[003/003] Loss: 0.0753 Acc:89.74%
Training:Epoch[004/010] Iteration[010/013] Loss: 0.1923 Acc:91.88%
Valid:	 Epoch[004/010] Iteration[003/003] Loss: 0.0067 Acc:93.33%
Training:Epoch[005/010] Iteration[010/013] Loss: 0.0573 Acc:98.75%
Valid:	 Epoch[005/010] Iteration[003/003] Loss: 0.0062 Acc:98.97%
Training:Epoch[006/010] Iteration[010/013] Loss: 0.0070 Acc:99.38%
训练意外中断...

在第5个epoch设置断点,接下来进行断点续训练

rmb_label = {"1": 0, "100": 1}

# 参数设置
checkpoint_interval = 5
MAX_EPOCH = 10
BATCH_SIZE = 16
LR = 0.01
log_interval = 10
val_interval = 1


# ============================ step 1/5 数据 ============================

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
split_dir = os.path.abspath(os.path.join(BASE_DIR, "..", "..", "materials", "rmb_split"))
train_dir = os.path.join(split_dir, "train")
valid_dir = os.path.join(split_dir, "valid")

if not os.path.exists(split_dir):
    raise Exception(r"数据 {} 不存在, 回到lesson-06\1_split_dataset.py生成数据".format(split_dir))

norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]

train_transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.RandomCrop(32, padding=4),
    transforms.RandomGrayscale(p=0.8),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])

valid_transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])

# 构建MyDataset实例
train_data = RMBDataset(data_dir=train_dir, transform=train_transform)
valid_data = RMBDataset(data_dir=valid_dir, transform=valid_transform)

# 构建DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)

# ============================ step 2/5 模型 ============================

net = LeNet(classes=2)
net.initialize_weights()

# ============================ step 3/5 损失函数 ============================
criterion = nn.CrossEntropyLoss()                                                   # 选择损失函数

# ============================ step 4/5 优化器 ============================
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9)                        # 选择优化器
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=6, gamma=0.1)     # 设置学习率下降策略


# ============================ step 5+/5 断点恢复 ============================
path_checkpoint = "./checkpoint_4_epoch.pkl"
checkpoint = torch.load(path_checkpoint)

net.load_state_dict(checkpoint['model_state_dict'])

optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

start_epoch = checkpoint['epoch']

scheduler.last_epoch = start_epoch

# ============================ step 5/5 训练 ============================
train_curve = list()
valid_curve = list()

for epoch in range(start_epoch + 1, MAX_EPOCH):

    loss_mean = 0.
    correct = 0.
    total = 0.

    net.train()
    for i, data in enumerate(train_loader):

        # forward
        inputs, labels = data
        outputs = net(inputs)

        # backward
        optimizer.zero_grad()
        loss = criterion(outputs, labels)
        loss.backward()

        # update weights
        optimizer.step()

        # 统计分类情况
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).squeeze().sum().numpy()

        # 打印训练信息
        loss_mean += loss.item()
        train_curve.append(loss.item())
        if (i+1) % log_interval == 0:
            loss_mean = loss_mean / log_interval
            print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
                epoch, MAX_EPOCH, i+1, len(train_loader), loss_mean, correct / total))
            loss_mean = 0.

    scheduler.step()  # 更新学习率

    if (epoch+1) % checkpoint_interval == 0:

        checkpoint = {"model_state_dict": net.state_dict(),
                      "optimizer_state_dic": optimizer.state_dict(),
                      "loss": loss,
                      "epoch": epoch}
        path_checkpoint = "./checkpint_{}_epoch.pkl".format(epoch)
        torch.save(checkpoint, path_checkpoint)

    # if epoch > 5:
    #     print("训练意外中断...")
    #     break

    # validate the model
    if (epoch+1) % val_interval == 0:

        correct_val = 0.
        total_val = 0.
        loss_val = 0.
        net.eval()
        with torch.no_grad():
            for j, data in enumerate(valid_loader):
                inputs, labels = data
                outputs = net(inputs)
                loss = criterion(outputs, labels)

                _, predicted = torch.max(outputs.data, 1)
                total_val += labels.size(0)
                correct_val += (predicted == labels).squeeze().sum().numpy()

                loss_val += loss.item()

            valid_curve.append(loss.item())
            print("Valid:\t Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
                epoch, MAX_EPOCH, j+1, len(valid_loader), loss_val/len(valid_loader), correct / total))


train_x = range(len(train_curve))
train_y = train_curve

train_iters = len(train_loader)
valid_x = np.arange(1, len(valid_curve)+1) * train_iters*val_interval # 由于valid中记录的是epochloss,需要对记录点进行转换到iterations
valid_y = valid_curve

plt.plot(train_x, train_y, label='Train')
plt.plot(valid_x, valid_y, label='Valid')

plt.legend(loc='upper right')
plt.ylabel('loss value')
plt.xlabel('Iteration')
plt.show()

发现:模型从第5轮接着开始训练

Training:Epoch[005/010] Iteration[010/013] Loss: 0.0603 Acc:97.50%
Valid:	 Epoch[005/010] Iteration[003/003] Loss: 0.0031 Acc:97.95%
Training:Epoch[006/010] Iteration[010/013] Loss: 0.0151 Acc:99.38%
Valid:	 Epoch[006/010] Iteration[003/003] Loss: 0.0021 Acc:99.49%
Training:Epoch[007/010] Iteration[010/013] Loss: 0.1177 Acc:97.50%
Valid:	 Epoch[007/010] Iteration[003/003] Loss: 0.0020 Acc:97.95%
Training:Epoch[008/010] Iteration[010/013] Loss: 0.0032 Acc:100.00%
Valid:	 Epoch[008/010] Iteration[003/003] Loss: 0.0019 Acc:98.97%
Training:Epoch[009/010] Iteration[010/013] Loss: 0.0160 Acc:99.38%
Valid:	 Epoch[009/010] Iteration[003/003] Loss: 0.0017 Acc:98.97%

如果对您有帮助,麻烦点赞关注,这真的对我很重要!!!如果需要互关,请评论或者私信!
在这里插入图片描述


你可能感兴趣的:(PyTorch框架学习,PyTorch,模型的加载,模型的保存,断点续训练)