optimizer.state_dict()和optimizer.param_groups的区别

参考
pytorch包含多种优化算法用于网络参数的更新,比如常用的SGD、Adam、LBFGS以及RMSProp等。使用中可以发现各种优化算法的使用方式几乎相同,是因为父类optimizer【1】定义了各个子类(即SGD等)的核心行为,下面是optimizer类注释:

class Optimizer(object):
    r"""Base class for all optimizers.
    Arguments:
        params (iterable): an iterable of :class:`torch.Tensor` s or
            :class:`dict` s. Specifies what Tensors should be optimized.
    """

其中首句“所有优化器的基类” 表明所有的优化器都必须继承optimizer类,下面来分析optimizer类的的各个实例函数。

1、optimizer.param_groups

优化器需要保存学习率等参数的值,所以optimizer类需要用实例属性来存储这些参数,也就是__init__()中的self.param_groups,下面的代码通过一个全连接网络来测试优化器的param_groups包含哪些参数:

net = nn.Linear(2, 2)
# 权重矩阵初始化为1
nn.init.constant_(net.weight, val=100)
nn.init.constant_(net.bias, val=20)
optimizer = optim.SGD(net.parameters(), lr=0.01)
print(optimizer.param_groups)
得到
[{'params': [Parameter containing:
tensor([[ 100.,  100.],
        [ 100.,  100.]]), Parameter containing:
tensor([20,, 20])], 'lr': 0.01, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False}]

其中2x2的矩阵是net的权重矩阵,1x2为偏置矩阵其余为优化器的其它参数,所以说param_groups保存了优化器的全部数据,这个下面的state_dict()不同。

2、优化器状态optimizer.state_dict()

def state_dict(self):
    r"""Returns the state of the optimizer as a :class:`dict` """
    # Save ids instead of Tensors
    def pack_group(group):
        # 对"params"和其它的键采用不同规则
        packed = {k: v for k, v in group.items() if k != 'params'}
        # 这里并没有保存参数的值,而是保存参数的id
        packed['params'] = [id(p) for p in group['params']]
        return packed
    # 对self.param_groups进行遍历
    param_groups = [pack_group(g) for g in self.param_groups]
    # Remap state to use ids as keys
    packed_state = {(id(k) if isinstance(k, torch.Tensor) else k): v
                    for k, v in self.state.items()}
    # 返回状态和参数组,其中参数组才是优化器的参数
    return {
        'state': packed_state,
        'param_groups': param_groups,
    }
print(optimizer.state_dict()["param_groups"])

可以到优化器的完整参数如下:

[{'lr': 0.01, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 
'nesterov': False, 'params': [2149749904224, 2149749906312]}]

你可能感兴趣的:(pytorch,python,开发语言)