(Caffe)基本类Solver、Caffe、Batch(二)

本文从CSDN上转移过来:
http://blog.csdn.net/mounty_fsc/article/details/51088173

1 Solver

1.1 简介

其对网络进行求解,其作用有:

  1. 提供优化日志支持、创建用于学习的训练网络、创建用于评估的测试网络
  2. 通过调用forward / backward迭代地优化,更新权值
  3. 周期性地评估测试网络
  4. 通过优化了解model及solver的状态

1.2 源代码

/**
 * @brief An interface for classes that perform optimization on Net%s.
 *
 * Requires implementation of ApplyUpdate to compute a parameter update
 * given the current state of the Net parameters.
 */
template 
class Solver {
 public:
  explicit Solver(const SolverParameter& param,
      const Solver* root_solver = NULL);
  explicit Solver(const string& param_file, const Solver* root_solver = NULL);
  void Init(const SolverParameter& param);
  void InitTrainNet();
  void InitTestNets();
 ...
  // The main entry of the solver function. In default, iter will be zero. Pass
  // in a non-zero iter number to resume training for a pre-trained net.
  virtual void Solve(const char* resume_file = NULL);
  inline void Solve(const string resume_file) { Solve(resume_file.c_str()); }
  void Step(int iters);
...

 protected:
  // Make and apply the update value for the current iteration.
  virtual void ApplyUpdate() = 0;
  ...

  SolverParameter param_;
  int iter_;
  int current_step_;
  shared_ptr > net_;
  vector > > test_nets_;
  vector callbacks_;
  vector losses_;
  Dtype smoothed_loss_;

  // The root solver that holds root nets (actually containing shared layers)
  // in data parallelism
  const Solver* const root_solver_;
...
};

说明:

  1. shared_ptr> net_为训练网络的指针,vector>> test_nets为测试网络的指针组,可见测试网络可以有多个

  2. 一般来说训练网络跟测试网络在实现上会有区别,但是绝大部分网络层是相同的。

  3. 不同的模型训练方法通过重载函数ComputeUpdateValue( )实现计算update参数的核心功能

  4. caffe.cpp中的train( )函数训练模型,在这里实例化一个Solver对象,初始化后调用了Solver中的Solve( )方法。而这个Solve( )函数主要就是在迭代运行下面这两个函数。ComputeUpdateValue();
    net_->Update();

1.3 Solver的方法

  • Stochastic Gradient Descent (type: "SGD")
  • AdaDelta (type: "AdaDelta")
  • Adaptive Gradient (type: "AdaGrad")
  • Adam (type: "Adam")
  • Nesterov’s Accelerated Gradient (type: "Nesterov")
  • RMSprop (type: "RMSProp")

详细参见引用1

2 Caffe类

Caffe类为一个包含常用的caffe成员的单例类。如caffe使用的cuda库cublas,curand的句柄等,以及生成Caffe中的随机数等。


// common.hpp
// A singleton class to hold common caffe stuff, such as the handler that
// caffe is going to use for cublas, curand, etc.
class Caffe {
 public:
  ~Caffe();

  // Thread local context for Caffe. Moved to common.cpp instead of
  // including boost/thread.hpp to avoid a boost/NVCC issues (#1009, #1010)
  // on OSX. Also fails on Linux with CUDA 7.0.18.
  static Caffe& Get();

  enum Brew { CPU, GPU };
...

protected:
#ifndef CPU_ONLY
  cublasHandle_t cublas_handle_;
  curandGenerator_t curand_generator_;
#endif
  shared_ptr random_generator_;

  Brew mode_;
  int solver_count_;
  bool root_solver_;

 private:
  // The private constructor to avoid duplicate instantiation.
  Caffe();
  DISABLE_COPY_AND_ASSIGN(Caffe);
};
//common.cpp

namespace caffe {

// Make sure each thread can have different values.
static boost::thread_specific_ptr thread_instance_;

Caffe& Caffe::Get() {
  if (!thread_instance_.get()) {
    thread_instance_.reset(new Caffe());
  }
  return *(thread_instance_.get());
}

...
Caffe::Caffe()
    : cublas_handle_(NULL), curand_generator_(NULL), random_generator_(),
    mode_(Caffe::CPU), solver_count_(1), root_solver_(true) {
  // Try to create a cublas handler, and report an error if failed (but we will
  // keep the program running as one might just want to run CPU code).
  if (cublasCreate(&cublas_handle_) != CUBLAS_STATUS_SUCCESS) {
    LOG(ERROR) << "Cannot create Cublas handle. Cublas won't be available.";
  }
  // Try to create a curand handler.
  if (curandCreateGenerator(&curand_generator_, CURAND_RNG_PSEUDO_DEFAULT)
      != CURAND_STATUS_SUCCESS ||
      curandSetPseudoRandomGeneratorSeed(curand_generator_, cluster_seedgen())
      != CURAND_STATUS_SUCCESS) {
    LOG(ERROR) << "Cannot create Curand generator. Curand won't be available.";
  }
}
...

}  // namespace caffe

说明:

  1. Caffe类为一个单例类,构造方法私有
  2. 该单例由static boost::thread_specific_ptr thread_instance_维护,确保多线程环境下,不同的线程有不同的Caffe单例版本
  3. 获取该单例由Get()方法执行,即Caffe::Get()方法返回thread_instance_维护的单例,
  4. thread_instance_的初值为NULL,若是第一次访问,则new Caffe()
  5. new Caffe()执行构造方法,其实只是创建了cublas,curand的句柄
  6. 单步调试可发现cublasCreate()创建cublas的句柄,生成了额外的两个线程

3 Batch

template 
class Batch {
 public:
  Blob data_, label_;
};

说明:

  • Batch是对一个样本的封装,与Datum不同,Datum是面向数据库的,且一个Datum对应一个样本(图像、标签);而Batch是面向网络的,一个Batch对应一批样本

[1].http://caffe.berkeleyvision.org/tutorial/solver.html

你可能感兴趣的:((Caffe)基本类Solver、Caffe、Batch(二))