pytorch保存训练好的模型及pytorch自己定义损失函数

一.pytorch保存训练好的模型

假设你的模型定义为:

class Net(torch.nn.Module):

    ......

两种方式:

仅仅保存和加载模型参数:

#保存
PATH="./model.pkl"
the_model = Net()
torch.save(the_model.state_dict(), PATH)
#加载
the_model = Net()
the_model.load_state_dict(torch.load(PATH))

保存和加载整个模型

#保存
the_model=Net()
torch.save(the_model, PATH)
#加载
the_model=Net()
the_model.eval()#加这个是为了和训练时dropout等的设置保持一致
the_model = torch.load(PATH)

参考链接:

https://blog.csdn.net/u011276025/article/details/78507950

https://blog.csdn.net/u011276025/article/details/72817353

二. pytorch自己定义损失函数

两种方式,这个链接讲的很清楚:

https://blog.csdn.net/qq_27825451/article/details/95165265

你可能感兴趣的:(python)