【pytorch】模型的保存与加载|| Dataloader数据加载器

Pytorch模型保存与加载,并在加载的模型基础上继续训练

系统学习Pytorch笔记三:Pytorch数据读取机制(DataLoader)与图像预处理模块(transforms)

一、只保存参数

1. 保存

一般地,采用一条语句即可保存参数:

torch.save(model.state_dict(), path)

其中model指定义的模型实例变量,如 model=vgg16( ), path是保存参数的路径,如 path=‘./model.pth’ , path=‘./model.tar’, path=‘./model.pkl’, 保存参数的文件一定要有后缀扩展名。

特别地,如果还想保存某一次训练采用的优化器、epochs等信息,可将这些信息组合起来构成一个字典,然后将字典保存起来:

state = {'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch}
torch.save(state, path)

2. 加载
针对上述第一种情况,也只需要一句即可加载模型:

model.load_state_dict(torch.load(path))

针对上述第二种以字典形式保存的方法,加载方式如下:

checkpoint = torch.load(path)  # load model
model.load_state_dict(checkpoint['model'])   # load parameters
optimizer.load_state_dict(checkpoint['optimizer'])    # load optimizer
epoch = checkpoint(['epoch'])           # load epoch for training continue

二、保存整个模型

1. 保存

torch.save(model, path)

2. 加载

model = torch.load(path)

三、在训练中pytorch通过Dataloader加载数据

torch.utils.data.DataLoader(): 构建可迭代的数据装载器, 我们在训练的时候,每一个for循环,每一次iteration,就是从DataLoader中获取一个batch_size大小的数据的。

例如:

train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

【pytorch】模型的保存与加载|| Dataloader数据加载器_第1张图片

你可能感兴趣的:(pytorch,人工智能,python)