pytorch1.0加载自己训练一半的模型

pytorch 1.0.1
pycharm

网络上有很多答案,几乎都是0.4版本的,其实到了1.0就非常容易了。

def load_checkpoint(model, checkpoint_PATH):

    model_CKPT = torch.load(checkpoint_PATH) # 之前模型的路径索引
    model.load_state_dict(model_CKPT)
    print('loading checkpoint!')
    return model

是不是感觉巨简单…然后只要在net初始化之后使用就好了。

 net = ResNet101()
 opt = optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=0.0005)

 net = load_checkpoint(net,
                       checkpoint_PATH="./model_52.pth")

你可能感兴趣的:(torch)