模型保存和加载

模型的保存和加载各有两种方法

模型保存方法1 :模型结构+模型参数

# 保存模型方式1
torch.save(vgg16_true,'./models/vgg16_true.pth')
torch.save(vgg16_fulse,'./models/vgg16_false.pth')

相应模型加载方法:

# 保存模型方式1(保存模型结构+参数)相应加载模型方式
vgg16_true = torch.load('./models/vgg16_true.pth')
print(vgg16_true)

模型保存方式2:模型参数(官方推荐) ,因为这个方式,储存量小

# # 把网络模型的参数,保存下来,储存成字典的形式
torch.save(vgg16_true.state_dict(),'./models/vgg16_true_2.pth')
torch.save(vgg16_false.state_dict(),'./models/vgg16_false_2.pth')

相应模型加载方法:

# 保存模型方式2(模型参数)相应加载模型方式
module = torch.load('./models/vgg16_true_2.pth')  # 数据呈字典形式
vgg16 = torchvision.models.vgg16(pretrained=False)  # 新建网络模型
vgg16 = torch.load(module)  # 网络模型加载字典形式参数

关于参数pretrained=True or pretrained=False

# 这两个模型可以用debug看一下里面的参数,有很大的不同(初始化参数,偏置bias全为0)
vgg16_true = torchvision.models.vgg16(pretrained=True)  # 模型结构+训练好的参数
vgg16_fulse = torchvision.models.vgg16(pretrained=False)  # 模型结构+初始化参数

vgg16_true :# 模型结构+训练好的参数
模型保存和加载_第1张图片模型保存和加载_第2张图片vgg16_fulse :# 模型结构+初始化参数
模型保存和加载_第3张图片

你可能感兴趣的:(python)