PYTORCH保存训练好的模型

1.保存模型的dict

#...
XXXmodel = model(*args, **kwargs)
torch.save(XXXmodel.state_dict(),'onlyDict.pkl'):

读取dict

pre_model = model(*args, **kwargs)
pre_model.load_state_dict(torch.load('onlyDict.pkl'))

2.保存整个模型(费时间、内存)

#...
XXXmodel = model(*args, **kwargs)
torch.save(XXXmodel,'fullmodel.pkl'):

读取model

pre_model = torch.load('fullmodel.pkl')

你可能感兴趣的:(python编程技巧,NLP)