pytorch文档阅读(五)如何保存、加载网络模型

1.网络的保存

torch.save()有两种方法

1)仅保存网络参数

torch.save(net.state_dict(), 'net_params.pkl')

2)保存整个网络结构

 torch.save(net, 'net.pkl')

2.网络的加载

1)仅加载参数

model_object.load_state_dict(torch.load('net_params.pkl')) 

2)加载整个模型

model = torch.load('net.pkl') 

两种方法在载入模型时都需要有预设的网络结构,例如下边代码,否则会提示找不到相应的module

#加载整个网络

class AlexNet(nn.Module):
    def __init__(self):
        super(AlexNet,self).__init__()
        self.conv1 = nn.Conv2d(3, 64, 5)
        self.pool1 = nn.MaxPool2d(3, 2)
        self.conv2 = nn.Conv2d(64, 64, 5)
        self.pool2 = nn.MaxPool2d(3, 2)
        self.fc1 = nn.Linear(1024, 384)
        self.fc2 = nn.Linear(384, 192)
        self.fc3 = nn.Linear(192, 10)

    def forward(self, x):
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = x.view(x.shape[0], -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.softmax(self.fc3(x))
        return x

net = torch.load("TestSave.pkl")#加载整个模型时直接用这句就可以实例化网络,并且把CUDA上运行这个属性也继承了过来
#只加载网络参数

class AlexNet(nn.Module):
    def __init__(self):
        super(AlexNet,self).__init__()
        self.conv1 = nn.Conv2d(3, 64, 5)
        self.pool1 = nn.MaxPool2d(3, 2)
        self.conv2 = nn.Conv2d(64, 64, 5)
        self.pool2 = nn.MaxPool2d(3, 2)
        self.fc1 = nn.Linear(1024, 384)
        self.fc2 = nn.Linear(384, 192)
        self.fc3 = nn.Linear(192, 10)

    def forward(self, x):
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = x.view(x.shape[0], -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.softmax(self.fc3(x))
        return x

net = AlexNet()#只加载网络参数的时候需要自行实例化网络
net.cuda()#并设置网络运行在cpu还是gpu上

net.load_state_dict(torch.load('net_params.pkl'))#再加载网络的参数

注意:

1.只加载网络参数的速度比加载整个网络快得多

2.pth、pkl格式效果相同,ckpt是tensorflow的格式

参考链接:

https://www.jb51.net/article/139102.htm

https://www.jianshu.com/p/0eda629e4007

你可能感兴趣的:(pytorch,机器学习)