网络模型的保存与读取

保存方式1: 模型结构+模型参数

torch.save(vgg16, 'vgg16_method1.pth')

方式一–>保存方式1, 加载模型

model1 = torch.load("vgg16_method1.pth")

保存方式2:模型参数(官方推荐,小)

torch.save(vgg16.state_dict(), 'vgg16_method2.pth')

方式二

model2 = torch.load("vgg16_method2.pth")
print(model2)

网络模型的保存与读取_第1张图片

完整版

vgg16 = torchvision.models.vgg16(weights=None)  # 不训练
vgg16.load_state_dict(torch.load("vgg16_method2.pth"))
# model2 = torch.load("vgg16_method2.pth")
print(vgg16)

网络模型的保存与读取_第2张图片

方式一的陷阱

保存

class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3)
    def forward(self, x):
        x = self.conv1(x)
        return x
tudui = Tudui()
torch.save(tudui, 'tudui_method1.pth')

加载

model1_1 = torch.load("tudui_method1.pth")
print(model1_1)

网络模型的保存与读取_第3张图片

修改,就是要把原网络模型抄过来

class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3)
    def forward(self, x):
        x = self.conv1(x)
        return x
# tudui = Tudui()  # 这行不用了
model1_1 = torch.load("tudui_method1.pth")
print(model1_1)

网络模型的保存与读取_第4张图片

正常写项目时肯定不会复制来复制去,一定是在一个文件夹,直接引入就好了

from model_save import *
model1_1 = torch.load("tudui_method1.pth")
print(model1_1)

同样也是可以的
网络模型的保存与读取_第5张图片

你可能感兴趣的:(神经网络,pytorch,python,深度学习,人工智能)