template <typename Dtype>
void Solver<Dtype>::Init(const SolverParameter& param) {
CHECK(Caffe::root_solver() || root_solver_)
<< "root_solver_ needs to be set for all non-root solvers";
LOG_IF(INFO, Caffe::root_solver()) << "Initializing solver from parameters: "
<< std::endl << param.DebugString();
param_ = param;
CHECK_GE(param_.average_loss(), 1) << "average_loss should be non-negative.";
CheckSnapshotWritePermissions();
if (Caffe::root_solver() && param_.random_seed() >= 0) {
Caffe::set_random_seed(param_.random_seed());
}
InitTrainNet();
if (Caffe::root_solver()) {
InitTestNets();
LOG(INFO) << "Solver scaffolding done.";
}
iter_ = 0;
current_step_ = 0;
}
template <typename Dtype>
void Solver<Dtype>::InitTrainNet() {
const int num_train_nets = param_.has_net() + param_.has_net_param() +
param_.has_train_net() + param_.has_train_net_param();
const string& field_names = "net, net_param, train_net, train_net_param";
CHECK_GE(num_train_nets, 1) << "SolverParameter must specify a train net "
<< "using one of these fields: " << field_names;
CHECK_LE(num_train_nets, 1) << "SolverParameter must not contain more than "
<< "one of these fields specifying a train_net: " << field_names;
NetParameter net_param;
if (param_.has_train_net_param()) {
LOG_IF(INFO, Caffe::root_solver())
<< "Creating training net specified in train_net_param.";
net_param.CopyFrom(param_.train_net_param());
} else if (param_.has_train_net()) {
LOG_IF(INFO, Caffe::root_solver())
<< "Creating training net from train_net file: " << param_.train_net();
ReadNetParamsFromTextFileOrDie(param_.train_net(), &net_param);
}
if (param_.has_net_param()) {
LOG_IF(INFO, Caffe::root_solver())
<< "Creating training net specified in net_param.";
net_param.CopyFrom(param_.net_param());
}
if (param_.has_net()) {
LOG_IF(INFO, Caffe::root_solver())
<< "Creating training net from net file: " << param_.net();
ReadNetParamsFromTextFileOrDie(param_.net(), &net_param);
}
NetState net_state;
net_state.set_phase(TRAIN);
net_state.MergeFrom(net_param.state());
net_state.MergeFrom(param_.train_state());
net_param.mutable_state()->CopyFrom(net_state);
if (Caffe::root_solver()) {
net_.reset(new Net<Dtype>(net_param));
} else {
net_.reset(new Net<Dtype>(net_param, root_solver_->net_.get()));
}
}
以上,差不多就是solver.cpp对网络进行初始话主要的内容,其它的
也大同小异,总结一下,solver.cpp的init主要就是将train.prototxt文
件读入,然后进行一些个判断啊什么的,最后调用net.cpp的init,进行网络的初始化。
附1,net_param赋值前后对比
(gdb) p net_param
$27 = {<google::protobuf::Message> = {<No data fields>},
static kNameFieldNumber = 1, static kInputFieldNumber = 3,
static kInputShapeFieldNumber = 8, static kInputDimFieldNumber = 4,
static kForceBackwardFieldNumber = 5, static kStateFieldNumber = 6,
static kDebugInfoFieldNumber = 7, static kLayerFieldNumber = 100,
static kLayersFieldNumber = 2, _unknown_fields_ = {fields_ = 0x0},
name_ = 0x7ffff692e3a8 <google::protobuf::internal::kEmptyString>,
input_ = {<google::protobuf::internal::RepeatedPtrFieldBase> = {
static kInitialSize = 0, elements_ = 0x0, current_size_ = 0,
allocated_size_ = 0, total_size_ = 0}, <No data fields>},
input_shape_ = {<google::protobuf::internal::RepeatedPtrFieldBase> = {
static kInitialSize = 0, elements_ = 0x0, current_size_ = 0,
allocated_size_ = 0, total_size_ = 0}, <No data fields>}, input_dim_ = {
static kInitialSize = <optimized out>, elements_ = 0x0, current_size_ = 0,
total_size_ = 0}, state_ = 0x0,
layer_ = {<google::protobuf::internal::RepeatedPtrFieldBase> = {
static kInitialSize = 0, elements_ = 0x0, current_size_ = 0,
allocated_size_ = 0, total_size_ = 0}, <No data fields>},
layers_ = {<google::protobuf::internal::RepeatedPtrFieldBase> = {
static kInitialSize = 0, elements_ = 0x0, current_size_ = 0,
allocated_size_ = 0, total_size_ = 0}, <No data fields>},
force_backward_ = false, debug_info_ = false, _cached_size_ = 0,
_has_bits_ = {0}, static default_instance_ = 0x6a2d70}
(gdb) p net_param
$29 = {<google::protobuf::Message> = {<No data fields>},
static kNameFieldNumber = 1, static kInputFieldNumber = 3,
static kInputShapeFieldNumber = 8, static kInputDimFieldNumber = 4,
static kForceBackwardFieldNumber = 5, static kStateFieldNumber = 6,
static kDebugInfoFieldNumber = 7, static kLayerFieldNumber = 100,
static kLayersFieldNumber = 2, _unknown_fields_ = {fields_ = 0x0},
name_ = 0x69f260,
input_ = {<google::protobuf::internal::RepeatedPtrFieldBase> = {
static kInitialSize = 0, elements_ = 0x0, current_size_ = 0,
allocated_size_ = 0, total_size_ = 0}, <No data fields>},
input_shape_ = {<google::protobuf::internal::RepeatedPtrFieldBase> = {
static kInitialSize = 0, elements_ = 0x0, current_size_ = 0,
allocated_size_ = 0, total_size_ = 0}, <No data fields>}, input_dim_ = {
static kInitialSize = <optimized out>, elements_ = 0x0, current_size_ = 0,
total_size_ = 0}, state_ = 0x0,
layer_ = {<google::protobuf::internal::RepeatedPtrFieldBase> = {
static kInitialSize = 0, elements_ = 0x6be760, current_size_ = 11,
allocated_size_ = 11, total_size_ = 16}, <No data fields>},
layers_ = {<google::protobuf::internal::RepeatedPtrFieldBase> = {
static kInitialSize = 0, elements_ = 0x0, current_size_ = 0,
allocated_size_ = 0, total_size_ = 0}, <No data fields>},
force_backward_ = false, debug_info_ = false, _cached_size_ = 0,
_has_bits_ = {1}, static default_instance_ = 0x6a2d70}