pytorch模型导入、修改、保存、读取

pytorch对模型的操作

导入

这里以vgg11的classifier层为例
导入模型:

vgg11 = torchvision.models.vgg11(pretrained=False)
print(vgg11)

其中classifier的输出如下
pytorch模型导入、修改、保存、读取_第1张图片

修改

主要的修改方式:

# 在某层中添加层
vgg11.classifier.add_module('new_linear', nn.Linear(1000, 10))
# 修改某层
vgg11.classifier[6] = nn.Linear(4096, 10)
print(vgg11)

修改后:
pytorch模型导入、修改、保存、读取_第2张图片

保存与读取

# 模型的保存与读取
# 方式1,保存模型和参数
torch.save(vgg11, 'vgg11_method1.pth')
# 在读取时需要保证原模型已经引入
model = torch.load('vgg11_method1.pth')

# 方式2,只保存模型参数,一个字典形式(官方推荐)
torch.save(vgg11.state_dict(), 'vgg11_method2.pth')

vgg11_new = torchvision.models.vgg11(pretrained=False)
vgg11_new.load_state_dict(torch.load('vgg11_method2.pth'))

你可能感兴趣的:(pytorch,pytorch,python,深度学习)