Pytorch state_dict介绍

Introduce

在pytorch中,torch.nn.Module模块中的state_dict变量存放训练过程中需要学习的权重和偏执系数,state_dict作为python的字典对象将每一层的参数映射成tensor张量,需要注意的是torch.nn.Module模块中的state_dict只包含卷积层和全连接层的参数,当网络中存在batchnorm时,例如vgg网络结构,torch.nn.Module模块中的state_dict也会存放batchnorm's running_mean,关于batchnorm详解可见https://blog.csdn.net/wzy_zju/article/details/81262453

torch.optim模块中的Optimizer优化器对象也存在一个state_dict对象,此处的state_dict字典对象包含state和param_groups的字典对象,而param_groups key对应的value也是一个由学习率,动量等参数组成的一个字典对象。

因为state_dict本质上Python字典对象,所以可以很好地进行保存、更新、修改和恢复操作(python字典结构的特性),从而为PyTorch模型和优化器增加了大量的模块化。

Sample

通过一个简单的案例来输出state_dict字典对象中存放的变量

  1. #encoding:utf-8

  2.  
  3. import torch

  4. import torch.nn as nn

  5. import torch.optim as optim

  6. import torchvision

  7. import numpy as mp

  8. import matplotlib.pyplot as plt

  9. import torch.nn.functional as F

  10.  
  11. #define model

  12. class TheModelClass(nn.Module):

  13. def __init__(self):

  14. super(TheModelClass,self).__init__()

  15. self.conv1=nn.Conv2d(3,6,5)

  16. self.pool=nn.MaxPool2d(2,2)

  17. self.conv2=nn.Conv2d(6,16,5)

  18. self.fc1=nn.Linear(16*5*5,120)

  19. self.fc2=nn.Linear(120,84)

  20. self.fc3=nn.Linear(84,10)

  21.  
  22. def forward(self,x):

  23. x=self.pool(F.relu(self.conv1(x)))

  24. x=self.pool(F.relu(self.conv2(x)))

  25. x=x.view(-1,16*5*5)

  26. x=F.relu(self.fc1(x))

  27. x=F.relu(self.fc2(x))

  28. x=self.fc3(x)

  29. return x

  30.  
  31. def main():

  32. # Initialize model

  33. model = TheModelClass()

  34.  
  35. #Initialize optimizer

  36. optimizer=optim.SGD(model.parameters(),lr=0.001,momentum=0.9)

  37.  
  38. #print model's state_dict

  39. print('Model.state_dict:')

  40. for param_tensor in model.state_dict():

  41. #打印 key value字典

  42. print(param_tensor,'\t',model.state_dict()[param_tensor].size())

  43.  
  44. #print optimizer's state_dict

  45. print('Optimizer,s state_dict:')

  46. for var_name in optimizer.state_dict():

  47. print(var_name,'\t',optimizer.state_dict()[var_name])

  48.  
  49.  
  50.  
  51. if __name__=='__main__':

  52. main()

  53.  

  54. 具体的输出结果如下:可以很清晰的观测到state_dict中存放的key和value的值

  55. Model.state_dict:

  56. conv1.weight torch.Size([6, 3, 5, 5])

  57. conv1.bias torch.Size([6])

  58. conv2.weight torch.Size([16, 6, 5, 5])

  59. conv2.bias torch.Size([16])

  60. fc1.weight torch.Size([120, 400])

  61. fc1.bias torch.Size([120])

  62. fc2.weight torch.Size([84, 120])

  63. fc2.bias torch.Size([84])

  64. fc3.weight torch.Size([10, 84])

  65. fc3.bias torch.Size([10])

  66. Optimizer,s state_dict:

  67. state {}

  68. param_groups [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [367949288, 367949432, 376459056, 381121808, 381121952, 381122024, 381121880, 381122168, 381122096, 381122312]}]

你可能感兴趣的:(Pytorch state_dict介绍)