其实Pytorch模型保存还是挺简单的,但是不同方式也有优劣之分吧。有时候,我们不仅仅需要保存模型参数,而有时需要保存训练的所有现场,包括优化器的内容。即有时候是只保存参数,但有时候需要保存模型训练的全过程。
我们实际上保存的是模型的参数,没有保存模型的结构的完整信息。
即,保存的模型是以字典形式保存的,所以被称作为state_dict。上面实际上我们按照已经定义好的模型进行加载,所以使用model.load_state_dict。其中的键信息实际是原本模型的层次的名字,因此模型在重新读取的时候,需要我们先实例化完全一致的结构,再进行参数的加载。
如果model是pytorch的nn.module继承而来的,那么如下:
model_path = os.path.join(output, 'model.pth')
torch.save(model.state_dict(), model_path)
这里有.pth
的格式存储,还有.pkl
格式,以及.pt
的格式。
之后,如果要进行推理或者使用时加载模型,只需要模型的结构对应,就可以直接加载:
model.load_state_dict(torch.load(args.model_path))
# args.model_path就是模型的路径字符串,比如'model.pth'
总结如下:
state_dict()
获取模型的参数,而不保存结构load_state_dict
方法,其参数不是文件路径,而是 torch.load(PATH)这是完整的存储了模型的信息的方法,包括模型的参数信息、模型的结构信息、参数等等所有内容。和方法一相比,弊端是会占用更大的信息,优势是,我们不需要知道文件中的模型究竟是什么样的,直接读取即可使用了:
torch.save(model, PATH)
model = torch.load(PATH)
有时我们不仅要保存模型,还要连带保存一些其他的信息。比如在训练过程中保存一些 checkpoint
,往往除了模型,还要保存它的epoch
、loss
、optimizer
等信息,以便于加载后对这些 checkpoint 继续训练等操作;或者再比如,有时候需要将多个模型一起打包保存等。
这里我们主要将多个内容放入一个字典进行保存:
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
...
}, PATH)
加载的时候,我们需要将各个对应的元素按照原本的类别,进行数据初始化,例如优化器必须还是之前的优化器,模型还是之前的模型结构(主要这里例子是state_dict,不然直接保存模型也是可以的)
model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
我们时常会涉及到,在有GPU的服务器进行训练,但是在CPU上进行推理和使用的情况。正常的CPU训练、CPU加载或者GPU训练、GPU使用,都是没问题的,主要是设备不同时的问题。
最为正常和一般的情况,照常操作,不过还是别忘记把模型放到GPU上去。
GPUidx=0
device = torch.device('cuda:{}'.format(GPUidx) if torch.cuda.is_available() else 'cpu')
BATCH_SIZE = 64 # number of data points in each batch
N_EPOCHS = 15 # times to run the model on complete data
INPUT_DIM = 28 * 28 # size of each input
HIDDEN_DIM = 256 # hidden dimension
LATENT_DIM = 20 # latent vector dimension
encoder = Encoder(INPUT_DIM, HIDDEN_DIM, LATENT_DIM) # encoder
decoder = Decoder(LATENT_DIM, HIDDEN_DIM, INPUT_DIM) # decoder
VAEmodel = VAE(encoder, decoder).to(device)# vae
VAEmodel.load_state_dict(torch.load(modelpath))
保存的行为一致,我们只需要在torch.load时,对相应的参数map_location
进行设置即可:
torch.save(net.state_dict(), PATH)
device = torch.device("cpu")
loaded_net = Net()
loaded_net.load_state_dict(torch.load(PATH, map_location=device))
虽然一般不太可能,但还是啰嗦一下
torch.save(net.state_dict(), PATH)
device = torch.device("cuda")
loaded_net = Net()
loaded_net.load_state_dict(torch.load(PATH, map_location=device))
# or
loaded_net.to(device)