add_param_group()
此函数是向optimizer
中添加优化参数param
import torch
import torch.optim as optim
w1 = torch.randn(3, 3)
w1.requires_grad = True
w2 = torch.randn(3, 3)
w2.requires_grad = True
o = optim.Adam([w1])
print("o.param_groups:-----------------------------")
print(o.param_groups)
'''
[{'amsgrad': False,
'betas': (0.9, 0.999),
'eps': 1e-08,
'lr': 0.001,
'params': [tensor([[ 2.9064, -0.2141, -0.4037],
[-0.5718, 1.0375, -0.6862],
[-0.8372, 0.4380, -0.1572]])],
'weight_decay': 0}]
'''
print("o.param_groups:-----------------------------")
o.add_param_group({'params': w2, 'eps': 1e-7})
print(o.param_groups)
'''
[{'amsgrad': False,
'betas': (0.9, 0.999),
'eps': 1e-08,
'lr': 0.001,
'params': [tensor([[ 2.9064, -0.2141, -0.4037],
[-0.5718, 1.0375, -0.6862],
[-0.8372, 0.4380, -0.1572]])],
'weight_decay': 0},
{'amsgrad': False,
'betas': (0.9, 0.999),
'eps': 1e-07,
'lr': 0.001,
'params': [tensor([[-0.0560, 0.4585, -0.7589],
[-0.1994, 0.4557, 0.5648],
[-0.1280, -0.0333, -1.1886]])],
'weight_decay': 0}]
'''
load_state_dict(state_dict)
一般来说之预训练模型加载到state_dict中,然后加载到优化器中
step()
在初始化optimizer
的时候我们会传入模型参数model.parameters()
,在反向传播之后我们会得到参数的梯度,此步骤的作用就是用梯度以及优化器来更新参数
zero.grad()
将参数的梯度设为0,如果每次更新参数之后不把梯度设为0的话到下次求梯度时,梯度会累加
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
# define model
class TheModelClass(nn.Module):
def __init__(self):
super(TheModelClass,self).__init__()
self.conv1=nn.Conv2d(3, 6, 5)
self.pool=nn.MaxPool2d(2, 2)
self.conv2=nn.Conv2d(6, 16, 5)
self.fc1=nn.Linear(16*5*5, 120)
self.fc2=nn.Linear(120, 84)
self.fc3=nn.Linear(84, 10)
def forward(self,x):
x=self.pool(F.relu(self.conv1(x)))
x=self.pool(F.relu(self.conv2(x)))
x=x.view(-1,16*5*5)
x=F.relu(self.fc1(x))
x=F.relu(self.fc2(x))
x=self.fc3(x)
return x
def main():
# Initialize model
model = TheModelClass()
# Initialize optimizer
optimizer=optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
'''
model的state_dict()与optimizer的略有不同
model:
torch.nn.Module模块中的state_dict只包含卷积层和全连接层的参数
当网络中存在batchnorm时,例如vgg网络结构,torch.nn.Module模块中
的state_dict也会存放batchnorm's running_mean
optimizer:
state_dict字典对象包含state和param_groups的字典对象,而param_groups key
对应的value也是一个由学习率,动量等参数组成的一个字典对象
'''
# print model state_dict
print('Model.state_dict: ')
model_param = model.state_dict()
for param_tensor in model_param:
# print key value字典
print(param_tensor, '\t', model.state_dict()[param_tensor].size())
'''
Model.state_dict:
conv1.weight torch.Size([6, 3, 5, 5])
conv1.bias torch.Size([6])
conv2.weight torch.Size([16, 6, 5, 5])
conv2.bias torch.Size([16])
fc1.weight torch.Size([120, 400])
fc1.bias torch.Size([120])
fc2.weight torch.Size([84, 120])
fc2.bias torch.Size([84])
fc3.weight torch.Size([10, 84])
fc3.bias torch.Size([10])
'''
# print optimizer state_dict
print('Optimizer state_dict: ')
optim_param = optimizer.state_dict()
for var_name in optim_param:
print(var_name, '\t', optimizer.state_dict()[var_name])
'''
Optimizer state_dict:
state {}
param_groups [{'lr': 0.001,
'momentum': 0.9,
'dampening': 0,
'weight_decay': 0,
'nesterov': False,
'params': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]}]
'''
# optimizer params
params_groups = optimizer.param_groups
print("optimizer.param_groups: ")
print(params_groups)
'''
此处会打印出每一个参数的具体数值
这里假设是torch.nn.Linear(1, 1),有两个参数 w&&b
[{'params': [Parameter containing:
tensor([[0.8117]], device='cuda:0', requires_grad=True),
Parameter containing:
tensor([0.8024], device='cuda:0', requires_grad=True)],
'lr': 0.001,
'momentum': 0,
'dampening': 0,
'weight_decay': 0,
'nesterov': False}]
'''
if __name__=='__main__':
main()
param_groups
方法可以返回关于优化器内部的参数字典,其中包含有model.parameters()
的具体每一项的数值以及optimizer
的相关参数,比如下面对于torch.nn.SGD
会返回其中每一个参数字典
import torch.optim
import torch.nn
model = torch.nn.Linear(1, 1)
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
params_groups = optimizer.param_groups
print("optimizer.param_groups: ")
print(params_groups)
'''
optimizer.param_groups:
[{'params': [Parameter containing:
tensor([[0.7701]], requires_grad=True),
Parameter containing:
tensor([0.8500], requires_grad=True)],
'lr': 0.001,
'momentum': 0.9,
'dampening': 0,
'weight_decay': 0,
'nesterov': False}]
'''