(Caffe)基本类DataReader、QueuePair、Body(四)

1 简介

QueuePair与Body是DataReader的内部类。一个DataReader对应一个任务,一个Body生成一个线程来读取数据库(如examples/mnist/mnist_train_lmdb)。QueuePair为前面两者之间的衔接、通信。

2 源代码

/**
 * @brief Reads data from a source to queues available to data layers.
 * A single reading thread is created per source, even if multiple solvers
 * are running in parallel, e.g. for multi-GPU training. This makes sure
 * databases are read sequentially, and that each solver accesses a different
 * subset of the database. Data is distributed to solvers in a round-robin
 * way to keep parallel training deterministic.
 */
class DataReader {
 public:
...
 protected:
  // Queue pairs are shared between a body and its readers
  class QueuePair {
   public:
    explicit QueuePair(int size);
    ~QueuePair();

    BlockingQueue<Datum*> free_;
    BlockingQueue<Datum*> full_;
  };

  // A single body is created per source
  class Body : public InternalThread {
   public:
...
   protected:
    void InternalThreadEntry();
    void read_one(db::Cursor* cursor, QueuePair* qp);

    const LayerParameter param_;
    BlockingQueue<shared_ptr<QueuePair> > new_queue_pairs_;
...
  };
...

  const shared_ptr<QueuePair> queue_pair_;
  shared_ptr<Body> body_;
  static map<const string, boost::weak_ptr<DataReader::Body> > bodies_;
};

2 类QueuePair

DataReader::QueuePair::QueuePair(int size) {
  // Initialize the free queue with requested number of datums
  for (int i = 0; i < size; ++i) {
    free_.push(new Datum());
  }
}

说明:
  1. 一个QueuePair对应一个任务队列,从数据库(如examples/mnist/mnist_train_lmdb)中读取size个样本
  2. BlockingQueue为一个线程安全的队列容器,其模板类型可能是Datum,Batch等。此处装的是Datum。
  3. BlockingQueue<Datum*> free_为Datum队列,均为新new出来的,没有包含原始数据(图像)信息
  4. BlockingQueue<Datum*> full_为从数据库读取信息后的队列,包含了原始数据(图像)信息
  5. Datum为一个样本单元,关于Datum的定义,参见caffe.proto文件,一般来说,Datum对应于一张图像(及其label)

3 类Body

DataReader::Body::Body(const LayerParameter& param)
    : param_(param),
      new_queue_pairs_() {
  StartInternalThread();
}

说明:

  1. Body类继承了InternalThread(详见博文)。在构造函数了开启这个线程
  2. Body类重载了 DataReader::Body::InternalThreadEntry() 函数,从数据库读取数据的操作在该函数中实现,见本文第5节

4 类DataReader

DataReader类的构造函数如下:

map<const string, weak_ptr<DataReader::Body> > DataReader::bodies_;
static boost::mutex bodies_mutex_;

DataReader::DataReader(const LayerParameter& param)
    : queue_pair_(new QueuePair(  //
        param.data_param().prefetch() * param.data_param().batch_size())) {
  // Get or create a body
  boost::mutex::scoped_lock lock(bodies_mutex_);
  string key = source_key(param);
  weak_ptr<Body>& weak = bodies_[key];
  body_ = weak.lock();
  if (!body_) {
    body_.reset(new Body(param));
    bodies_[key] = weak_ptr<Body>(body_);
  }
  body_->new_queue_pairs_.push(queue_pair_);
}
说明:
  1. 一个数据库只可能有Body对象,如examples/mnist/mnist_train_lmdb不管在任何线程的任何DataReader对象中,都只会有一个Body对象,因为bodies_是静态的
  2. 所以有,一个Body的对象也可以有多个DataReader对象
  3. 此外有,一个DataReader对象可以有多个Body对象,即map<string,weak_ptr<Body>> bodies_
  4. 由代码5,6行及16行可知,每一个DataReader对应一个读的任务,即从数据库(如examples/mnist/mnist_train_lmdb)中读取param.data_param().prefetch() * param.data_param().batch_size()(LeNet5中默认为4×64)个样本
  5. 由此可见,一个DataReader为一个任务,通过QueuePair(也对应于该任务)“通知”Body某个数据库中读去N个样本
  6. 由代码13行可知,某个数据库(如examples/mnist/mnist_train_lmdb)对应的Body若不存在,将新建一个Body来处理该数据库,也可以理解成新建一个唯一对应于该数据库的线程来处理该数据可。

5 函数DataReader::Body::InternalThreadEntry

void DataReader::Body::InternalThreadEntry() {
...
  vector<shared_ptr<QueuePair> > qps;
  try {
...
    // To ensure deterministic runs, only start running once all solvers
    // are ready. But solvers need to peek on one item during initialization,
    // so read one item, then wait for the next solver.
    for (int i = 0; i < solver_count; ++i) {
      shared_ptr<QueuePair> qp(new_queue_pairs_.pop());
      read_one(cursor.get(), qp.get());
      qps.push_back(qp);
    }
    // Main loop
    while (!must_stop()) {
      for (int i = 0; i < solver_count; ++i) {
        read_one(cursor.get(), qps[i].get());
      }
...
    }
  } catch (boost::thread_interrupted&) {
    // Interrupted exception is expected on shutdown
  }
}
说明:

  1. read_one()从QueuePair的free_中取出一个Datum,从数据库读入数据至Datum,然后放入full_中
  2. 由第4节16行可知,一个新的任务(DataReader)到来时,将把一个命令队列(QueuePair)放入到某个数据库(Body)的缓冲命令队列中(new_queue_pairs_)
  3. 9到13行从每个solver的任务中读取一个Datum,在15到18行从数据库中循环读出数据
  4. 该线程何时停止呢?

你可能感兴趣的:(C++,源码,深度学习,caffe)