pytorch保存参数及模型的两种方式

在pytorch中已经帮助我们写好了保存模型的方法,一般有两种方式,方法一是保存整个模型,方法二是只保存模型参数状态。

其实对于原生python也有保存模型的方式,可以使用dump、load来保存。

class Linear(nn.Module):
    def __init__(self):
        super().__init__()
        self.w1 = nn.Parameter(torch.randn(3, 4))
        self.b1 = nn.Parameter(torch.randn(1, 3))
        
        self.w2 = nn.Parameter(torch.randn(3, 2))
        self.b2 = nn.Parameter(torch.randn(1, 2))
        
    def forward(self, x):
        x = F.linear(x, self.w1, self.b1)
        return F.linear(x, self.w2, self.b2)

方法一

对于方法一是将整个模型全部保存,包括模型参数及模型的结构都会保存,所以模型较重,读写速度较慢,而且这种方式容易出错,虽然方法使用简单,但是不推荐使用。

model = Linear()

torch.save(model, 'model.pth')
new_model = torch.load('model.pth')

方法二

方法二是只保存模型对应的参数,创建好一个新的模型,我们只需要将参数读取到新的模型中即可,这种方法尤为推荐,只不过相对方法一写起来会复杂一点。

torch.save(model.state_dict(), 'model.pth')

new_model = Linear()
new_model.load_state_dict(torch.load('model.pth'))

注意一点,二者读取参数时,方法一不需要实例化新模型,直接读取model.pth就会返回一个模型,因为方法一会保存模型的结构信息。

对于方法二需要提前定义模型,因为方法二只保存参数,我们需要先实例化一个模型,然后把读取的模型参数加载到模型中。

你可能感兴趣的:(PyTorch,pytorch,深度学习,python,人工智能,神经网络)