新版的caffe就够模块更加规范化了一点,把所有的求解方法都另外分到一个solvers里面,里面有所有的求解方法。
那我们就先看看中介部分的sgd_solvers:
头文件sgd_solvers.hpp:
#ifndef CAFFE_SGD_SOLVERS_HPP_
#define CAFFE_SGD_SOLVERS_HPP_
#include
#include
#include "caffe/solver.hpp"
namespace caffe {
//SGD优化求解
template
class SGDSolver : public Solver {
public:
explicit SGDSolver(const SolverParameter& param)
: Solver(param) { PreSolve(); }//继承solver的参数,并添加PreSolve()方法
explicit SGDSolver(const string& param_file)
: Solver(param_file) { PreSolve(); }
virtual inline const char* type() const { return "SGD"; }//返回SGD类型
const vector > >& history() { return history_; }
protected:
void PreSolve();
Dtype GetLearningRate();//获取学习率
virtual void ApplyUpdate();
virtual void Normalize(int param_id);//标准化
virtual void Regularize(int param_id);//正则化
virtual void ComputeUpdateValue(int param_id, Dtype rate);//计算更新值
virtual void ClipGradients();//修正梯度
//Snapshot的一系列操作
virtual void SnapshotSolverState(const string& model_filename);
virtual void SnapshotSolverStateToBinaryProto(const string& model_filename);
virtual void SnapshotSolverStateToHDF5(const string& model_filename);
virtual void RestoreSolverStateFromHDF5(const string& state_file);
virtual void RestoreSolverStateFromBinaryProto(const string& state_file);
//history维护旧的动量数据。update维护更新的相关数据,而且在snapshots中是不需要的。temp维护其他信息,这些信息可能是在计算梯度或者更新时需要的,而且在snapshots中是不需要的。
vector > > history_, update_, temp_;
//禁止复制
DISABLE_COPY_AND_ASSIGN(SGDSolver);
};
//Nesterov 的加速梯度法(Nesterov’s accelerated gradient)作为凸优化中最理想的方法,其收敛速度非常快。
template
class NesterovSolver : public SGDSolver {
public:
explicit NesterovSolver(const SolverParameter& param)
: SGDSolver(param) {}
explicit NesterovSolver(const string& param_file)
: SGDSolver(param_file) {}
virtual inline const char* type() const { return "Nesterov"; }
protected:
virtual void ComputeUpdateValue(int param_id, Dtype rate);
DISABLE_COPY_AND_ASSIGN(NesterovSolver);
};
//自适应梯度(adaptive gradient)是基于梯度的优化方法
template
class AdaGradSolver : public SGDSolver {
public:
explicit AdaGradSolver(const SolverParameter& param)
: SGDSolver(param) { constructor_sanity_check(); }
explicit AdaGradSolver(const string& param_file)
: SGDSolver(param_file) { constructor_sanity_check(); }
virtual inline const char* type() const { return "AdaGrad"; }
protected:
virtual void ComputeUpdateValue(int param_id, Dtype rate);
void constructor_sanity_check() {
CHECK_EQ(0, this->param_.momentum())
<< "Momentum cannot be used with AdaGrad.";
}
DISABLE_COPY_AND_ASSIGN(AdaGradSolver);
};
//RMSprop是Tieleman在一次 Coursera课程演讲中提出来的,也是一种基于梯度的优化方法
template
class RMSPropSolver : public SGDSolver {
public:
explicit RMSPropSolver(const SolverParameter& param)
: SGDSolver(param) { constructor_sanity_check(); }
explicit RMSPropSolver(const string& param_file)
: SGDSolver(param_file) { constructor_sanity_check(); }
virtual inline const char* type() const { return "RMSProp"; }
protected:
virtual void ComputeUpdateValue(int param_id, Dtype rate);
void constructor_sanity_check() {
CHECK_EQ(0, this->param_.momentum())
<< "Momentum cannot be used with RMSProp.";
CHECK_GE(this->param_.rms_decay(), 0)
<< "rms_decay should lie between 0 and 1.";
CHECK_LT(this->param_.rms_decay(), 1)
<< "rms_decay should lie between 0 and 1.";
}
DISABLE_COPY_AND_ASSIGN(RMSPropSolver);
};
//AdaDelta基本思想是用一阶的方法,近似模拟二阶牛顿法。
template
class AdaDeltaSolver : public SGDSolver {
public:
explicit AdaDeltaSolver(const SolverParameter& param)
: SGDSolver(param) { AdaDeltaPreSolve(); }
explicit AdaDeltaSolver(const string& param_file)
: SGDSolver(param_file) { AdaDeltaPreSolve(); }
virtual inline const char* type() const { return "AdaDelta"; }
protected:
void AdaDeltaPreSolve();
virtual void ComputeUpdateValue(int param_id, Dtype rate);
DISABLE_COPY_AND_ASSIGN(AdaDeltaSolver);
};
/**
* @brief AdamSolver, an algorithm for first-order gradient-based optimization
* of stochastic objective functions, based on adaptive estimates of
* lower-order moments. Described in [1].
*
* [1] D. P. Kingma and J. L. Ba, "ADAM: A Method for Stochastic Optimization."
* arXiv preprint arXiv:1412.6980v8 (2014).
*/
// Adam是一种基于梯度的优化方法
template
class AdamSolver : public SGDSolver {
public:
explicit AdamSolver(const SolverParameter& param)
: SGDSolver(param) { AdamPreSolve();}
explicit AdamSolver(const string& param_file)
: SGDSolver(param_file) { AdamPreSolve(); }
virtual inline const char* type() const { return "Adam"; }
protected:
void AdamPreSolve();
virtual void ComputeUpdateValue(int param_id, Dtype rate);
DISABLE_COPY_AND_ASSIGN(AdamSolver);
};
} // namespace caffe
#endif // CAFFE_SGD_SOLVERS_HPP_
实现代码名字是sgd_solver.cpp,没有s。
#include
#include
#include "caffe/sgd_solvers.hpp"
#include "caffe/util/hdf5.hpp"
#include "caffe/util/io.hpp"
#include "caffe/util/upgrade_proto.hpp"
namespace caffe {
//lr选择策略图形显示,看http://blog.csdn.net/langb2014/article/details/51274376
// Return the current learning rate. The currently implemented learning rate
// policies are as follows:
// - fixed: always return base_lr.
// - step: return base_lr * gamma ^ (floor(iter / step))
// - exp: return base_lr * gamma ^ iter
// - inv: return base_lr * (1 + gamma * iter) ^ (- power)
// - multistep: similar to step but it allows non uniform steps defined by
// stepvalue
// - poly: the effective learning rate follows a polynomial decay, to be
// zero by the max_iter. return base_lr (1 - iter/max_iter) ^ (power)
// - sigmoid: the effective learning rate follows a sigmod decay
// return base_lr ( 1/(1 + exp(-gamma * (iter - stepsize))))
//
// where base_lr, max_iter, gamma, step, stepvalue and power are defined
// in the solver parameter protocol buffer, and iter is the current iteration.
//获取学习率,下面是不同学习率的实现函数
template
Dtype SGDSolver::GetLearningRate() {
Dtype rate;
const string& lr_policy = this->param_.lr_policy();
if (lr_policy == "fixed") {
rate = this->param_.base_lr();
} else if (lr_policy == "step") {
this->current_step_ = this->iter_ / this->param_.stepsize();
rate = this->param_.base_lr() *
pow(this->param_.gamma(), this->current_step_);
} else if (lr_policy == "exp") {
rate = this->param_.base_lr() * pow(this->param_.gamma(), this->iter_);
} else if (lr_policy == "inv") {
rate = this->param_.base_lr() *
pow(Dtype(1) + this->param_.gamma() * this->iter_,
- this->param_.power());
} else if (lr_policy == "multistep") {
if (this->current_step_ < this->param_.stepvalue_size() &&
this->iter_ >= this->param_.stepvalue(this->current_step_)) {
this->current_step_++;
LOG(INFO) << "MultiStep Status: Iteration " <<
this->iter_ << ", step = " << this->current_step_;
}
rate = this->param_.base_lr() *
pow(this->param_.gamma(), this->current_step_);
} else if (lr_policy == "poly") {
rate = this->param_.base_lr() * pow(Dtype(1.) -
(Dtype(this->iter_) / Dtype(this->param_.max_iter())),
this->param_.power());
} else if (lr_policy == "sigmoid") {
rate = this->param_.base_lr() * (Dtype(1.) /
(Dtype(1.) + exp(-this->param_.gamma() * (Dtype(this->iter_) -
Dtype(this->param_.stepsize())))));
} else {
LOG(FATAL) << "Unknown learning rate policy: " << lr_policy;
}
return rate;
}
//这个是干什么的呢?好像history维护旧的动量数据。update维护更新的相关数据,而且在snapshots中是不需要的。temp维护其他信息,这些信息可能是在计算梯度或者更新时需要的,而且在snapshots中是不需要的。前面的这几个参数输入到vector,用于后面的blob输出吧
template
void SGDSolver::PreSolve() {
// Initialize the history
const vector*>& net_params = this->net_->learnable_params();
history_.clear();
update_.clear();
temp_.clear();
for (int i = 0; i < net_params.size(); ++i) {
const vector& shape = net_params[i]->shape();
history_.push_back(shared_ptr >(new Blob(shape)));
update_.push_back(shared_ptr >(new Blob(shape)));
temp_.push_back(shared_ptr >(new Blob(shape)));
}
}
//修正梯度
template
void SGDSolver::ClipGradients() {
const Dtype clip_gradients = this->param_.clip_gradients();
if (clip_gradients < 0) { return; }
const vector*>& net_params = this->net_->learnable_params();
Dtype sumsq_diff = 0;
for (int i = 0; i < net_params.size(); ++i) {
sumsq_diff += net_params[i]->sumsq_diff();
}
const Dtype l2norm_diff = std::sqrt(sumsq_diff);
//二范数按照scale_factor比例缩小
if (l2norm_diff > clip_gradients) {
Dtype scale_factor = clip_gradients / l2norm_diff;
LOG(INFO) << "Gradient clipping: scaling down gradients (L2 norm "
<< l2norm_diff << " > " << clip_gradients << ") "
<< "by scale factor " << scale_factor;
for (int i = 0; i < net_params.size(); ++i) {
net_params[i]->scale_diff(scale_factor);
}
}
}
//应用更新
template
void SGDSolver::ApplyUpdate() {
CHECK(Caffe::root_solver());
Dtype rate = GetLearningRate();
if (this->param_.display() && this->iter_ % this->param_.display() == 0) {
LOG(INFO) << "Iteration " << this->iter_ << ", lr = " << rate;
}
ClipGradients();
for (int param_id = 0; param_id < this->net_->learnable_params().size();
++param_id) {
Normalize(param_id);
Regularize(param_id);
ComputeUpdateValue(param_id, rate);
}
this->net_->Update();
}
//这个是归一化
template
void SGDSolver::Normalize(int param_id) {
if (this->param_.iter_size() == 1) { return; }
// Scale gradient to counterbalance accumulation.
const vector*>& net_params = this->net_->learnable_params();
//实现归一化操作
const Dtype accum_normalization = Dtype(1.) / this->param_.iter_size();
switch (Caffe::mode()) {
case Caffe::CPU: {
caffe_scal(net_params[param_id]->count(), accum_normalization,
net_params[param_id]->mutable_cpu_diff());
break;
}
case Caffe::GPU: {
#ifndef CPU_ONLY
caffe_gpu_scal(net_params[param_id]->count(), accum_normalization,
net_params[param_id]->mutable_gpu_diff());
#else
NO_GPU;
#endif
break;
}
default:
LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
}
}
//
template
void SGDSolver::Regularize(int param_id) {
const vector*>& net_params = this->net_->learnable_params();
const vector& net_params_weight_decay =
this->net_->params_weight_decay();
Dtype weight_decay = this->param_.weight_decay();
string regularization_type = this->param_.regularization_type();
Dtype local_decay = weight_decay * net_params_weight_decay[param_id];
switch (Caffe::mode()) {
case Caffe::CPU: {
if (local_decay) {
if (regularization_type == "L2") {
//添加衰减权重,这一块忘记话,再看一下前面math_functions
//http://blog.csdn.net/langb2014/article/details/50986678
caffe_axpy(net_params[param_id]->count(),
local_decay,
net_params[param_id]->cpu_data(),
net_params[param_id]->mutable_cpu_diff());
} else if (regularization_type == "L1") {
caffe_cpu_sign(net_params[param_id]->count(),
net_params[param_id]->cpu_data(),
temp_[param_id]->mutable_cpu_data());
caffe_axpy(net_params[param_id]->count(),
local_decay,
temp_[param_id]->cpu_data(),
net_params[param_id]->mutable_cpu_diff());
} else {
LOG(FATAL) << "Unknown regularization type: " << regularization_type;
}
}
break;
}
case Caffe::GPU: {
#ifndef CPU_ONLY
if (local_decay) {
if (regularization_type == "L2") {
// add weight decay
caffe_gpu_axpy(net_params[param_id]->count(),
local_decay,
net_params[param_id]->gpu_data(),
net_params[param_id]->mutable_gpu_diff());
} else if (regularization_type == "L1") {
caffe_gpu_sign(net_params[param_id]->count(),
net_params[param_id]->gpu_data(),
temp_[param_id]->mutable_gpu_data());
caffe_gpu_axpy(net_params[param_id]->count(),
local_decay,
temp_[param_id]->gpu_data(),
net_params[param_id]->mutable_gpu_diff());
} else {
LOG(FATAL) << "Unknown regularization type: " << regularization_type;
}
}
#else
NO_GPU;
#endif
break;
}
default:
LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
}
}
#ifndef CPU_ONLY
template
void sgd_update_gpu(int N, Dtype* g, Dtype* h, Dtype momentum,
Dtype local_rate);
#endif
//计算更新值
template
void SGDSolver::ComputeUpdateValue(int param_id, Dtype rate) {
const vector*>& net_params = this->net_->learnable_params();
const vector& net_params_lr = this->net_->params_lr();
Dtype momentum = this->param_.momentum();
Dtype local_rate = rate * net_params_lr[param_id];
// Compute the update to history, then copy it to the parameter diff.
switch (Caffe::mode()) {
case Caffe::CPU: {
caffe_cpu_axpby(net_params[param_id]->count(), local_rate,
net_params[param_id]->cpu_diff(), momentum,
history_[param_id]->mutable_cpu_data());
caffe_copy(net_params[param_id]->count(),
history_[param_id]->cpu_data(),
net_params[param_id]->mutable_cpu_diff());
break;
}
case Caffe::GPU: {
#ifndef CPU_ONLY
sgd_update_gpu(net_params[param_id]->count(),
net_params[param_id]->mutable_gpu_diff(),
history_[param_id]->mutable_gpu_data(),
momentum, local_rate);
#else
NO_GPU;
#endif
break;
}
default:
LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
}
}
//Snapshot这一块不详细介绍了
template
void SGDSolver::SnapshotSolverState(const string& model_filename) {
switch (this->param_.snapshot_format()) {
case caffe::SolverParameter_SnapshotFormat_BINARYPROTO:
SnapshotSolverStateToBinaryProto(model_filename);
break;
case caffe::SolverParameter_SnapshotFormat_HDF5:
SnapshotSolverStateToHDF5(model_filename);
break;
default:
LOG(FATAL) << "Unsupported snapshot format.";
}
}
template
void SGDSolver::SnapshotSolverStateToBinaryProto(
const string& model_filename) {
SolverState state;
state.set_iter(this->iter_);
state.set_learned_net(model_filename);
state.set_current_step(this->current_step_);
state.clear_history();
for (int i = 0; i < history_.size(); ++i) {
// Add history
BlobProto* history_blob = state.add_history();
history_[i]->ToProto(history_blob);
}
string snapshot_filename = Solver::SnapshotFilename(".solverstate");
LOG(INFO)
<< "Snapshotting solver state to binary proto file " << snapshot_filename;
WriteProtoToBinaryFile(state, snapshot_filename.c_str());
}
template
void SGDSolver::SnapshotSolverStateToHDF5(
const string& model_filename) {
string snapshot_filename =
Solver::SnapshotFilename(".solverstate.h5");
LOG(INFO) << "Snapshotting solver state to HDF5 file " << snapshot_filename;
hid_t file_hid = H5Fcreate(snapshot_filename.c_str(), H5F_ACC_TRUNC,
H5P_DEFAULT, H5P_DEFAULT);
CHECK_GE(file_hid, 0)
<< "Couldn't open " << snapshot_filename << " to save solver state.";
hdf5_save_int(file_hid, "iter", this->iter_);
hdf5_save_string(file_hid, "learned_net", model_filename);
hdf5_save_int(file_hid, "current_step", this->current_step_);
hid_t history_hid = H5Gcreate2(file_hid, "history", H5P_DEFAULT, H5P_DEFAULT,
H5P_DEFAULT);
CHECK_GE(history_hid, 0)
<< "Error saving solver state to " << snapshot_filename << ".";
for (int i = 0; i < history_.size(); ++i) {
ostringstream oss;
oss << i;
hdf5_save_nd_dataset(history_hid, oss.str(), *history_[i]);
}
H5Gclose(history_hid);
H5Fclose(file_hid);
}
template
void SGDSolver::RestoreSolverStateFromBinaryProto(
const string& state_file) {
SolverState state;
ReadProtoFromBinaryFile(state_file, &state);
this->iter_ = state.iter();
if (state.has_learned_net()) {
NetParameter net_param;
ReadNetParamsFromBinaryFileOrDie(state.learned_net().c_str(), &net_param);
this->net_->CopyTrainedLayersFrom(net_param);
}
this->current_step_ = state.current_step();
CHECK_EQ(state.history_size(), history_.size())
<< "Incorrect length of history blobs.";
LOG(INFO) << "SGDSolver: restoring history";
for (int i = 0; i < history_.size(); ++i) {
history_[i]->FromProto(state.history(i));
}
}
template
void SGDSolver::RestoreSolverStateFromHDF5(const string& state_file) {
hid_t file_hid = H5Fopen(state_file.c_str(), H5F_ACC_RDONLY, H5P_DEFAULT);
CHECK_GE(file_hid, 0) << "Couldn't open solver state file " << state_file;
this->iter_ = hdf5_load_int(file_hid, "iter");
if (H5LTfind_dataset(file_hid, "learned_net")) {
string learned_net = hdf5_load_string(file_hid, "learned_net");
this->net_->CopyTrainedLayersFrom(learned_net);
}
this->current_step_ = hdf5_load_int(file_hid, "current_step");
hid_t history_hid = H5Gopen2(file_hid, "history", H5P_DEFAULT);
CHECK_GE(history_hid, 0) << "Error reading history from " << state_file;
int state_history_size = hdf5_get_num_links(history_hid);
CHECK_EQ(state_history_size, history_.size())
<< "Incorrect length of history blobs.";
for (int i = 0; i < history_.size(); ++i) {
ostringstream oss;
oss << i;
hdf5_load_nd_dataset(history_hid, oss.str().c_str(), 0,
kMaxBlobAxes, history_[i].get());
}
H5Gclose(history_hid);
H5Fclose(file_hid);
}
INSTANTIATE_CLASS(SGDSolver);
REGISTER_SOLVER_CLASS(SGD);
} // namespace caffe
接下来会解析学习不同的求解方法。