(Caffe,Lenet5)训练网络入口(二)

本文地址:http://blog.csdn.net/mounty_fsc/article/details/51090114

在训练lenet的train_lenet.sh中内容为:

./build/tools/caffe train –solver=examples/mnist/lenet_solver.prototxt

由此可知,训练网咯模型是由tools/caffe.cpp生成的工具caffe在模式train下完成的。
初始化过程总的来说,从main()train()中创建Solver,在Solver中创建Net,在Net中创建Layer.

1 程序入口

  • 找到caffe.cppmain函数中,通过GetBrewFunction(caffe::string(argv[1]))()调用执行train()函数。
  • train中,通过参数-examples/mnist/lenet_solver.prototxtsolver参数读入solver_param中。
  • 随后注册并定义solver的指针(见第2节)

        shared_ptr<caffe::Solver<float> > 
    solver(caffe::SolverRegistry<float>::CreateSolver(solver_param))
  • 调用solverSolver()方法。多个GPU涉及到GPU间带异步处理问题(见第3节)

     if (gpus.size() > 1) {
       caffe::P2PSync<float> sync(solver, NULL, solver->param());
       sync.run(gpus);
     } else {
       LOG(INFO) << "Starting Optimization";
       solver->Solve();
     }

2 Solver的创建

在1中,Solver的指针solver是通过SolverRegistry::CreateSolver创建的,CreateSolver函数中值得注意带是return registry[type](param)

  // Get a solver using a SolverParameter.
  static Solver<Dtype>* CreateSolver(const SolverParameter& param) {
    const string& type = param.type();
    CreatorRegistry& registry = Registry();
    CHECK_EQ(registry.count(type), 1) << "Unknown solver type: " << type
        << " (known types: " << SolverTypeListString() << ")";
    return registry[type](param);
  }

其中:

registry是一个map<string,Creator>: typedef std::map<string, Creator> CreatorRegistry
其中Creator是一个函数指针类型: typedef Solver<Dtype>* (*Creator)(const SolverParameter&)
registry[type]为一个函数指针变量,在Lenet5中,此处具体的值为 caffe::Creator_SGDSolver<float>(caffe::SolverParameter const&)
其中Creator_SGDSolver在以下宏中定义,
REGISTER_SOLVER_CLASS(SGD)
该宏完全展开得到的内容为:

template <typename Dtype>                                                    \
  Solver<Dtype>* Creator_SGDSolver(                                       \
      const SolverParameter& param)                                            \
  {                                                                            \
    return new SGDSolver<Dtype>(param);                                     \
  }                                                                            \
  static SolverRegisterer<float> g_creator_f_SGD("SGD", Creator_SGDSolver<float>);    \
  static SolverRegisterer<double> g_creator_d_SGD("SGD", Creator_SGDSolver<double>)

从上可以看出,registry[type](param)中实际上调用了SGDSolver带构造方法,事实上,网络是在SGDSolver的构造方法中初始化的。
SGDSolver的定义如下:

template <typename Dtype>
class SGDSolver : public Solver<Dtype> {
 public:
  explicit SGDSolver(const SolverParameter& param)
      : Solver<Dtype>(param) { PreSolve(); }
  explicit SGDSolver(const string& param_file)
      : Solver<Dtype>(param_file) { PreSolve(); }
......

SGDSolver继承与Solver<Dtype>,因而new SGDSolver<Dtype>(param)将执行Solver<Dtype>的构造函数,然后调用自身构造函数。整个网络带初始化即在这里面完成(详见本系列博文(三))。

3 Solver::Solve()函数

你可能感兴趣的:(C++,caffe,LeNet)