
class torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True)







         1)conv2d.weight     shape=[输出channels,输入channels,kernel_size,kernel_size]

         2)conv2d.bias   shape=[输出channels]


#定义一个卷积层 conv1
conv1 = torch.nn.Conv2d(5,10,3)
# 则在model的state_dict中会自动保存以下参数
conv1.weight    (torch.Size=[10, 5, 3, 3])
conv1.bias      (torch.Size()=[10])


import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# define model
class TheModelClass(nn.Module):
    def __init__(self):
        self.conv1 = nn.Conv2d(3,6,5)
        self.pool = nn.MaxPool2d(2,2)
        self.conv2 = nn.Conv2d(6,16,5)
        self.fc1 = nn.Linear(16*5*5,120)
        self.fc2 = nn.Linear(120,84)
        self.fc3 = nn.Linear(84,10)

    def forward(self,x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1,16*5*5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# initial model
model = TheModelClass()

#initialize the optimizer
optimizer = optim.SGD(model.parameters(),lr=0.001,momentum=0.9)

# print the model's state_dict
print("model's state_dict:")
for param_tensor in model.state_dict():

print("\noptimizer's state_dict")
for var_name in optimizer.state_dict():

print("\nprint particular param")


model's state_dict:
conv1.weight 	 torch.Size([6, 3, 5, 5])
conv1.bias 	 torch.Size([6])
conv2.weight 	 torch.Size([16, 6, 5, 5])
conv2.bias 	 torch.Size([16])
fc1.weight 	 torch.Size([120, 400])
fc1.bias 	 torch.Size([120])
fc2.weight 	 torch.Size([84, 120])
fc2.bias 	 torch.Size([84])
fc3.weight 	 torch.Size([10, 84])
fc3.bias 	 torch.Size([10])

optimizer's state_dict
state 	 {}
param_groups 	 [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [139985995931368, 139984655959456, 139984655959600, 139984655959672, 139984655959744, 139984655959816, 139984655959888, 139984655959960, 139984655960032, 139984655960104]}]

print particular param

 torch.Size([6, 3, 5, 5])

 Parameter containing:
