pytorch保存和加载模型state_dict

保存模型:

torch.save({
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
        }, datadir)

加载模型

model = model_class(num_classes=num_classes) # 定义模型
state = torch.load(datadir)
model.load_state_dict(state['state_dict'])

你可能感兴趣的:(pytorch,pytorch,深度学习,python)