class torch.nn.
Conv2d
(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True)
参数:Conv2d(
输入channels,
输出channels,
kernel_size,)
备注:在定义了相应的卷积层后,会在model的state_dict中自动生成相关参数的state描述.
每定义一个卷积层,会在model的state_dict中自动生成两个tensor参数:
1)conv2d.weight shape=[输出channels,输入channels,kernel_size,kernel_size]
2)conv2d.bias shape=[输出channels]
举例:
#定义一个卷积层 conv1
conv1 = torch.nn.Conv2d(5,10,3)
# 则在model的state_dict中会自动保存以下参数
conv1.weight (torch.Size=[10, 5, 3, 3])
conv1.bias (torch.Size()=[10])
示例2:
#-*-coding:utf-8-*-
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
# define model
class TheModelClass(nn.Module):
def __init__(self):
super(TheModelClass,self).__init__()
self.conv1 = nn.Conv2d(3,6,5)
self.pool = nn.MaxPool2d(2,2)
self.conv2 = nn.Conv2d(6,16,5)
self.fc1 = nn.Linear(16*5*5,120)
self.fc2 = nn.Linear(120,84)
self.fc3 = nn.Linear(84,10)
def forward(self,x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1,16*5*5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
# initial model
model = TheModelClass()
#initialize the optimizer
optimizer = optim.SGD(model.parameters(),lr=0.001,momentum=0.9)
# print the model's state_dict
print("model's state_dict:")
for param_tensor in model.state_dict():
print(param_tensor,'\t',model.state_dict()[param_tensor].size())
print("\noptimizer's state_dict")
for var_name in optimizer.state_dict():
print(var_name,'\t',optimizer.state_dict()[var_name])
print("\nprint particular param")
print('\n',model.conv1.weight.size())
print('\n',model.conv1.weight)
输出:
model's state_dict:
conv1.weight torch.Size([6, 3, 5, 5])
conv1.bias torch.Size([6])
conv2.weight torch.Size([16, 6, 5, 5])
conv2.bias torch.Size([16])
fc1.weight torch.Size([120, 400])
fc1.bias torch.Size([120])
fc2.weight torch.Size([84, 120])
fc2.bias torch.Size([84])
fc3.weight torch.Size([10, 84])
fc3.bias torch.Size([10])
optimizer's state_dict
state {}
param_groups [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [139985995931368, 139984655959456, 139984655959600, 139984655959672, 139984655959744, 139984655959816, 139984655959888, 139984655959960, 139984655960032, 139984655960104]}]
print particular param
torch.Size([6, 3, 5, 5])
Parameter containing:
tensor([[[[-0.0080, -0.0503, 0.0092, -0.1068, -0.0789],
[-0.1028, -0.0067, -0.1015, -0.0660, 0.1107],
[ 0.0733, 0.0195, -0.0236, 0.0244, 0.0168],
[-0.0310, -0.0915, 0.0267, -0.0465, -0.0112],
[-0.0876, -0.0579, -0.0689, -0.0397, -0.1020]],
[[ 0.0148, -0.0605, -0.0428, -0.0280, -0.0038],
[-0.0452, 0.0938, 0.0793, -0.0857, 0.0700],
[-0.0463, -0.0326, -0.0130, 0.0460, 0.0138],
[ 0.1144, 0.0173, -0.0178, -0.0745, 0.0625],
[ 0.0713, 0.0400, -0.0596, -0.0878, -0.0773]],
[[ 0.0782, 0.0849, -0.0777, 0.0770, -0.0115],
[-0.0918, -0.0262, 0.0067, 0.0481, 0.0812],
[ 0.0411, -0.1067, 0.0187, 0.0250, 0.0964],
[ 0.0076, 0.0715, -0.0559, 0.0888, -0.0787],
[-0.0894, 0.0258, 0.1001, -0.0621, -0.0245]]],
[[[-0.0464, -0.0124, -0.0204, -0.0179, 0.0263],
[ 0.1148, 0.0955, -0.0630, 0.0382, -0.0889],
[ 0.1114, 0.0027, -0.0478, -0.0857, -0.0735],
[ 0.0446, 0.0893, -0.0671, 0.0066, -0.0356],
[-0.1027, 0.0593, -0.0410, -0.0647, 0.0377]],
[[-0.0145, 0.0259, -0.0488, -0.1128, -0.0441],
[-0.0269, -0.0213, 0.0958, -0.0159, -0.1011],
[ 0.0614, -0.0445, -0.0642, -0.0092, 0.0317],
[ 0.0399, -0.0608, -0.0156, 0.1112, 0.0865],
[ 0.0679, -0.0030, 0.0948, 0.0804, -0.0644]],
[[ 0.0625, 0.0002, -0.0690, 0.0803, -0.0091],
[ 0.0073, 0.1063, 0.0663, 0.0094, -0.0997],
[-0.0938, 0.0973, -0.0571, -0.0281, -0.0008],
[ 0.0502, -0.0266, -0.0459, -0.0831, 0.0589],
[ 0.1062, 0.0144, 0.0318, 0.0814, 0.0641]]],
[[[ 0.0706, 0.0121, -0.0918, 0.0571, -0.0780],
[ 0.0068, 0.0786, -0.0118, 0.0070, 0.0367],
[-0.0983, -0.0742, 0.0878, 0.1115, -0.0342],
[ 0.0682, -0.1151, 0.0689, -0.1039, -0.0854],
[-0.0185, 0.0474, -0.0282, -0.0707, -0.0105]],
[[-0.0562, 0.0887, 0.0002, 0.0974, 0.1088],
[-0.0568, 0.0291, 0.0522, -0.0791, -0.0136],
[ 0.0480, 0.0764, 0.1015, 0.0315, -0.0715],
[ 0.0078, 0.1052, 0.0647, -0.0707, -0.0269],
[-0.0742, 0.1057, 0.0410, 0.0867, -0.0098]],
[[-0.0847, 0.0005, 0.0210, 0.1104, -0.0865],
[ 0.0424, -0.0321, -0.0856, 0.0761, -0.1053],
[-0.0995, 0.0792, 0.0428, 0.0239, 0.0532],
[-0.0705, 0.0683, -0.0691, 0.0287, -0.0657],
[-0.0518, -0.0395, 0.0270, 0.0997, -0.0581]]],
[[[ 0.0071, 0.1119, 0.0198, 0.0697, 0.0853],
[-0.0718, -0.0216, -0.0026, 0.0939, 0.0791],
[ 0.0584, -0.0262, 0.0226, 0.0166, -0.0898],
[ 0.1004, -0.0992, 0.0630, 0.0591, 0.0152],
[-0.0731, -0.0343, 0.0821, 0.0518, -0.0257]],
[[-0.0847, 0.1124, -0.0815, -0.0989, 0.0975],
[ 0.0750, -0.0998, -0.0341, 0.0603, 0.0299],
[ 0.0504, -0.0782, -0.0870, 0.0940, -0.0717],
[-0.0387, 0.1046, -0.0216, 0.0870, -0.0550],
[-0.0772, 0.0888, 0.0341, 0.0018, 0.0923]],
[[-0.0257, -0.0024, -0.0461, 0.0309, -0.0204],
[ 0.0782, -0.1152, -0.1073, -0.0128, -0.1088],
[ 0.0238, 0.0951, -0.1048, 0.1055, 0.1090],
[ 0.0984, -0.0634, 0.0864, 0.1067, -0.1024],
[-0.0499, 0.1054, 0.0025, -0.0640, -0.0089]]],
[[[-0.0263, 0.0849, -0.0872, -0.0457, -0.1010],
[-0.0327, 0.0176, -0.0301, 0.0329, 0.0561],
[-0.0325, 0.0409, -0.0862, 0.0603, -0.0904],
[-0.0352, 0.0723, 0.0955, -0.0478, -0.1055],
[-0.0711, -0.0076, -0.0725, -0.0856, 0.0413]],
[[ 0.0999, -0.0613, -0.0390, -0.1126, 0.0182],
[ 0.0302, 0.0699, 0.0263, 0.0594, 0.0965],
[-0.0062, 0.0779, 0.0010, 0.0617, 0.0596],
[ 0.0058, -0.0344, 0.0266, -0.0754, -0.0667],
[ 0.0120, 0.1121, -0.0693, 0.0516, 0.0863]],
[[-0.0897, -0.0838, -0.0126, 0.0938, 0.0570],
[ 0.0729, 0.0482, 0.0066, 0.0559, -0.0951],
[ 0.0750, 0.0592, 0.0550, 0.0671, 0.0661],
[-0.1132, -0.0496, -0.0931, 0.0659, -0.0453],
[ 0.0177, 0.0018, 0.0622, 0.0571, 0.1092]]],
[[[ 0.0697, 0.0629, 0.0071, 0.0266, 0.0199],
[-0.1087, 0.1084, 0.0488, -0.0162, 0.1147],
[-0.0944, -0.1005, -0.0494, 0.0163, -0.0477],
[ 0.0199, -0.0245, 0.0768, -0.0319, -0.0087],
[ 0.0823, 0.1125, -0.0000, -0.0238, -0.0647]],
[[ 0.0107, -0.0313, -0.0060, 0.0010, 0.0102],
[-0.0748, 0.0240, -0.0658, -0.0524, 0.0908],
[-0.0921, -0.1004, -0.0492, 0.0021, 0.0020],
[-0.1136, 0.0122, 0.0324, 0.0125, 0.0843],
[-0.0888, 0.0573, 0.0286, 0.0672, 0.0266]],
[[-0.0215, -0.0275, -0.0994, 0.1052, 0.1087],
[ 0.0008, -0.1082, -0.0890, 0.0155, 0.0612],
[ 0.0211, 0.0042, -0.0483, 0.0919, -0.1100],
[-0.0703, -0.0263, -0.0256, -0.0122, -0.0594],
[-0.0150, -0.0508, -0.0393, -0.1073, 0.0849]]]],
requires_grad=True)
参考:https://pytorch.org/docs/stable/nn.html#conv2d
https://pytorch.org/tutorials/beginner/saving_loading_models.html