Pytorch自己的网络模型的保存和加载

1 自己写的网络模型的保存(两种方式)

1.1 第一种方式(保存整个网络结构+网络模型参数)

 torch.save(net, 'net.pth')

1.2 第二种方式(只保存网络模型参数)

这种方式是官方推荐的,因为它占的内存比第一种方式小,但是也不会小很多。但是我不推荐使用,因为使用起来比第一种要麻烦很多。

torch.save(net.state_dict(), 'net_params.pth')

2 自己写的网络模型的加载

首选需要说明的一点是,不管上述的那种方式,在我们加载网络模型的时候都需要有预设的网络结构,例如下边代码,否则会提示找不到相应的module

2.1 第一种加载方式

model = torch.load('net.pth') 

实例代码

#加载整个网络
 
class AlexNet(nn.Module):
    def __init__(self):
        super(AlexNet,self).__init__()
        self.conv1 = nn.Conv2d(3, 64, 5)
        self.pool1 = nn.MaxPool2d(3, 2)
        self.conv2 = nn.Conv2d(64, 64, 5)
        self.pool2 = nn.MaxPool2d(3, 2)
        self.fc1 = nn.Linear(1024, 384)
        self.fc2 = nn.Linear(384, 192)
        self.fc3 = nn.Linear(192, 10)
 
    def forward(self, x):
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = x.view(x.shape[0], -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.softmax(self.fc3(x))
        return x
 
net = torch.load("TestSave.pkl")#加载整个模型时直接用这句就可以实例化网络,并且把CUDA上运行这个属性也继承了过来
net.eval() #加上这句后效果更好

2.2 第二种加载方式

model_object.load_state_dict(torch.load('net_params.pth')) 

实例代码

#只加载网络参数
 
class AlexNet(nn.Module):
    def __init__(self):
        super(AlexNet,self).__init__()
        self.conv1 = nn.Conv2d(3, 64, 5)
        self.pool1 = nn.MaxPool2d(3, 2)
        self.conv2 = nn.Conv2d(64, 64, 5)
        self.pool2 = nn.MaxPool2d(3, 2)
        self.fc1 = nn.Linear(1024, 384)
        self.fc2 = nn.Linear(384, 192)
        self.fc3 = nn.Linear(192, 10)
 
    def forward(self, x):
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = x.view(x.shape[0], -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.softmax(self.fc3(x))
        return x
 
net = AlexNet()#只加载网络参数的时候需要自行实例化网络
net.cuda()#并设置网络运行在cpu还是gpu上
 
net.load_state_dict(torch.load('net_params.pth'))#再加载网络的参数
net.eval()

3 注意:

1.只加载网络参数的速度比加载整个网络快得多
2.pth、pkl格式效果相同,ckpt是tensorflow的格式

参考文章

保存加载模型的两种方式

你可能感兴趣的:(Pytorch学习,pytorch,网络,python)