pytorch保存训练模型以及加载模型

【1】训练中保存网络参数
model_save_dir = 为你要保存的位置
torch.save(net.state_dict(), model_save_dir)
【2】测试中加载保存的网络
net.load_state_dict(torch.load(‘model_save_dir’))

你可能感兴趣的:(pytorch保存训练模型以及加载模型)