caffe:solver.cpp——init()

template <typename Dtype>
void Solver<Dtype>::Init(const SolverParameter& param) {
  CHECK(Caffe::root_solver() || root_solver_)
      << "root_solver_ needs to be set for all non-root solvers";
  LOG_IF(INFO, Caffe::root_solver()) << "Initializing solver from parameters: "
    << std::endl << param.DebugString();
  param_ = param;
  CHECK_GE(param_.average_loss(), 1) << "average_loss should be non-negative.";
  CheckSnapshotWritePermissions();//对snapshot的一个检查
  if (Caffe::root_solver() && param_.random_seed() >= 0) {
    Caffe::set_random_seed(param_.random_seed());
  }
  // Scaffolding code
  InitTrainNet();
  if (Caffe::root_solver()) {
    InitTestNets();
    LOG(INFO) << "Solver scaffolding done.";
  }
  iter_ = 0;
  current_step_ = 0;
}
//这段代码主要是对网络进行初始赋值前的一些检查,基本上没啥需要特别关注的(也可能是我功力不够),主要看它掉用的函数InitTrainNet();

template <typename Dtype>
void Solver<Dtype>::InitTrainNet() {
  const int num_train_nets = param_.has_net() + param_.has_net_param() +
      param_.has_train_net() + param_.has_train_net_param();
      //我针对的是lenet网络,这里,只有param_.has_net()为true,从参数的命名就可以看出,
      //这是检查网络框架是否读入,以及网络的参数是否读入 你也可以打印下param_,就会发
//net_ = 0x69f850, net_param_ = 0x0, 
  //train_net_ = 0x7ffff692e3a8 , 
  //test_net_ = { = {
    //  static kInitialSize = 0, elements_ = 0x0, current_size_ = 0, 
   //   allocated_size_ = 0, total_size_ = 0}, }, 
//  train_net_param_ = 0x0, 
  const string& field_names = "net, net_param, train_net, train_net_param";
  CHECK_GE(num_train_nets, 1) << "SolverParameter must specify a train net "
      << "using one of these fields: " << field_names;
  CHECK_LE(num_train_nets, 1) << "SolverParameter must not contain more than "
      << "one of these fields specifying a train_net: " << field_names;
  NetParameter net_param; //个人建议,当出现这种参数声明的时候,对照caffe.proto看
  if (param_.has_train_net_param()) {
    LOG_IF(INFO, Caffe::root_solver())
        << "Creating training net specified in train_net_param.";
    net_param.CopyFrom(param_.train_net_param());
  } else if (param_.has_train_net()) {
    LOG_IF(INFO, Caffe::root_solver())
        << "Creating training net from train_net file: " << param_.train_net();
    ReadNetParamsFromTextFileOrDie(param_.train_net(), &net_param);
  }
  if (param_.has_net_param()) {
    LOG_IF(INFO, Caffe::root_solver())
        << "Creating training net specified in net_param.";
    net_param.CopyFrom(param_.net_param());
  }
  //这几个条件判断目的:确定对网络进行哪些操作
  if (param_.has_net()) {
    LOG_IF(INFO, Caffe::root_solver())
        << "Creating training net from net file: " << param_.net();
    ReadNetParamsFromTextFileOrDie(param_.net(), &net_param);
    //打印param_.net(): "/home/bing/tool/caffe-master11/caffe-master/examples/mnist/lenet_train_test.prototxt"
    //ReadNetParamsFromTextFileOrDie这个函数的作用就是将param_.net()指向的
    //网络赋值给刚刚创建的net_param参数(附1)
  }
  // Set the correct NetState.  We start with the solver defaults (lowest  
  // precedence); then, merge in any NetState specified by the net_param itself;
  // finally, merge in any NetState specified by the train_state (highest
  // precedence).
  NetState net_state;//看caffe.proto去~
  net_state.set_phase(TRAIN);
  net_state.MergeFrom(net_param.state());
  net_state.MergeFrom(param_.train_state());
  net_param.mutable_state()->CopyFrom(net_state);
  if (Caffe::root_solver()) {
    net_.reset(new Net<Dtype>(net_param));//net_param去对网络进行初始化到net.cpp的init
  } else {
    net_.reset(new Net<Dtype>(net_param, root_solver_->net_.get()));
  }
}

以上,差不多就是solver.cpp对网络进行初始话主要的内容,其它的
也大同小异,总结一下,solver.cpp的init主要就是将train.prototxt文
件读入,然后进行一些个判断啊什么的,最后调用net.cpp的init,进行网络的初始化。



附1,net_param赋值前后对比
(gdb) p net_param
$27 = {<google::protobuf::Message> = {<No data fields>}, 
  static kNameFieldNumber = 1, static kInputFieldNumber = 3, 
  static kInputShapeFieldNumber = 8, static kInputDimFieldNumber = 4, 
  static kForceBackwardFieldNumber = 5, static kStateFieldNumber = 6, 
  static kDebugInfoFieldNumber = 7, static kLayerFieldNumber = 100, 
  static kLayersFieldNumber = 2, _unknown_fields_ = {fields_ = 0x0}, 
  name_ = 0x7ffff692e3a8 <google::protobuf::internal::kEmptyString>, 
  input_ = {<google::protobuf::internal::RepeatedPtrFieldBase> = {
      static kInitialSize = 0, elements_ = 0x0, current_size_ = 0, 
      allocated_size_ = 0, total_size_ = 0}, <No data fields>}, 
  input_shape_ = {<google::protobuf::internal::RepeatedPtrFieldBase> = {
      static kInitialSize = 0, elements_ = 0x0, current_size_ = 0, 
      allocated_size_ = 0, total_size_ = 0}, <No data fields>}, input_dim_ = {
    static kInitialSize = <optimized out>, elements_ = 0x0, current_size_ = 0, 
    total_size_ = 0}, state_ = 0x0, 
  layer_ = {<google::protobuf::internal::RepeatedPtrFieldBase> = {
      static kInitialSize = 0, elements_ = 0x0, current_size_ = 0, 
      allocated_size_ = 0, total_size_ = 0}, <No data fields>}, 
  layers_ = {<google::protobuf::internal::RepeatedPtrFieldBase> = {
      static kInitialSize = 0, elements_ = 0x0, current_size_ = 0, 
      allocated_size_ = 0, total_size_ = 0}, <No data fields>}, 
  force_backward_ = false, debug_info_ = false, _cached_size_ = 0, 
  _has_bits_ = {0}, static default_instance_ = 0x6a2d70}
(gdb) p net_param
$29 = {<google::protobuf::Message> = {<No data fields>}, 
  static kNameFieldNumber = 1, static kInputFieldNumber = 3, 
  static kInputShapeFieldNumber = 8, static kInputDimFieldNumber = 4, 
  static kForceBackwardFieldNumber = 5, static kStateFieldNumber = 6, 
  static kDebugInfoFieldNumber = 7, static kLayerFieldNumber = 100, 
  static kLayersFieldNumber = 2, _unknown_fields_ = {fields_ = 0x0}, 
  name_ = 0x69f260, 
  input_ = {<google::protobuf::internal::RepeatedPtrFieldBase> = {
      static kInitialSize = 0, elements_ = 0x0, current_size_ = 0, 
      allocated_size_ = 0, total_size_ = 0}, <No data fields>}, 
  input_shape_ = {<google::protobuf::internal::RepeatedPtrFieldBase> = {
      static kInitialSize = 0, elements_ = 0x0, current_size_ = 0, 
      allocated_size_ = 0, total_size_ = 0}, <No data fields>}, input_dim_ = {
    static kInitialSize = <optimized out>, elements_ = 0x0, current_size_ = 0, 
    total_size_ = 0}, state_ = 0x0, 
  layer_ = {<google::protobuf::internal::RepeatedPtrFieldBase> = {
      static kInitialSize = 0, elements_ = 0x6be760, current_size_ = 11, 
      allocated_size_ = 11, total_size_ = 16}, <No data fields>}, 
  layers_ = {<google::protobuf::internal::RepeatedPtrFieldBase> = {
      static kInitialSize = 0, elements_ = 0x0, current_size_ = 0, 
      allocated_size_ = 0, total_size_ = 0}, <No data fields>}, 
  force_backward_ = false, debug_info_ = false, _cached_size_ = 0, 
  _has_bits_ = {1}, static default_instance_ = 0x6a2d70}

你可能感兴趣的:(caffe)