pytorch模型及参数的保存和提取

保存网络:

1.  torch.save(net1, 'net.pkl')  #保存整个网络

2.  torch.save(net1.state_dict(), 'net_params.pkl')  #只保存网络中的参数,(速度快,占内存少)

 

提取网络:

1.  net2 = torch.load('net.pkl')

2. net3 = torch.nn.Sequential(

              torch.nn.Linear(1,10),

              torch.nn.ReLU(),

              torch.nn.Linear(10,1)

               )  

   net3.load_state_dict(torch.load('net_params.pkl'))

 

你可能感兴趣的:(pytorch模型及参数的保存和提取)