新版的caffe就够模块更加规范化了一点,把所有的求解方法都另外分到一个solvers里面,里面有所有的求解方法。
那我们就先看看中介部分的sgd_solvers:
头文件sgd_solvers.hpp:
#ifndef CAFFE_SGD_SOLVERS_HPP_ #define CAFFE_SGD_SOLVERS_HPP_ #include <string> #include <vector> #include "caffe/solver.hpp" namespace caffe { //SGD优化求解 template <typename Dtype> class SGDSolver : public Solver<Dtype> { public: explicit SGDSolver(const SolverParameter& param) : Solver<Dtype>(param) { PreSolve(); }//继承solver的参数,并添加PreSolve()方法 explicit SGDSolver(const string& param_file) : Solver<Dtype>(param_file) { PreSolve(); } virtual inline const char* type() const { return "SGD"; }//返回SGD类型 const vector<shared_ptr<Blob<Dtype> > >& history() { return history_; } protected: void PreSolve(); Dtype GetLearningRate();//获取学习率 virtual void ApplyUpdate(); virtual void Normalize(int param_id);//标准化 virtual void Regularize(int param_id);//正则化 virtual void ComputeUpdateValue(int param_id, Dtype rate);//计算更新值 virtual void ClipGradients();//修正梯度 //Snapshot的一系列操作 virtual void SnapshotSolverState(const string& model_filename); virtual void SnapshotSolverStateToBinaryProto(const string& model_filename); virtual void SnapshotSolverStateToHDF5(const string& model_filename); virtual void RestoreSolverStateFromHDF5(const string& state_file); virtual void RestoreSolverStateFromBinaryProto(const string& state_file); //history维护旧的动量数据。update维护更新的相关数据,而且在snapshots中是不需要的。temp维护其他信息,这些信息可能是在计算梯度或者更新时需要的,而且在snapshots中是不需要的。 vector<shared_ptr<Blob<Dtype> > > history_, update_, temp_; //禁止复制 DISABLE_COPY_AND_ASSIGN(SGDSolver); }; //Nesterov 的加速梯度法(Nesterov’s accelerated gradient)作为凸优化中最理想的方法,其收敛速度非常快。 template <typename Dtype> class NesterovSolver : public SGDSolver<Dtype> { public: explicit NesterovSolver(const SolverParameter& param) : SGDSolver<Dtype>(param) {} explicit NesterovSolver(const string& param_file) : SGDSolver<Dtype>(param_file) {} virtual inline const char* type() const { return "Nesterov"; } protected: virtual void ComputeUpdateValue(int param_id, Dtype rate); DISABLE_COPY_AND_ASSIGN(NesterovSolver); }; //自适应梯度(adaptive gradient)是基于梯度的优化方法 template <typename Dtype> class AdaGradSolver : public SGDSolver<Dtype> { public: explicit AdaGradSolver(const SolverParameter& param) : SGDSolver<Dtype>(param) { constructor_sanity_check(); } explicit AdaGradSolver(const string& param_file) : SGDSolver<Dtype>(param_file) { constructor_sanity_check(); } virtual inline const char* type() const { return "AdaGrad"; } protected: virtual void ComputeUpdateValue(int param_id, Dtype rate); void constructor_sanity_check() { CHECK_EQ(0, this->param_.momentum()) << "Momentum cannot be used with AdaGrad."; } DISABLE_COPY_AND_ASSIGN(AdaGradSolver); }; //RMSprop是Tieleman在一次 Coursera课程演讲中提出来的,也是一种基于梯度的优化方法 template <typename Dtype> class RMSPropSolver : public SGDSolver<Dtype> { public: explicit RMSPropSolver(const SolverParameter& param) : SGDSolver<Dtype>(param) { constructor_sanity_check(); } explicit RMSPropSolver(const string& param_file) : SGDSolver<Dtype>(param_file) { constructor_sanity_check(); } virtual inline const char* type() const { return "RMSProp"; } protected: virtual void ComputeUpdateValue(int param_id, Dtype rate); void constructor_sanity_check() { CHECK_EQ(0, this->param_.momentum()) << "Momentum cannot be used with RMSProp."; CHECK_GE(this->param_.rms_decay(), 0) << "rms_decay should lie between 0 and 1."; CHECK_LT(this->param_.rms_decay(), 1) << "rms_decay should lie between 0 and 1."; } DISABLE_COPY_AND_ASSIGN(RMSPropSolver); }; //AdaDelta基本思想是用一阶的方法,近似模拟二阶牛顿法。 template <typename Dtype> class AdaDeltaSolver : public SGDSolver<Dtype> { public: explicit AdaDeltaSolver(const SolverParameter& param) : SGDSolver<Dtype>(param) { AdaDeltaPreSolve(); } explicit AdaDeltaSolver(const string& param_file) : SGDSolver<Dtype>(param_file) { AdaDeltaPreSolve(); } virtual inline const char* type() const { return "AdaDelta"; } protected: void AdaDeltaPreSolve(); virtual void ComputeUpdateValue(int param_id, Dtype rate); DISABLE_COPY_AND_ASSIGN(AdaDeltaSolver); }; /** * @brief AdamSolver, an algorithm for first-order gradient-based optimization * of stochastic objective functions, based on adaptive estimates of * lower-order moments. Described in [1]. * * [1] D. P. Kingma and J. L. Ba, "ADAM: A Method for Stochastic Optimization." * arXiv preprint arXiv:1412.6980v8 (2014). */ // Adam是一种基于梯度的优化方法 template <typename Dtype> class AdamSolver : public SGDSolver<Dtype> { public: explicit AdamSolver(const SolverParameter& param) : SGDSolver<Dtype>(param) { AdamPreSolve();} explicit AdamSolver(const string& param_file) : SGDSolver<Dtype>(param_file) { AdamPreSolve(); } virtual inline const char* type() const { return "Adam"; } protected: void AdamPreSolve(); virtual void ComputeUpdateValue(int param_id, Dtype rate); DISABLE_COPY_AND_ASSIGN(AdamSolver); }; } // namespace caffe #endif // CAFFE_SGD_SOLVERS_HPP_实现代码名字是sgd_solver.cpp,没有s。
#include <string> #include <vector> #include "caffe/sgd_solvers.hpp" #include "caffe/util/hdf5.hpp" #include "caffe/util/io.hpp" #include "caffe/util/upgrade_proto.hpp" namespace caffe { //lr选择策略图形显示,看http://blog.csdn.net/langb2014/article/details/51274376 // Return the current learning rate. The currently implemented learning rate // policies are as follows: // - fixed: always return base_lr. // - step: return base_lr * gamma ^ (floor(iter / step)) // - exp: return base_lr * gamma ^ iter // - inv: return base_lr * (1 + gamma * iter) ^ (- power) // - multistep: similar to step but it allows non uniform steps defined by // stepvalue // - poly: the effective learning rate follows a polynomial decay, to be // zero by the max_iter. return base_lr (1 - iter/max_iter) ^ (power) // - sigmoid: the effective learning rate follows a sigmod decay // return base_lr ( 1/(1 + exp(-gamma * (iter - stepsize)))) // // where base_lr, max_iter, gamma, step, stepvalue and power are defined // in the solver parameter protocol buffer, and iter is the current iteration. //获取学习率,下面是不同学习率的实现函数 template <typename Dtype> Dtype SGDSolver<Dtype>::GetLearningRate() { Dtype rate; const string& lr_policy = this->param_.lr_policy(); if (lr_policy == "fixed") { rate = this->param_.base_lr(); } else if (lr_policy == "step") { this->current_step_ = this->iter_ / this->param_.stepsize(); rate = this->param_.base_lr() * pow(this->param_.gamma(), this->current_step_); } else if (lr_policy == "exp") { rate = this->param_.base_lr() * pow(this->param_.gamma(), this->iter_); } else if (lr_policy == "inv") { rate = this->param_.base_lr() * pow(Dtype(1) + this->param_.gamma() * this->iter_, - this->param_.power()); } else if (lr_policy == "multistep") { if (this->current_step_ < this->param_.stepvalue_size() && this->iter_ >= this->param_.stepvalue(this->current_step_)) { this->current_step_++; LOG(INFO) << "MultiStep Status: Iteration " << this->iter_ << ", step = " << this->current_step_; } rate = this->param_.base_lr() * pow(this->param_.gamma(), this->current_step_); } else if (lr_policy == "poly") { rate = this->param_.base_lr() * pow(Dtype(1.) - (Dtype(this->iter_) / Dtype(this->param_.max_iter())), this->param_.power()); } else if (lr_policy == "sigmoid") { rate = this->param_.base_lr() * (Dtype(1.) / (Dtype(1.) + exp(-this->param_.gamma() * (Dtype(this->iter_) - Dtype(this->param_.stepsize()))))); } else { LOG(FATAL) << "Unknown learning rate policy: " << lr_policy; } return rate; } //这个是干什么的呢?好像history维护旧的动量数据。update维护更新的相关数据,而且在snapshots中是不需要的。temp维护其他信息,这些信息可能是在计算梯度或者更新时需要的,而且在snapshots中是不需要的。前面的这几个参数输入到vector,用于后面的blob输出吧 template <typename Dtype> void SGDSolver<Dtype>::PreSolve() { // Initialize the history const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params(); history_.clear(); update_.clear(); temp_.clear(); for (int i = 0; i < net_params.size(); ++i) { const vector<int>& shape = net_params[i]->shape(); history_.push_back(shared_ptr<Blob<Dtype> >(new Blob<Dtype>(shape))); update_.push_back(shared_ptr<Blob<Dtype> >(new Blob<Dtype>(shape))); temp_.push_back(shared_ptr<Blob<Dtype> >(new Blob<Dtype>(shape))); } } //修正梯度 template <typename Dtype> void SGDSolver<Dtype>::ClipGradients() { const Dtype clip_gradients = this->param_.clip_gradients(); if (clip_gradients < 0) { return; } const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params(); Dtype sumsq_diff = 0; for (int i = 0; i < net_params.size(); ++i) { sumsq_diff += net_params[i]->sumsq_diff(); } const Dtype l2norm_diff = std::sqrt(sumsq_diff); //二范数按照scale_factor比例缩小 if (l2norm_diff > clip_gradients) { Dtype scale_factor = clip_gradients / l2norm_diff; LOG(INFO) << "Gradient clipping: scaling down gradients (L2 norm " << l2norm_diff << " > " << clip_gradients << ") " << "by scale factor " << scale_factor; for (int i = 0; i < net_params.size(); ++i) { net_params[i]->scale_diff(scale_factor); } } } //应用更新 template <typename Dtype> void SGDSolver<Dtype>::ApplyUpdate() { CHECK(Caffe::root_solver()); Dtype rate = GetLearningRate(); if (this->param_.display() && this->iter_ % this->param_.display() == 0) { LOG(INFO) << "Iteration " << this->iter_ << ", lr = " << rate; } ClipGradients(); for (int param_id = 0; param_id < this->net_->learnable_params().size(); ++param_id) { Normalize(param_id); Regularize(param_id); ComputeUpdateValue(param_id, rate); } this->net_->Update(); } //这个是归一化 template <typename Dtype> void SGDSolver<Dtype>::Normalize(int param_id) { if (this->param_.iter_size() == 1) { return; } // Scale gradient to counterbalance accumulation. const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params(); //实现归一化操作 const Dtype accum_normalization = Dtype(1.) / this->param_.iter_size(); switch (Caffe::mode()) { case Caffe::CPU: { caffe_scal(net_params[param_id]->count(), accum_normalization, net_params[param_id]->mutable_cpu_diff()); break; } case Caffe::GPU: { #ifndef CPU_ONLY caffe_gpu_scal(net_params[param_id]->count(), accum_normalization, net_params[param_id]->mutable_gpu_diff()); #else NO_GPU; #endif break; } default: LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode(); } } // template <typename Dtype> void SGDSolver<Dtype>::Regularize(int param_id) { const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params(); const vector<float>& net_params_weight_decay = this->net_->params_weight_decay(); Dtype weight_decay = this->param_.weight_decay(); string regularization_type = this->param_.regularization_type(); Dtype local_decay = weight_decay * net_params_weight_decay[param_id]; switch (Caffe::mode()) { case Caffe::CPU: { if (local_decay) { if (regularization_type == "L2") { //添加衰减权重,这一块忘记话,再看一下前面math_functions //http://blog.csdn.net/langb2014/article/details/50986678 caffe_axpy(net_params[param_id]->count(), local_decay, net_params[param_id]->cpu_data(), net_params[param_id]->mutable_cpu_diff()); } else if (regularization_type == "L1") { caffe_cpu_sign(net_params[param_id]->count(), net_params[param_id]->cpu_data(), temp_[param_id]->mutable_cpu_data()); caffe_axpy(net_params[param_id]->count(), local_decay, temp_[param_id]->cpu_data(), net_params[param_id]->mutable_cpu_diff()); } else { LOG(FATAL) << "Unknown regularization type: " << regularization_type; } } break; } case Caffe::GPU: { #ifndef CPU_ONLY if (local_decay) { if (regularization_type == "L2") { // add weight decay caffe_gpu_axpy(net_params[param_id]->count(), local_decay, net_params[param_id]->gpu_data(), net_params[param_id]->mutable_gpu_diff()); } else if (regularization_type == "L1") { caffe_gpu_sign(net_params[param_id]->count(), net_params[param_id]->gpu_data(), temp_[param_id]->mutable_gpu_data()); caffe_gpu_axpy(net_params[param_id]->count(), local_decay, temp_[param_id]->gpu_data(), net_params[param_id]->mutable_gpu_diff()); } else { LOG(FATAL) << "Unknown regularization type: " << regularization_type; } } #else NO_GPU; #endif break; } default: LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode(); } } #ifndef CPU_ONLY template <typename Dtype> void sgd_update_gpu(int N, Dtype* g, Dtype* h, Dtype momentum, Dtype local_rate); #endif //计算更新值 template <typename Dtype> void SGDSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) { const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params(); const vector<float>& net_params_lr = this->net_->params_lr(); Dtype momentum = this->param_.momentum(); Dtype local_rate = rate * net_params_lr[param_id]; // Compute the update to history, then copy it to the parameter diff. switch (Caffe::mode()) { case Caffe::CPU: { caffe_cpu_axpby(net_params[param_id]->count(), local_rate, net_params[param_id]->cpu_diff(), momentum, history_[param_id]->mutable_cpu_data()); caffe_copy(net_params[param_id]->count(), history_[param_id]->cpu_data(), net_params[param_id]->mutable_cpu_diff()); break; } case Caffe::GPU: { #ifndef CPU_ONLY sgd_update_gpu(net_params[param_id]->count(), net_params[param_id]->mutable_gpu_diff(), history_[param_id]->mutable_gpu_data(), momentum, local_rate); #else NO_GPU; #endif break; } default: LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode(); } } //Snapshot这一块不详细介绍了 template <typename Dtype> void SGDSolver<Dtype>::SnapshotSolverState(const string& model_filename) { switch (this->param_.snapshot_format()) { case caffe::SolverParameter_SnapshotFormat_BINARYPROTO: SnapshotSolverStateToBinaryProto(model_filename); break; case caffe::SolverParameter_SnapshotFormat_HDF5: SnapshotSolverStateToHDF5(model_filename); break; default: LOG(FATAL) << "Unsupported snapshot format."; } } template <typename Dtype> void SGDSolver<Dtype>::SnapshotSolverStateToBinaryProto( const string& model_filename) { SolverState state; state.set_iter(this->iter_); state.set_learned_net(model_filename); state.set_current_step(this->current_step_); state.clear_history(); for (int i = 0; i < history_.size(); ++i) { // Add history BlobProto* history_blob = state.add_history(); history_[i]->ToProto(history_blob); } string snapshot_filename = Solver<Dtype>::SnapshotFilename(".solverstate"); LOG(INFO) << "Snapshotting solver state to binary proto file " << snapshot_filename; WriteProtoToBinaryFile(state, snapshot_filename.c_str()); } template <typename Dtype> void SGDSolver<Dtype>::SnapshotSolverStateToHDF5( const string& model_filename) { string snapshot_filename = Solver<Dtype>::SnapshotFilename(".solverstate.h5"); LOG(INFO) << "Snapshotting solver state to HDF5 file " << snapshot_filename; hid_t file_hid = H5Fcreate(snapshot_filename.c_str(), H5F_ACC_TRUNC, H5P_DEFAULT, H5P_DEFAULT); CHECK_GE(file_hid, 0) << "Couldn't open " << snapshot_filename << " to save solver state."; hdf5_save_int(file_hid, "iter", this->iter_); hdf5_save_string(file_hid, "learned_net", model_filename); hdf5_save_int(file_hid, "current_step", this->current_step_); hid_t history_hid = H5Gcreate2(file_hid, "history", H5P_DEFAULT, H5P_DEFAULT, H5P_DEFAULT); CHECK_GE(history_hid, 0) << "Error saving solver state to " << snapshot_filename << "."; for (int i = 0; i < history_.size(); ++i) { ostringstream oss; oss << i; hdf5_save_nd_dataset<Dtype>(history_hid, oss.str(), *history_[i]); } H5Gclose(history_hid); H5Fclose(file_hid); } template <typename Dtype> void SGDSolver<Dtype>::RestoreSolverStateFromBinaryProto( const string& state_file) { SolverState state; ReadProtoFromBinaryFile(state_file, &state); this->iter_ = state.iter(); if (state.has_learned_net()) { NetParameter net_param; ReadNetParamsFromBinaryFileOrDie(state.learned_net().c_str(), &net_param); this->net_->CopyTrainedLayersFrom(net_param); } this->current_step_ = state.current_step(); CHECK_EQ(state.history_size(), history_.size()) << "Incorrect length of history blobs."; LOG(INFO) << "SGDSolver: restoring history"; for (int i = 0; i < history_.size(); ++i) { history_[i]->FromProto(state.history(i)); } } template <typename Dtype> void SGDSolver<Dtype>::RestoreSolverStateFromHDF5(const string& state_file) { hid_t file_hid = H5Fopen(state_file.c_str(), H5F_ACC_RDONLY, H5P_DEFAULT); CHECK_GE(file_hid, 0) << "Couldn't open solver state file " << state_file; this->iter_ = hdf5_load_int(file_hid, "iter"); if (H5LTfind_dataset(file_hid, "learned_net")) { string learned_net = hdf5_load_string(file_hid, "learned_net"); this->net_->CopyTrainedLayersFrom(learned_net); } this->current_step_ = hdf5_load_int(file_hid, "current_step"); hid_t history_hid = H5Gopen2(file_hid, "history", H5P_DEFAULT); CHECK_GE(history_hid, 0) << "Error reading history from " << state_file; int state_history_size = hdf5_get_num_links(history_hid); CHECK_EQ(state_history_size, history_.size()) << "Incorrect length of history blobs."; for (int i = 0; i < history_.size(); ++i) { ostringstream oss; oss << i; hdf5_load_nd_dataset<Dtype>(history_hid, oss.str().c_str(), 0, kMaxBlobAxes, history_[i].get()); } H5Gclose(history_hid); H5Fclose(file_hid); } INSTANTIATE_CLASS(SGDSolver); REGISTER_SOLVER_CLASS(SGD); } // namespace caffe接下来会解析学习不同的求解方法。