Caffe SGD solver代码阅读分析

代码文件:sgd_solver.cpp

sgd更新公式推导:

这里以 L2 的regularization为例

W( t+ 1) = W ( t )  − lr ∗ wd ∗ W ( t )   − lr ∗ delta(W)  −  momentum ∗( W ( t-1 ) - W ( t ) )

               = W ( t ) −  lr ∗ [ wd ∗ W ( t ) + delta(W) ] −  momentum ∗[ W ( t-1 ) - W ( t ) ]

其中 lr = base_lr lr_mult  (全局的lr  各个层的lr因子),    momentum 为slover的超级参数,

         wd = weight_decay decay_mult  (全局的wd  各个层的wd因子)


下面我们简化一下这个公式,让它跟caffe的函数对应上:

W( t+ 1) = W ( t )  −  ComputeUpdateValue(param_id, lr)

也就是说:

ComputeUpdateValue(param_id, lr) = lr ∗ [ wd ∗ W ( t ) + delta(W) ] momentum ∗[ W ( t-1 ) - W ( t ) ]

Regularize(layer_id) = wd ∗ W ( t ) + delta(W)

Normalize(layer_id)  = Delta(W) / iter_size, 针对iter_size > 1的情况下,对weight进行平均处理

推导出:

ComputeUpdateValue(param_id, lr) = lr ∗ Regularize(param_id) momentum ∗[ W ( t-1 ) - W ( t ) ]

caffe SGDSolver的三个函数ComputeUpdateValue(param_id, lr), Regularize(param_id) Normalize(param_id) 的作用就在这个公式里了。

param_id为每个weight blog的id, caffe 循环更新每个blob

其中Delta(W) 为每个weight blog 在Backward()过程中计算出来的Blob::cpu_diff()

代码分析:

template 
void SGDSolver::ApplyUpdate() {
  Dtype rate = GetLearningRate();
  if (this->param_.display() && this->iter_ % this->param_.display() == 0) {
    LOG_IF(INFO, Caffe::root_solver()) << "Iteration " << this->iter_
        << ", lr = " << rate;
  }
  ClipGradients();
  for (int param_id = 0; param_id < this->net_->learnable_params().size();
       ++param_id) {
    Normalize(param_id);
    Regularize(param_id);
    ComputeUpdateValue(param_id, rate);
  }
  this->net_->Update();    //真正的执行权重更新 Blob::mutable_cpu_data() = Blob::cpu_data() - Blob::cpu_diff()

你可能感兴趣的:(深度学习)