torch模型保存

torch模型保存。

模型保存的本质就是利用pickle模块进行序列化。序列化到文件,从文件反序列化回来的对象,要么是Python自定义的对象,要么是本文件中已经定义的类。

import torch
import torch.nn as nn
import torch.optim as optim


class Model(nn.Module):

    def __init__(self, input_size, output_size):

        super(Model, self).__init__()
        self.linear1 = nn.Linear(input_size, input_size * 2)
        self.linear2 = nn.Linear(input_size * 2, output_size)

    def forward(self, inputs):

        inputs = self.linear1(inputs)
        output = self.linear2(inputs)
        return output

第一种方式

model = Model()
torch.save(model,'./model.pth')

model = torch.load('./model.pth')

第二种方式

model = Model()
torch.save(model.state_dict(), './model_state_dict.pth')

model = Model()
model.load_state_dict('./model_state_dict.pth')

你可能感兴趣的:(pytorch,深度学习,pytorch,python)