pytorch中SGD源码解读

调用方法:

torch.optim.SGD(params, lr=<required parameter>, momentum=0, dampening=0, weight_decay=0, nesterov=False)

momentum: 动量参数
dampen ing:梯度抑制参数
weight_cay:L2的参数
nesterov:是否使用neterov动量

    def step(self, closure=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            weight_decay = group['weight_decay'] # 权重衰减系数
            momentum = group['momentum'] # 动量因子,0.9或0.8
            dampening = group['dampening'] # 梯度抑制因子
            nesterov = group['nesterov'] # 是否使用nesterov动量

            for p in group['params']:
                if p.grad is None:
                    continue
                d_p = p.grad.data
                if weight_decay != 0: # 进行正则化
                	# add_表示原处改变,d_p = d_p + weight_decay*p.data
                    d_p.add_(weight_decay, p.data)
                if momentum != 0:
                    param_state = self.state[p] # 之前的累计的数据,v(t-1)
                    # 进行动量累计计算
                    if 'momentum_buffer' not in param_state:
                        buf = param_state['momentum_buffer'] = torch.clone(d_p).detach()
                    else:
                    	# 之前的动量
                        buf = param_state['momentum_buffer']
                        # buf= buf*momentum + (1-dampening)*d_p
                        buf.mul_(momentum).add_(1 - dampening, d_p)
                    if nesterov: # 使用neterov动量
                    	# d_p= d_p + momentum*buf
                        d_p = d_p.add(momentum, buf)
                    else:
                        d_p = buf
				# p = p - lr*d_p
                p.data.add_(-group['lr'], d_p)

        return loss

Pytorch中的SGD采用下述方式
pytorch中SGD源码解读_第1张图片
也即
pytorch中SGD源码解读_第2张图片
代码中,
d_p = d_p + weight_decayp.data # 权重衰减,这里实际上是做的 L2正则
buf = buf
momentum + (1-dampening)d_p # 计算动量,即v
若采用nesterov动量
d_p= d_p + momentum
buf
否则 d_p = buf
最后更新
p = p - lr*d_p

说明一下,L2正则的地方
l o s s = ∣ ∣ y − y ∗ ∣ ∣ + λ 2 ∣ ∣ w ∣ ∣ 2 loss=||y-y^{*}||+\frac{\lambda}{2}||w||^2 loss=yy+2λw2
求完导数后
∂ l o s s ∂ w = ∂ ∣ ∣ y − y ∗ ∣ ∣ ∂ w + λ ∗ w \frac{\partial loss}{\partial w}=\frac{\partial ||y-y^{*}||}{\partial w}+\lambda*w wloss=wyy+λw
上式右边的前半部分在loss.backward()已经求出,后半部分既是代码中进行正则化的部分,计算完毕再进行参数更新。weight_decay即是 λ \lambda λ.
普通不带动量的SGD中权重衰减和L2正则是等价的。

在这里,还需要注意,采用d_p = p.grad.data修改时,会直接修改p.grad.data的数据

你可能感兴趣的:(Pytorch使用,SGD,Pytorch,优化算法)