pytorch数据加载、模型保存及加载

主要涉及的Pytorch官方示例下图红框部分的一些翻译及备注。
pytorch数据加载、模型保存及加载_第1张图片

1、数据加载及处理

  该部分主要是用于进行数据集加载及数据预处理说明,使用的数据集为:人脸+标注坐标。demo程序需要pandas(读取CSV文件)及scikit-image(图像变换)这两个包。

1.1、jupyter显示matplot图像

import matplotlib.pyplot as plt
%matplotlib inline   #这句是在jupyter显示图像的关键,在其它IDE中必须注释掉,否则报错  

1.2、数据集类

  torch.utils.data.Dataset 是一个处理数据的抽象类。当使用自己的数据集时需要继承Dataset类,并且重载以下成员函数:
《1》、len : 用于返回数据集的大小。
《2》、getitem : 通过下标索引取得第i个样本。

demo程序中为脸部标注样本创建了一个FaceLandmarksDataset类。在该类的__init__方法中读取csv文件,在__getitem__方法中加载图片。
我们创建的数据集样本会以一个字典表示,如下:

{'image': image, 'landmarks': landmarks}

该数据集类有一个可选参数“transform”,用于控制对图像进行的处理。

1.3、数据转换(transforms)

  几乎所有神经网络的输入都希望接收到大小固定的数据,而我们demo中的原始图像大小是不一致的。因此我们添加一些图像变换方法来处理这些图像。包括以下三个:
  Rescale: 缩放图片
  RandomCrop: 随机裁剪
  ToTensor:将numpy表示的图像转化为torch的Tensor表示
将每一个图像变换用一个可调用的类实现。这样做的好处是–进行变换时的参数不用每次都在迭代上下文传递。为此实现了类的__call__ 专有函数

__call__、__getitem__
python专有函数。
若在定义类的时候,实现__call__函数,则这个类就成为可调用的。换句话说,我们可以把这个类的实例当做函数来使用。
相当于重载了括号运算符。

例子说明:

class g_dpm(object):
    def __init__(self, g):
        self.g = g

    def __call__(self, t):
        return (self.g*t**2)/2

  计算地球场景的时候,我们就可以令e_dpm = g_dpm(9.8),s = e_dpm(t)

1.4、组合图像变换

  我们需要将demo中使用的样本图片较短的边设置为256,之后将图片裁剪为244x244大小。为此我们需要组合Rescale和RandomCrop两种变换。可以通过torchvision.transforms.Compose实现。这是一个可调用类。

scale = Rescale(256)
crop = RandomCrop(128)
composed = transforms.Compose([Rescale(256),
                               RandomCrop(224)])

1.5、在数据集上进行迭代处理

   如果我们在训练的每次迭代中只从数据集中抽取一张图片,那么我们会丢失很多和数据集有关的特征。因此我们每次迭代我们采用以下方法:
  《1》、批处理
  《2》、数据重混(打乱数据)
  《3》、用多个进程并行加载数据
原文:
pytorch数据加载、模型保存及加载_第2张图片
torch提供了torch.utils.data.DataLoader用于实现以上3点。torch.utils.data.DataLoader是一个迭代器,他有一个collate_fn的参数需要特别留意下,该参数用于合并一些list形式样本来形成一个小批量( merges a list of samples to form a mini-batch)。
原文:
pytorch数据加载、模型保存及加载_第3张图片

dataloader = DataLoader(transformed_dataset, batch_size=4,
                        shuffle=True, num_workers=4)

1.6、Torchvision

  torchvision包提供了一些常用的数据集类和数据处理实现。该包中最常用的数据集类是ImageFolder。
该类假设图片按照以下方式存储。
pytorch数据加载、模型保存及加载_第4张图片
上图中bees、ants都是类标签。ImageFloder使用例子:

import torch
from torchvision import transforms, datasets

data_transform = transforms.Compose([
        transforms.RandomSizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
hymenoptera_dataset = datasets.ImageFolder(root='hymenoptera_data/train',
                                           transform=data_transform)
dataset_loader = torch.utils.data.DataLoader(hymenoptera_dataset,
                                             batch_size=4, shuffle=True,
                                             num_workers=4)

2、模型加载及保存

2.1、与模型加载有关的3个函数:

《1》、torch.save

torch.save(obj, f, pickle_module=, pickle_protocol=2)

功能:将模型保存到磁盘。该函数使用python的pickle包来序列化模型。
官方推荐两种用法:
A、仅仅保存模型参数;
B、保存整个模型;
pytorch数据加载、模型保存及加载_第5张图片
《2》、torch.load

torch.load(f, map_location=None, pickle_module=)

功能:加载由torch.load()函数保存的模型。
  该函数首先会将模型反序列化到CPU然后将模型移动到保存模型时该模型所处的设备(CPU或GPU)。如果现有机器上没有对应保存模型时的设备,则该函数会抛出异常。如果遇到这种情况,可以使用该函数的map_location参数来将模型动态映射到一系列设备上。
在这里插入图片描述

>>> torch.load('tensors.pt')
# Load all tensors onto the CPU
>>> torch.load('tensors.pt', map_location=torch.device('cpu'))
# Load all tensors onto the CPU, using a function
>>> torch.load('tensors.pt', map_location=lambda storage, loc: storage)
# Load all tensors onto GPU 1
>>> torch.load('tensors.pt', map_location=lambda storage, loc: storage.cuda(1))
# Map tensors from GPU 1 to GPU 0
>>> torch.load('tensors.pt', map_location={'cuda:1':'cuda:0'})
# Load tensor from io.BytesIO object
>>> with open('tensor.pt') as f:
        buffer = io.BytesIO(f.read())
>>> torch.load(buffer)

《3》、torch.nn.Module.load_state_dict

load_state_dict(state_dict, strict=True)

功能:仅仅加载模型的参数。
pytorch数据加载、模型保存及加载_第6张图片

2.2、STATE_DICT

  在pytorch中,torch.nn.model中的可学习参数(权重,bias(偏差)等)都存储在模型的parameters成员中,可通过model.parameters()获取。
stat_dict是一个字典,该字典包含model每一层的tensor类型的可学习参数。只有包含可学习参数的网络层才能将其参数映射到state_dict字典中。
原文:
pytorch数据加载、模型保存及加载_第7张图片

例子:

#定义网络模型用于说明sate_dict
class TheModelClass(nn.Module):
    def __init__(self):
        super(TheModelClass, self).__init__()
        self.cov1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
        
    def forward():
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = TheModelClass()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])

输出:
pytorch数据加载、模型保存及加载_第8张图片
由打印知道模型参数包括两大类。一类是权重及偏差参数,另一类是Optimizer参数。optimizer的state_dict包含两个关键字:优化器的state及超参数。

2.3、保存模型及加载模型用于预测

a、保存
推荐仅仅保存模型的state_dict

torch.save(model.state_dict(), MODELPATH)

b、加载

model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

Pytorch保存的模型后缀一般是.pt或者.pth
必须在加载模型后调用model.eval函数来将dropout及批归一化层设置为预测模式。如果不这么做结果出错。

2.4、保存临时模型用于预测或再训练

a、保存

torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            ...
            }, PATH)

  当保存一个临时模型用于预测或再训练时,需要保存比state_dict更多的参数。包括优化器的state_dict,迭代次数epoch,最后一层迭代的loss及其他任何需要的参数。
  当保存多个组件时,将多个组件以字典的形式组织,然后用torch.savee()来序列化该字典。在Pytorch中常用.tar文件后缀表示这种模型。

b、加载

model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

model.eval()   #预测
# - or -
model.train() #再训练

2.5、将多个模型保存在一个文件中

a、保存:

torch.save({
            'modelA_state_dict': modelA.state_dict(),
            'modelB_state_dict': modelB.state_dict(),
            'optimizerA_state_dict': optimizerA.state_dict(),
            'optimizerB_state_dict': optimizerB.state_dict(),
            ...
            }, PATH)

b、加载:

modelA = TheModelAClass(*args, **kwargs)
modelB = TheModelBClass(*args, **kwargs)
optimizerA = TheOptimizerAClass(*args, **kwargs)
optimizerB = TheOptimizerBClass(*args, **kwargs)

checkpoint = torch.load(PATH)
modelA.load_state_dict(checkpoint['modelA_state_dict'])
modelB.load_state_dict(checkpoint['modelB_state_dict'])
optimizerA.load_state_dict(checkpoint['optimizerA_state_dict'])
optimizerB.load_state_dict(checkpoint['optimizerB_state_dict'])

modelA.eval()
modelB.eval()
# - or -
modelA.train()
modelB.train()

2.6、利用一个不同的模型来预热(warmstarting)待使用的模型【WARMSTARTING MODEL USING PARAMETERS FROM A DIFFERENT MODEL】

pytorch数据加载、模型保存及加载_第9张图片
加载一个模型的一部分或者加载一个不完整的网络在迁移学习或者训练一个新的复杂网络时会经常遇到。
使用已经训练过的参数,即使这些参数仅仅是待训练网络参数的一小部分,也会加快网络的训练及帮助网络更快达到收敛。

2.7、在不同设备上进行模型的保存及加载

《1》、GPU上保存,CPU上加载
pytorch数据加载、模型保存及加载_第10张图片
在这种情况下,tensor的使用的内存会自动重映射到CPU设备中。

《2》、GPU上保存,GPU上加载
pytorch数据加载、模型保存及加载_第11张图片
该场景下需要注意:必须将模型的所有输入使用.to(torch.device(“cuda”))转为GPU使用的类型。
注意:
  my_tensor.to(device)返回的是my_tensor的一个新的拷贝,该操作不会覆盖my_tensor原本的device类型(CPU或GPU)
覆盖式的写法:

 my_tensor = my_tensor.to(device)

《3》、CPU上保存,GPU上加载
pytorch数据加载、模型保存及加载_第12张图片
比起类型2,在调用load_state_dict函数时多一个map_loaction操作。其它操作同类型2.
《4》、保存并行数据模型(torch.nn.DataParallel)
pytorch数据加载、模型保存及加载_第13张图片
torch.nn.DataParallel模型是一个封装好的模型,该模型能使用GPU的并行处理操作。

3、Jupyter demo

你可能感兴趣的:(深度学习)