Pytorch训练流程图

Pytoch训练流程

    • 一 读取图片
    • 二 定义模型
    • 三 保存模型

一 读取图片

imageFolder and DataLoader
DataLoader可自定义,但需实现两个抽象函数,分别是__len__(selft)__getitem__(self, index)

二 定义模型


# Download and load the pretrained ResNet-18.
resnet = mm.resnet18(pretrained=True)

三 保存模型

# -------- 7. Save and load the model. --------#
# Save and load the entire model.
torch.save(resnet, 'model.mkl')
model = torch.load('model.mkl')

# Save and load only the model parameters (recommended).
torch.save(resnet.state_dict(), 'params.mkl')
resnet.load_state_dict(torch.load('params.mkl'))

你可能感兴趣的:(学习笔记)