之前有一篇介绍solver的求解,也可以看官网的介绍:here ,和翻译版的介绍。
solver.hpp头文件的简单解析:
#ifndef CAFFE_SOLVER_HPP_
#define CAFFE_SOLVER_HPP_
#include
#include
#include
#include "caffe/net.hpp"
#include "caffe/solver_factory.hpp"
namespace caffe {
/**
* @brief Enumeration of actions that a client of the Solver may request by
* implementing the Solver's action request function, which a
* a client may optionally provide in order to request early termination
* or saving a snapshot without exiting. In the executable caffe, this
* mechanism is used to allow the snapshot to be saved when stopping
* execution with a SIGINT (Ctrl-C).
*/
//大概意思就是按Ctrl-C时,会保存当前训练时的模型,如果还在训练终端不小心被关闭时,可以接着上次继续训练
namespace SolverAction {
enum Enum {
NONE = 0, // Take no special action.
STOP = 1, //停止训练后,可以继续训练
SNAPSHOT = 2 // Take a snapshot, and keep training.
};
}
/**
* @brief Type of a function that returns a Solver Action enumeration.
*/
//学过java的可以理解为回滚操作,比如银行账户钱从一个用户转到另一个账户时,中途发生点意外,一个用户钱已经减了,另一个却没有增加,这时需要回滚操作,
//就像这时训练的时候中断了,然后回滚,到上次断点,继续训练。
typedef boost::function ActionCallback;
/**
* @brief An interface for classes that perform optimization on Net%s.
*
* Requires implementation of ApplyUpdate to compute a parameter update
* given the current state of the Net parameters.
*/
template
class Solver {
public:
explicit Solver(const SolverParameter& param,
const Solver* root_solver = NULL);
explicit Solver(const string& param_file, const Solver* root_solver = NULL);
void Init(const SolverParameter& param);
void InitTrainNet();
void InitTestNets();
// Client of the Solver optionally may call this in order to set the function
// that the solver uses to see what action it should take (e.g. snapshot or
// exit training early).
void SetActionFunction(ActionCallback func);
SolverAction::Enum GetRequestedAction();
//主函数,默认iter为0,非0的iter输入到预训练的网络中。
virtual void Solve(const char* resume_file = NULL);
inline void Solve(const string resume_file) { Solve(resume_file.c_str()); }
void Step(int iters);
// The Restore method simply dispatches to one of the
// RestoreSolverStateFrom___ protected methods. You should implement these
// methods to restore the state from the appropriate snapshot type.
//存储函数实现如何存储solver到快照模型中。应该实现RestoreSolverState()函数这个函数是存储来自SolverState缓冲的状态
void Restore(const char* resume_file);
//Solver::Snapshot主要是基本的快照功能,存储学习的网络
void Snapshot();
virtual ~Solver() {}
inline const SolverParameter& param() const { return param_; }
inline shared_ptr > net() { return net_; }
inline const vector > >& test_nets() {
return test_nets_;
}
int iter() { return iter_; }
//在迭代中调用特殊的点
class Callback {
protected:
virtual void on_start() = 0;
virtual void on_gradients_ready() = 0;
template
friend class Solver;
};
const vector& callbacks() const { return callbacks_; }
void add_callback(Callback* value) {
callbacks_.push_back(value);
}
void CheckSnapshotWritePermissions();
//返回slover的类型
virtual inline const char* type() const { return ""; }
protected:
//生成和应用当前迭代的更新的值
virtual void ApplyUpdate() = 0;
string SnapshotFilename(const string extension);
string SnapshotToBinaryProto();
string SnapshotToHDF5();
// 测试程序
void TestAll();
void Test(const int test_net_id = 0);
virtual void SnapshotSolverState(const string& model_filename) = 0;
virtual void RestoreSolverStateFromHDF5(const string& state_file) = 0;
virtual void RestoreSolverStateFromBinaryProto(const string& state_file) = 0;
void DisplayOutputBlobs(const int net_id);
void UpdateSmoothedLoss(Dtype loss, int start_iter, int average_loss);
SolverParameter param_;
int iter_;//在测试的时候,需要迭代的次数,即test_iter* batchsize(测试集的)=测试集的大小,测试集batchsize可以在prototxt文件里设置
int current_step_;
shared_ptr > net_;
vector > > test_nets_;//test net可以有多个
vector callbacks_;//嵌套类,暂时还不知道它的作用
vector losses_;
Dtype smoothed_loss_;
//在数据并行中,继续根solver层保持根nets(包含共享的层)
const Solver* const root_solver_;
//通过函数是选择确认按钮来选择保存还是退出快照。
ActionCallback action_request_function_;
// True iff a request to stop early was received.
bool requested_early_exit_;
DISABLE_COPY_AND_ASSIGN(Solver);
};
//在多GPU计算时,仅仅计算梯度
template
class WorkerSolver : public Solver {
public:
explicit WorkerSolver(const SolverParameter& param,
const Solver* root_solver = NULL)
: Solver(param, root_solver) {}
protected:
void ApplyUpdate() {}
void SnapshotSolverState(const string& model_filename) {
LOG(FATAL) << "Should not be called on worker solver.";
}
void RestoreSolverStateFromBinaryProto(const string& state_file) {
LOG(FATAL) << "Should not be called on worker solver.";
}
void RestoreSolverStateFromHDF5(const string& state_file) {
LOG(FATAL) << "Should not be called on worker solver.";
}
};
} // namespace caffe
#endif // CAFFE_SOLVER_HPP_
实现部分:
#include
#include
#include
#include "caffe/solver.hpp"
#include "caffe/util/format.hpp"
#include "caffe/util/hdf5.hpp"
#include "caffe/util/io.hpp"
#include "caffe/util/upgrade_proto.hpp"
namespace caffe {
template
void Solver::SetActionFunction(ActionCallback func) {
action_request_function_ = func;
}
template
SolverAction::Enum Solver::GetRequestedAction() {
if (action_request_function_) {
// If the external request function has been set, call it.
return action_request_function_();
}
return SolverAction::NONE;
}
template
Solver::Solver(const SolverParameter& param, const Solver* root_solver)
: net_(), callbacks_(), root_solver_(root_solver),
requested_early_exit_(false) {
Init(param);
}
//会调用Init()方法进行初始化,即Solver scaffolding
template
Solver::Solver(const string& param_file, const Solver* root_solver)
: net_(), callbacks_(), root_solver_(root_solver),
requested_early_exit_(false) {
SolverParameter param;
ReadSolverParamsFromTextFileOrDie(param_file, ¶m);
Init(param);
}
/*
功能:初始化网络
步骤:
1. 设置随机数种子
2. 申请一块Net空间以下面的构造函数进行初始化
param_file=train_net_,net_指向这块空间
3. 如果有test_net,则申请一块Net空间,test_net_指向这块空间
输入:SolverParameter类型的param
输出:无
*/
template
void Solver::Init(const SolverParameter& param) {
CHECK(Caffe::root_solver() || root_solver_)
<< "root_solver_ needs to be set for all non-root solvers";
LOG_IF(INFO, Caffe::root_solver()) << "Initializing solver from parameters: "
<< std::endl << param.DebugString();
param_ = param;//为solver类的数据成员param_赋值
CHECK_GE(param_.average_loss(), 1) << "average_loss should be non-negative.";
CheckSnapshotWritePermissions();
if (Caffe::root_solver() && param_.random_seed() >= 0) {
Caffe::set_random_seed(param_.random_seed());
//调用Caffe命名空间里的set_random_seed函数,而不是caffe类的set_random_seed函数;param_.random_seed()实际
//上调用的是::google::protobuf::int64 random_seed()
}
// Scaffolding code
InitTrainNet();
if (Caffe::root_solver()) {
InitTestNets();
LOG(INFO) << "Solver scaffolding done.";
}
iter_ = 0;
current_step_ = 0;
}
template
void Solver::InitTrainNet() {
const int num_train_nets = param_.has_net() + param_.has_net_param() +
param_.has_train_net() + param_.has_train_net_param();
const string& field_names = "net, net_param, train_net, train_net_param";
//只能有一个train net
CHECK_GE(num_train_nets, 1) << "SolverParameter must specify a train net "
<< "using one of these fields: " << field_names;
CHECK_LE(num_train_nets, 1) << "SolverParameter must not contain more than "
<< "one of these fields specifying a train_net: " << field_names;
NetParameter net_param;
if (param_.has_train_net_param()) {
LOG_IF(INFO, Caffe::root_solver())
<< "Creating training net specified in train_net_param.";
net_param.CopyFrom(param_.train_net_param());
} else if (param_.has_train_net()) {
LOG_IF(INFO, Caffe::root_solver())
<< "Creating training net from train_net file: " << param_.train_net();
ReadNetParamsFromTextFileOrDie(param_.train_net(), &net_param);
}
if (param_.has_net_param()) {
LOG_IF(INFO, Caffe::root_solver())
<< "Creating training net specified in net_param.";
net_param.CopyFrom(param_.net_param());
}
if (param_.has_net()) {
LOG_IF(INFO, Caffe::root_solver())
<< "Creating training net from net file: " << param_.net();
ReadNetParamsFromTextFileOrDie(param_.net(), &net_param);
}
//设置正确的网络状态,训练从默认开始,然后融入通过网络层规定在任何状态,最后融入训练状态(最优解)
NetState net_state;
net_state.set_phase(TRAIN);
//从低到高获取state,最终从最高优先级SolverParameter类型中的train_state,显然这会覆盖掉之前获取的state。
net_state.MergeFrom(net_param.state());
//这里获取的state可以为Netparameter中的state赋值,然后可以根据LayerParameter中的include和exclude来确定该层是否应该包含在网络中。
net_state.MergeFrom(param_.train_state());
//这是Initialize train net 的一部分工作。InitTestNets也是如此
net_param.mutable_state()->CopyFrom(net_state);
if (Caffe::root_solver()) {
//调用模板类的构造函数,进行net的初始化
net_.reset(new Net(net_param));
} else {
net_.reset(new Net(net_param, root_solver_->net_.get()));
}
}
//需要注意的是TestNet可以有多个,而TrainNet只能有一个
template
void Solver::InitTestNets() {
CHECK(Caffe::root_solver());
const bool has_net_param = param_.has_net_param();
const bool has_net_file = param_.has_net();
const int num_generic_nets = has_net_param + has_net_file;
CHECK_LE(num_generic_nets, 1)
<< "Both net_param and net_file may not be specified.";
const int num_test_net_params = param_.test_net_param_size();
const int num_test_net_files = param_.test_net_size();
const int num_test_nets = num_test_net_params + num_test_net_files;
if (num_generic_nets) {
CHECK_GE(param_.test_iter_size(), num_test_nets)
<< "test_iter must be specified for each test network.";
} else {
CHECK_EQ(param_.test_iter_size(), num_test_nets)
<< "test_iter must be specified for each test network.";
}
//可以有多个test net
const int num_generic_net_instances = param_.test_iter_size() - num_test_nets;
//num_test_net_instances由num_test_nets 和 num_generic_net_instances 组成,实际上也就是param_.test_iter_size()
const int num_test_net_instances = num_test_nets + num_generic_net_instances;
if (param_.test_state_size()) {
CHECK_EQ(param_.test_state_size(), num_test_net_instances)
<< "test_state must be unspecified or specified once per test net.";
}
if (num_test_net_instances) {
CHECK_GT(param_.test_interval(), 0);
}
int test_net_id = 0;
vector sources(num_test_net_instances);
vector net_params(num_test_net_instances);
for (int i = 0; i < num_test_net_params; ++i, ++test_net_id) {
sources[test_net_id] = "test_net_param";
net_params[test_net_id].CopyFrom(param_.test_net_param(i));
}
for (int i = 0; i < num_test_net_files; ++i, ++test_net_id) {
sources[test_net_id] = "test_net file: " + param_.test_net(i);
ReadNetParamsFromTextFileOrDie(param_.test_net(i),
&net_params[test_net_id]);
}
const int remaining_test_nets = param_.test_iter_size() - test_net_id;
if (has_net_param) {
for (int i = 0; i < remaining_test_nets; ++i, ++test_net_id) {
sources[test_net_id] = "net_param";
net_params[test_net_id].CopyFrom(param_.net_param());
}
}
if (has_net_file) {
for (int i = 0; i < remaining_test_nets; ++i, ++test_net_id) {
sources[test_net_id] = "net file: " + param_.net();
ReadNetParamsFromTextFileOrDie(param_.net(), &net_params[test_net_id]);
}
}
test_nets_.resize(num_test_net_instances);
for (int i = 0; i < num_test_net_instances; ++i) {
//设置正确的网络状态,训练从默认开始,然后融入通过网络层规定在任何状态,最后融入测试状态(最优解)
NetState net_state;
net_state.set_phase(TEST);
net_state.MergeFrom(net_params[i].state());
if (param_.test_state_size()) {
net_state.MergeFrom(param_.test_state(i));
}
net_params[i].mutable_state()->CopyFrom(net_state);
LOG(INFO)
<< "Creating test net (#" << i << ") specified by " << sources[i];
if (Caffe::root_solver()) {
test_nets_[i].reset(new Net(net_params[i]));
} else {
test_nets_[i].reset(new Net(net_params[i],
root_solver_->test_nets_[i].get()));
}
test_nets_[i]->set_debug_info(param_.debug_info());
}
}
template
void Solver::Step(int iters) {
const int start_iter = iter_;
const int stop_iter = iter_ + iters;
int average_loss = this->param_.average_loss();
losses_.clear();
smoothed_loss_ = 0;
while (iter_ < stop_iter) {
// 0初始化参数
net_->ClearParamDiffs();
//test_initialization默认为true
if (param_.test_interval() && iter_ % param_.test_interval() == 0
&& (iter_ > 0 || param_.test_initialization())
&& Caffe::root_solver()) {
TestAll();
if (requested_early_exit_) {
// Break out of the while loop because stop was requested while testing.
break;
}
}
for (int i = 0; i < callbacks_.size(); ++i) {
callbacks_[i]->on_start();
}
const bool display = param_.display() && iter_ % param_.display() == 0;
net_->set_debug_info(display && param_.debug_info());
// accumulate the loss and gradient
Dtype loss = 0;
for (int i = 0; i < param_.iter_size(); ++i) {
loss += net_->ForwardBackward();
}
loss /= param_.iter_size();//accumulate(累积) gradients over `iter_size` x `batch_size` instances。默认情况下,iter_size=1,即默认情况下,一个iteratio一个batch
// average the loss across iterations for smoothed reporting
UpdateSmoothedLoss(loss, start_iter, average_loss);
// average_loss [default = 1]——> Display the loss averaged over the last average_loss iterations
if (display) {
LOG_IF(INFO, Caffe::root_solver()) << "Iteration " << iter_
<< ", loss = " << smoothed_loss_;
const vector*>& result = net_->output_blobs();
int score_index = 0;
for (int j = 0; j < result.size(); ++j) {
const Dtype* result_vec = result[j]->cpu_data();
const string& output_name =
net_->blob_names()[net_->output_blob_indices()[j]];
const Dtype loss_weight =
net_->blob_loss_weights()[net_->output_blob_indices()[j]];
for (int k = 0; k < result[j]->count(); ++k) {
ostringstream loss_msg_stream;
if (loss_weight) {
loss_msg_stream << " (* " << loss_weight
<< " = " << loss_weight * result_vec[k] << " loss)";
}
LOG_IF(INFO, Caffe::root_solver()) << " Train net output #"
<< score_index++ << ": " << output_name << " = "
<< result_vec[k] << loss_msg_stream.str();
}
}
}
for (int i = 0; i < callbacks_.size(); ++i) {
callbacks_[i]->on_gradients_ready();
}
ApplyUpdate();
// Increment the internal iter_ counter -- its value should always indicate
// the number of times the weights have been updated.
++iter_;
SolverAction::Enum request = GetRequestedAction();
// Save a snapshot if needed.
if ((param_.snapshot()
&& iter_ % param_.snapshot() == 0
&& Caffe::root_solver()) ||
(request == SolverAction::SNAPSHOT)) {
Snapshot();
}
if (SolverAction::STOP == request) {
requested_early_exit_ = true;
// Break out of training loop.
break;
}
}
}
/*
对整个网络进行训练(也就是你运行Caffe训练某个模型)的时候,实际上是在运行caffe.cpp中的train()函数,而这个函数实际上是实例化一个Solver对象,初始化后调用了Solver中的Solve()方法
调用此方法训练网络,其中会调用Step()方法来迭代,迭代 param_.max_iter() - iter_ 次
*/
template
void Solver::Solve(const char* resume_file) {
CHECK(Caffe::root_solver());
LOG(INFO) << "Solving " << net_->name();
LOG(INFO) << "Learning Rate Policy: " << param_.lr_policy();
//任何时候开始求解,初始化失败
requested_early_exit_ = false;
if (resume_file) {
LOG(INFO) << "Restoring previous solver status from " << resume_file;
Restore(resume_file);
}
// For a network that is trained by the solver, no bottom or top vecs
// should be given, and we will just provide dummy vecs.
//对于一个正在训练的网络,没有bottom或top向量被给,而且仅仅提供dummy vecs
int start_iter = iter_;
Step(param_.max_iter() - iter_);
// If we haven't already, save a snapshot after optimization, unless
// overridden by setting snapshot_after_train := false
if (param_.snapshot_after_train()
&& (!param_.snapshot() || iter_ % param_.snapshot() != 0)) {
Snapshot();
}
if (requested_early_exit_) {
LOG(INFO) << "Optimization stopped early.";
return;
}
//在优化完后,运行一个额外的训练和测试过程展示训练测试的loss或者输出。
if (param_.display() && iter_ % param_.display() == 0) {
int average_loss = this->param_.average_loss();
Dtype loss;
net_->Forward(&loss);
UpdateSmoothedLoss(loss, start_iter, average_loss);
LOG(INFO) << "Iteration " << iter_ << ", loss = " << smoothed_loss_;
}
if (param_.test_interval() && iter_ % param_.test_interval() == 0) {
TestAll();
}
LOG(INFO) << "Optimization Done.";
}
template
void Solver::TestAll() {
for (int test_net_id = 0;
test_net_id < test_nets_.size() && !requested_early_exit_;
++test_net_id) {
Test(test_net_id);
}
}
template
void Solver::Test(const int test_net_id) {
CHECK(Caffe::root_solver());
LOG(INFO) << "Iteration " << iter_
<< ", Testing net (#" << test_net_id << ")";
//检查是否有layer共享于多个网络
CHECK_NOTNULL(test_nets_[test_net_id].get())->
ShareTrainedLayersWith(net_.get());
vector test_score;
vector test_score_output_id;
const shared_ptr >& test_net = test_nets_[test_net_id];
Dtype loss = 0;
for (int i = 0; i < param_.test_iter(test_net_id); ++i) {
SolverAction::Enum request = GetRequestedAction();
//如果在训练或测试中断请求发出后,随时执行保存快照
while (request != SolverAction::NONE) {
if (SolverAction::SNAPSHOT == request) {
Snapshot();
} else if (SolverAction::STOP == request) {
requested_early_exit_ = true;
}
request = GetRequestedAction();
}
if (requested_early_exit_) {
// break out of test loop.
break;
}
Dtype iter_loss;
const vector*>& result =
test_net->Forward(&iter_loss);
if (param_.test_compute_loss()) {
loss += iter_loss;
}
if (i == 0) {
for (int j = 0; j < result.size(); ++j) {
const Dtype* result_vec = result[j]->cpu_data();
for (int k = 0; k < result[j]->count(); ++k) {
test_score.push_back(result_vec[k]);
test_score_output_id.push_back(j);
}
}
} else {
int idx = 0;
for (int j = 0; j < result.size(); ++j) {
const Dtype* result_vec = result[j]->cpu_data();
for (int k = 0; k < result[j]->count(); ++k) {
test_score[idx++] += result_vec[k];
}
}
}
}
if (requested_early_exit_) {
LOG(INFO) << "Test interrupted.";
return;
}
if (param_.test_compute_loss()) {
loss /= param_.test_iter(test_net_id);
LOG(INFO) << "Test loss: " << loss;
}
for (int i = 0; i < test_score.size(); ++i) {
const int output_blob_index =
test_net->output_blob_indices()[test_score_output_id[i]];
const string& output_name = test_net->blob_names()[output_blob_index];
const Dtype loss_weight = test_net->blob_loss_weights()[output_blob_index];
ostringstream loss_msg_stream;
const Dtype mean_score = test_score[i] / param_.test_iter(test_net_id);//求多次迭代Loss的平均值,也就是求多个batch的平局值,因为一次迭代用的是一个test batch-size 的图片
if (loss_weight) {
loss_msg_stream << " (* " << loss_weight
<< " = " << loss_weight * mean_score << " loss)";
}
LOG(INFO) << " Test net output #" << i << ": " << output_name << " = "
<< mean_score << loss_msg_stream.str();
}
}
//输出当前网络状态到一个文件中。
template
void Solver::Snapshot() {
CHECK(Caffe::root_solver());
string model_filename;
switch (param_.snapshot_format()) {
case caffe::SolverParameter_SnapshotFormat_BINARYPROTO:
model_filename = SnapshotToBinaryProto();
break;
case caffe::SolverParameter_SnapshotFormat_HDF5:
model_filename = SnapshotToHDF5();
break;
default:
LOG(FATAL) << "Unsupported snapshot format.";
}
SnapshotSolverState(model_filename);
}
//check快照的写入权限
template
void Solver::CheckSnapshotWritePermissions() {
if (Caffe::root_solver() && param_.snapshot()) {
CHECK(param_.has_snapshot_prefix())
<< "In solver params, snapshot is specified but snapshot_prefix is not";
string probe_filename = SnapshotFilename(".tempfile");
std::ofstream probe_ofs(probe_filename.c_str());
if (probe_ofs.good()) {
probe_ofs.close();
std::remove(probe_filename.c_str());
} else {
LOG(FATAL) << "Cannot write to snapshot prefix '"
<< param_.snapshot_prefix() << "'. Make sure "
<< "that the directory exists and is writeable.";
}
}
}
//Snapshot的名字
template
string Solver::SnapshotFilename(const string extension) {
return param_.snapshot_prefix() + "_iter_" + caffe::format_int(iter_)
+ extension;
}
//Snapshot保存为二进制proto的模型
template
string Solver::SnapshotToBinaryProto() {
string model_filename = SnapshotFilename(".caffemodel");
LOG(INFO) << "Snapshotting to binary proto file " << model_filename;
NetParameter net_param;
net_->ToProto(&net_param, param_.snapshot_diff());
WriteProtoToBinaryFile(net_param, model_filename);
return model_filename;
}
//Snapshot保存为HDF5模型
template
string Solver::SnapshotToHDF5() {
string model_filename = SnapshotFilename(".caffemodel.h5");
LOG(INFO) << "Snapshotting to HDF5 file " << model_filename;
net_->ToHDF5(model_filename, param_.snapshot_diff());
return model_filename;
}
//从一个文件中读入网络状态,并可以从那个状态恢复。
template
void Solver::Restore(const char* state_file) {
CHECK(Caffe::root_solver());
string state_filename(state_file);
if (state_filename.size() >= 3 &&
state_filename.compare(state_filename.size() - 3, 3, ".h5") == 0) {
RestoreSolverStateFromHDF5(state_filename);
} else {
RestoreSolverStateFromBinaryProto(state_filename);
}
}
//迭代时平均loss的smooth报告,翻译不是很准
template
void Solver::UpdateSmoothedLoss(Dtype loss, int start_iter,
int average_loss) {
if (losses_.size() < average_loss) {
losses_.push_back(loss);
int size = losses_.size();
smoothed_loss_ = (smoothed_loss_ * (size - 1) + loss) / size;
} else {
int idx = (iter_ - start_iter) % average_loss;
smoothed_loss_ += (loss - losses_[idx]) / average_loss;
losses_[idx] = loss;
}
}
INSTANTIATE_CLASS(Solver);
} // namespace caffe