主要实现了一个模板类solver,而且是个抽象类。
int iter_;//在测试的时候,需要迭代的次数,即test_iter* batchsize(测试集的)=测试集的大小,测试集batchsize可以在prototxt文件里设置 int current_step_; shared_ptr<Net<Dtype> > net_; vector<shared_ptr<Net<Dtype> > > test_nets_;//test net可以有多个 vector<Callback*> callbacks_;//嵌套类,暂时还不知道它的作用 // The root solver that holds root nets (actually containing shared layers) // in data parallelism const Solver* const root_solver_; // A function that can be set by a client of the Solver to provide indication // that it wants a snapshot saved and/or to exit early. ActionCallback action_request_function_; // True iff a request to stop early was received. bool requested_early_exit_;
然后看一下主要的几个成员函数
===========================构造函数========================================== 会调用Init()方法进行初始化,即Solver scaffolding template <typename Dtype> Solver<Dtype>::Solver(const SolverParameter& param, const Solver* root_solver) : net_(), callbacks_(), root_solver_(root_solver), requested_early_exit_(false) { Init(param); }
===========================Init()方法======================================== 会调用InitTrainNet()和InitTestNet()来初始化TrainNet、TestNet template <typename Dtype> void Solver<Dtype>::Init(const SolverParameter& param) { CHECK(Caffe::root_solver() || root_solver_) << "root_solver_ needs to be set for all non-root solvers"; LOG_IF(INFO, Caffe::root_solver()) << "Initializing solver from parameters: " << std::endl << param.DebugString(); param_ = param;//为solver类的数据成员param_赋值 CHECK_GE(param_.average_loss(), 1) << "average_loss should be non-negative."; CheckSnapshotWritePermissions(); if (Caffe::root_solver() && param_.random_seed() >= 0) { Caffe::set_random_seed(param_.random_seed());//调用Caffe命名空间里的set_random_seed函数,而不是caffe类的set_random_seed函数;param_.random_seed()实际上调用的是::google::protobuf::int64 random_seed() } // Scaffolding code InitTrainNet(); if (Caffe::root_solver()) { InitTestNets(); LOG(INFO) << "Solver scaffolding done."; } iter_ = 0; current_step_ = 0; }
===============================InitTrainNet()方法========================================= template <typename Dtype> void Solver<Dtype>::InitTrainNet() { const int num_train_nets = param_.has_net() + param_.has_net_param() + param_.has_train_net() + param_.has_train_net_param(); const string& field_names = "net, net_param, train_net, train_net_param"; //只能有一个train net CHECK_GE(num_train_nets, 1) << "SolverParameter must specify a train net " << "using one of these fields: " << field_names; CHECK_LE(num_train_nets, 1) << "SolverParameter must not contain more than " << "one of these fields specifying a train_net: " << field_names; NetParameter net_param; if (param_.has_train_net_param()) { LOG_IF(INFO, Caffe::root_solver()) << "Creating training net specified in train_net_param."; net_param.CopyFrom(param_.train_net_param()); } else if (param_.has_train_net()) { LOG_IF(INFO, Caffe::root_solver()) << "Creating training net from train_net file: " << param_.train_net(); ReadNetParamsFromTextFileOrDie(param_.train_net(), &net_param); } if (param_.has_net_param()) { LOG_IF(INFO, Caffe::root_solver()) << "Creating training net specified in net_param."; net_param.CopyFrom(param_.net_param()); } if (param_.has_net()) { LOG_IF(INFO, Caffe::root_solver()) << "Creating training net from net file: " << param_.net(); ReadNetParamsFromTextFileOrDie(param_.net(), &net_param); } // Set the correct NetState. We start with the solver defaults (lowest // precedence); then, merge in any NetState specified by the net_param itself; // finally, merge in any NetState specified by the train_state (highest // precedence). NetState net_state; net_state.set_phase(TRAIN); net_state.MergeFrom(net_param.state());//从低到高获取state,最终从最高优先级SolverParameter类型中的train_state,显然这会覆盖掉之前获取的state。 net_state.MergeFrom(param_.train_state());//这里获取的state可以为Netparameter中的state赋值,然后可以根据LayerParameter中的include和exclude来确定该层是否应该包含在网络中。 net_param.mutable_state()->CopyFrom(net_state);//这是Initialize train net 的一部分工作。InitTestNets也是如此 if (Caffe::root_solver()) { net_.reset(new Net<Dtype>(net_param));//调用模板类的构造函数,进行net的初始化 } else { net_.reset(new Net<Dtype>(net_param, root_solver_->net_.get())); } }
===================================InitTestNet()方法======================================= 需要注意的是TestNet可以有多个,而TrainNet只能有一个 template <typename Dtype> void Solver<Dtype>::InitTestNets() { CHECK(Caffe::root_solver()); const bool has_net_param = param_.has_net_param(); const bool has_net_file = param_.has_net(); const int num_generic_nets = has_net_param + has_net_file; CHECK_LE(num_generic_nets, 1) << "Both net_param and net_file may not be specified."; const int num_test_net_params = param_.test_net_param_size(); const int num_test_net_files = param_.test_net_size(); const int num_test_nets = num_test_net_params + num_test_net_files; if (num_generic_nets) { CHECK_GE(param_.test_iter_size(), num_test_nets) << "test_iter must be specified for each test network."; } else { CHECK_EQ(param_.test_iter_size(), num_test_nets) << "test_iter must be specified for each test network."; } // If we have a generic net (specified by net or net_param, rather than // test_net or test_net_param), we may have an unlimited number of actual // test networks -- the actual number is given by the number of remaining // test_iters after any test nets specified by test_net_param and/or test_net // are evaluated. // 可以有多个test net const int num_generic_net_instances = param_.test_iter_size() - num_test_nets; const int num_test_net_instances = num_test_nets + num_generic_net_instances;//num_test_net_instances由num_test_nets 和 num_generic_net_instances 组成,实际上也就是param_.test_iter_size() if (param_.test_state_size()) { CHECK_EQ(param_.test_state_size(), num_test_net_instances) << "test_state must be unspecified or specified once per test net."; } if (num_test_net_instances) { CHECK_GT(param_.test_interval(), 0); } int test_net_id = 0; vector<string> sources(num_test_net_instances); vector<NetParameter> net_params(num_test_net_instances); for (int i = 0; i < num_test_net_params; ++i, ++test_net_id) { sources[test_net_id] = "test_net_param"; net_params[test_net_id].CopyFrom(param_.test_net_param(i)); } for (int i = 0; i < num_test_net_files; ++i, ++test_net_id) { sources[test_net_id] = "test_net file: " + param_.test_net(i); ReadNetParamsFromTextFileOrDie(param_.test_net(i), &net_params[test_net_id]); } const int remaining_test_nets = param_.test_iter_size() - test_net_id; if (has_net_param) { for (int i = 0; i < remaining_test_nets; ++i, ++test_net_id) { sources[test_net_id] = "net_param"; net_params[test_net_id].CopyFrom(param_.net_param()); } } if (has_net_file) { for (int i = 0; i < remaining_test_nets; ++i, ++test_net_id) { sources[test_net_id] = "net file: " + param_.net(); ReadNetParamsFromTextFileOrDie(param_.net(), &net_params[test_net_id]); } } test_nets_.resize(num_test_net_instances); for (int i = 0; i < num_test_net_instances; ++i) { // Set the correct NetState. We start with the solver defaults (lowest // precedence); then, merge in any NetState specified by the net_param // itself; finally, merge in any NetState specified by the test_state // (highest precedence). NetState net_state; net_state.set_phase(TEST); net_state.MergeFrom(net_params[i].state()); if (param_.test_state_size()) { net_state.MergeFrom(param_.test_state(i)); } net_params[i].mutable_state()->CopyFrom(net_state); LOG(INFO) << "Creating test net (#" << i << ") specified by " << sources[i]; if (Caffe::root_solver()) { test_nets_[i].reset(new Net<Dtype>(net_params[i])); } else { test_nets_[i].reset(new Net<Dtype>(net_params[i], root_solver_->test_nets_[i].get())); } test_nets_[i]->set_debug_info(param_.debug_info()); } }
=============================Step()方法============================ template <typename Dtype> void Solver<Dtype>::Step(int iters) { vector<Blob<Dtype>*> bottom_vec; const int start_iter = iter_; const int stop_iter = iter_ + iters; int average_loss = this->param_.average_loss(); vector<Dtype> losses; Dtype smoothed_loss = 0; while (iter_ < stop_iter) { // zero-init the params net_->ClearParamDiffs(); //test_initialization默认为true if (param_.test_interval() && iter_ % param_.test_interval() == 0 && (iter_ > 0 || param_.test_initialization()) && Caffe::root_solver()) { TestAll(); if (requested_early_exit_) { // Break out of the while loop because stop was requested while testing. break; } } for (int i = 0; i < callbacks_.size(); ++i) { callbacks_[i]->on_start(); } const bool display = param_.display() && iter_ % param_.display() == 0; net_->set_debug_info(display && param_.debug_info()); // accumulate the loss and gradient Dtype loss = 0; for (int i = 0; i < param_.iter_size(); ++i) { loss += net_->ForwardBackward(bottom_vec); } loss /= param_.iter_size();//accumulate(累积) gradients over `iter_size` x `batch_size` instances。默认情况下,iter_size=1,即默认情况下,一个iteratio一个batch // average the loss across iterations for smoothed reporting. // average_loss [default = 1]——> Display the loss averaged over the last average_loss iterations if (losses.size() < average_loss) { losses.push_back(loss); int size = losses.size(); smoothed_loss = (smoothed_loss * (size - 1) + loss) / size; } else { int idx = (iter_ - start_iter) % average_loss; smoothed_loss += (loss - losses[idx]) / average_loss; losses[idx] = loss; } if (display) { LOG_IF(INFO, Caffe::root_solver()) << "Iteration " << iter_ << ", loss = " << smoothed_loss; const vector<Blob<Dtype>*>& result = net_->output_blobs(); int score_index = 0; for (int j = 0; j < result.size(); ++j) { const Dtype* result_vec = result[j]->cpu_data(); const string& output_name = net_->blob_names()[net_->output_blob_indices()[j]]; const Dtype loss_weight = net_->blob_loss_weights()[net_->output_blob_indices()[j]]; for (int k = 0; k < result[j]->count(); ++k) { ostringstream loss_msg_stream; if (loss_weight) { loss_msg_stream << " (* " << loss_weight << " = " << loss_weight * result_vec[k] << " loss)"; } LOG_IF(INFO, Caffe::root_solver()) << " Train net output #" << score_index++ << ": " << output_name << " = " << result_vec[k] << loss_msg_stream.str(); } } } for (int i = 0; i < callbacks_.size(); ++i) { callbacks_[i]->on_gradients_ready(); } ApplyUpdate(); // Increment the internal iter_ counter -- its value should always indicate // the number of times the weights have been updated. ++iter_; SolverAction::Enum request = GetRequestedAction(); // Save a snapshot if needed. if ((param_.snapshot() && iter_ % param_.snapshot() == 0 && Caffe::root_solver()) || (request == SolverAction::SNAPSHOT)) { Snapshot(); } if (SolverAction::STOP == request) { requested_early_exit_ = true; // Break out of training loop. break; } } }
=================================Test()方法============================== template <typename Dtype> void Solver<Dtype>::Test(const int test_net_id) { CHECK(Caffe::root_solver()); LOG(INFO) << "Iteration " << iter_ << ", Testing net (#" << test_net_id << ")"; //检查是否有layer共享于多个网络 CHECK_NOTNULL(test_nets_[test_net_id].get())-> ShareTrainedLayersWith(net_.get()); vector<Dtype> test_score; vector<int> test_score_output_id; vector<Blob<Dtype>*> bottom_vec; const shared_ptr<Net<Dtype> >& test_net = test_nets_[test_net_id]; Dtype loss = 0; for (int i = 0; i < param_.test_iter(test_net_id); ++i) { SolverAction::Enum request = GetRequestedAction(); // Check to see if stoppage of testing/training has been requested. while (request != SolverAction::NONE) { if (SolverAction::SNAPSHOT == request) { Snapshot(); } else if (SolverAction::STOP == request) { requested_early_exit_ = true; } request = GetRequestedAction(); } if (requested_early_exit_) { // break out of test loop. break; } Dtype iter_loss; const vector<Blob<Dtype>*>& result = test_net->Forward(bottom_vec, &iter_loss); if (param_.test_compute_loss()) { loss += iter_loss; } if (i == 0) { for (int j = 0; j < result.size(); ++j) { const Dtype* result_vec = result[j]->cpu_data(); for (int k = 0; k < result[j]->count(); ++k) { test_score.push_back(result_vec[k]); test_score_output_id.push_back(j); } } } else { int idx = 0; for (int j = 0; j < result.size(); ++j) { const Dtype* result_vec = result[j]->cpu_data(); for (int k = 0; k < result[j]->count(); ++k) { test_score[idx++] += result_vec[k]; } } } } if (requested_early_exit_) { LOG(INFO) << "Test interrupted."; return; } if (param_.test_compute_loss()) { loss /= param_.test_iter(test_net_id); LOG(INFO) << "Test loss: " << loss; } for (int i = 0; i < test_score.size(); ++i) { const int output_blob_index = test_net->output_blob_indices()[test_score_output_id[i]]; const string& output_name = test_net->blob_names()[output_blob_index]; const Dtype loss_weight = test_net->blob_loss_weights()[output_blob_index]; ostringstream loss_msg_stream; const Dtype mean_score = test_score[i] / param_.test_iter(test_net_id);//求多次迭代Loss的平均值,也就是求多个batch的平局值,因为一次迭代用的是一个test batch-size 的图片 if (loss_weight) { loss_msg_stream << " (* " << loss_weight << " = " << loss_weight * mean_score << " loss)"; } LOG(INFO) << " Test net output #" << i << ": " << output_name << " = " << mean_score << loss_msg_stream.str(); } }
=======================================Solve()方法:=================================================== 对整个网络进行训练(也就是你运行Caffe训练某个模型)的时候,实际上是在运行caffe.cpp中的train( )函数,而这个函数实际上是实例化一个Solver对象,初始化后调用了Solver中的Solve( )方法 调用此方法训练网络,其中会调用Step()方法来迭代,迭代 param_.max_iter() - iter_ 次 template <typename Dtype> void Solver<Dtype>::Solve(const char* resume_file) { CHECK(Caffe::root_solver()); LOG(INFO) << "Solving " << net_->name(); LOG(INFO) << "Learning Rate Policy: " << param_.lr_policy(); // Initialize to false every time we start solving. requested_early_exit_ = false; if (resume_file) { LOG(INFO) << "Restoring previous solver status from " << resume_file; Restore(resume_file); } // For a network that is trained by the solver, no bottom or top vecs // should be given, and we will just provide dummy vecs. Step(param_.max_iter() - iter_); // If we haven't already, save a snapshot after optimization, unless // overridden by setting snapshot_after_train := false if (param_.snapshot_after_train() && (!param_.snapshot() || iter_ % param_.snapshot() != 0)) { Snapshot(); } if (requested_early_exit_) { LOG(INFO) << "Optimization stopped early."; return; } // After the optimization is done, run an additional train and test pass to // display the train and test loss/outputs if appropriate (based on the // display and test_interval settings, respectively). Unlike in the rest of // training, for the train net we only run a forward pass as we've already // updated the parameters "max_iter" times -- this final pass is only done to // display the loss, which is computed in the forward pass. if (param_.display() && iter_ % param_.display() == 0) { Dtype loss; net_->ForwardPrefilled(&loss); LOG(INFO) << "Iteration " << iter_ << ", loss = " << loss; } if (param_.test_interval() && iter_ % param_.test_interval() == 0) { TestAll(); } LOG(INFO) << "Optimization Done."; }
Snapshot()输出当前网络状态到一个文件中。 Restore()从一个文件中读入网络状态,并可以从那个状态恢复。