Caffe源码解读(十二):自定义数据输入层

第1,3,4,5步跟上一节的自定义神经层的一样。
数据输入层需要重写三个函数:
1. DataLayerSetUp:定义好从prototxt读入的参数名和容器的规格(设好N,K,H,W)
2. ShuffleImages:打乱顺序
3. load_batch:把图片读入到内存

代码及解读如下:

#ifdef USE_OPENCV
#include 

#include   // NOLINT(readability/streams)
#include   // NOLINT(readability/streams)
#include 
#include 
#include 

#include "caffe/data_transformer.hpp"
#include "caffe/layers/base_data_layer.hpp"
#include "caffe/layers/image_data_layer.hpp"
#include "caffe/util/benchmark.hpp"
#include "caffe/util/io.hpp"
#include "caffe/util/math_functions.hpp"
#include "caffe/util/rng.hpp"

namespace caffe {

template <typename Dtype>
ImageDataLayer::~ImageDataLayer() {
  this->StopInternalThread();
}

//DataLayerSetUp:定义好从prototxt读入的参数名和容器的规格(设好N, K, H, W)
template <typename Dtype>
void ImageDataLayer::DataLayerSetUp(const vector*>& bottom,
      const vector*>& top) {

    /*
        读取在prototxt里面的设置数据:
            1、caffe.proto中的LayerParameter中定义了ImageDataParameter类型的image_data_param变量
            2、ImageDataParameter类中,定义了new_height、new_width、is_color、root_folder
    */
    //layer_param_是Layer类中定义的protected变量;
    //Layer类在layer.hpp中定义,Layer没有继承任何其他类;
    //Layer类中定义了LayerParameter类型的变量layer_param_;
    //LayerParameter在caffe.proto定义,LayerParameter中定义了“optional ImageDataParameter image_data_param = 115;”
    //ImageDataParameter也在caffe.proto中定义,ImageDataParameter中定义了new_height、new_width、is_color、root_folder
  const int new_height = this->layer_param_.image_data_param().new_height();    
  const int new_width  = this->layer_param_.image_data_param().new_width();
  const bool is_color  = this->layer_param_.image_data_param().is_color();
  string root_folder = this->layer_param_.image_data_param().root_folder();

  CHECK((new_height == 0 && new_width == 0) ||
      (new_height > 0 && new_width > 0)) << "Current implementation requires "
      "new_height and new_width to be set at the same time.";

  // Read the file with filenames and labels
    /*
        读取lmdb文件,并把data和label配对,存储到lines_里面,lines_在image_data_layer.hpp由我们自己定义,不是由prototxt生成。
    */
    //source跟new_height、new_width一样在ImageDataParameter中定义
  const string& source = this->layer_param_.image_data_param().source();
  LOG(INFO) << "Opening file " << source;

    //读取lmdb文件
    //string类的c_str()函数,返回string的内含字符串,创建一个stream类对象infile,从硬盘读数据到内存。
  std::ifstream infile(source.c_str());
  string line;
  size_t pos;
  int label;
  while (std::getline(infile, line)) {
    pos = line.find_last_of(' ');   //find_last_of:查找最近一个空格的位置。每一行的data和label是由空格分开,pos就是空格的位置。
    label = atoi(line.substr(pos + 1).c_str()); //substr:取子字符串,从pos+1(也就是label的首字母)到行尾。取label的字符串表示。
    lines_.push_back(std::make_pair(line.substr(0, pos), label));   //lines_在ImageDataLayer类中由自己定义的变量;
                                                                    //类型为vector >:std::pair主要的作用是将两个数据组合成一个数据
                                                                    //make_pair:生成pair对象
                                                                    //push_back:vector的函数,把变量装入vector中。
  }

  CHECK(!lines_.empty()) << "File is empty";

    /*
        使用shuffle打乱顺序
    */
  if (this->layer_param_.image_data_param().shuffle()) {
    // randomly shuffle data
    LOG(INFO) << "Shuffling data";
    const unsigned int prefetch_rng_seed = caffe_rng_rand();    //生成一个随机数作为种子,caffe_rng_rand在math_functions.cpp中定义
    prefetch_rng_.reset(new Caffe::RNG(prefetch_rng_seed));     //prefetch_rng_在ImageDataLayer类中由自己定义,类型为shared_ptr
                                                                //shared_ptr请参见《笔记.doc》
    ShuffleImages();    //本类的函数,根据prefetch_rng_随机数打乱容器lines_的顺序
  }
  LOG(INFO) << "A total of " << lines_.size() << " images.";

    /*
        使用skip随机跳过一些图片
    */
  lines_id_ = 0;    //lines_id_:由在ImageDataLayer类中由自己定义
  // Check if we would need to randomly skip a few data points
  if (this->layer_param_.image_data_param().rand_skip()) {  //rand_skip指定随机跳过的间隔,跟new_height、new_width一样在ImageDataParameter中定义
    unsigned int skip = caffe_rng_rand() %
        this->layer_param_.image_data_param().rand_skip();  //
    LOG(INFO) << "Skipping first " << skip << " data points.";
    CHECK_GT(lines_.size(), skip) << "Not enough points to skip";
    lines_id_ = skip;
  }


  // Read an image, and use it to initialize the top blob.
    /*
        代码核心:加载图片
    */
  cv::Mat cv_img = ReadImageToCVMat(root_folder + lines_[lines_id_].first,  //root_folder:根目录; lines_:类型为的容器;first指的是data,也就是图片
                                    new_height, new_width, is_color);       //图片的高、宽、通道数
  CHECK(cv_img.data) << "Could not load " << lines_[lines_id_].first;
  // Use data_transformer to infer the expected blob shape from a cv_image.
  vector<int> top_shape = this->data_transformer_->InferBlobShape(cv_img);  //data_transformer_:在BaseDataLayer类中定义,类型为shared_ptr >
                                                                            //DataTransformer类:在data_transformer.hpp中定义,作用是将常用变换应用于输入数据,例如缩放,镜像,减去图像平均值。
                                                                            //InferBlobShape函数:在DataTransformer类中定义,推断Blob的shape
                                                                            //详情见http://blog.csdn.net/xizero00/article/details/50905685
                                                                            //top_shape:以mnist的图片为例,top_shape的值将是:[1,1,28,28]。即1张图,单通道,高28,宽28;
  this->transformed_data_.Reshape(top_shape);   //更改Blob的维度大小到图片的大小
  // Reshape prefetch_data and top[0] according to the batch_size.
  const int batch_size = this->layer_param_.image_data_param().batch_size();    //读取batch_size大小,
  CHECK_GT(batch_size, 0) << "Positive batch size required";
  top_shape[0] = batch_size;    //设置第一个维度的大小为batch_size,即每次迭代有batch_size个图片
  for (int i = 0; i < this->PREFETCH_COUNT; ++i) {  //PREFETCH_COUNT:静态变量,预取的batch数,默认为3
    this->prefetch_[i].data_.Reshape(top_shape);    //把prefetch_的data_做reshape到top_shape大小
                                                    //prefetch_[PREFETCH_COUNT]:类型Batch ,Batch类只有两个public的变量data_和label_
  }
  top[0]->Reshape(top_shape);   //top即ImageDataLayer要输出的数据,由组成,top[0]表示数据,top[1]表示label。data和label都是blob类型。

  LOG(INFO) << "output data size: " << top[0]->num() << ","
      << top[0]->channels() << "," << top[0]->height() << ","
      << top[0]->width();
  // label
  vector<int> label_shape(1, batch_size);   //label的空间尺寸:1表示1维空间
  top[1]->Reshape(label_shape);
  for (int i = 0; i < this->PREFETCH_COUNT; ++i) {
    this->prefetch_[i].label_.Reshape(label_shape);
  }
}

template <typename Dtype>
void ImageDataLayer::ShuffleImages() {
  caffe::rng_t* prefetch_rng =
      static_cast(prefetch_rng_->generator());   //generator():随机数生成器

  shuffle(lines_.begin(), lines_.end(), prefetch_rng);  //shuffle:根据随机数打乱vector的顺序,为什么
}

// This function is called on prefetch thread
//把图片读到内存
template <typename Dtype>
void ImageDataLayer::load_batch(Batch* batch) {
  CPUTimer batch_timer;
  batch_timer.Start();
  double read_time = 0;
  double trans_time = 0;
  CPUTimer timer;
  CHECK(batch->data_.count());
  CHECK(this->transformed_data_.count());
  //读取prototxt的文件配置,和SetUp函数操作一致。
  ImageDataParameter image_data_param = this->layer_param_.image_data_param();
  const int batch_size = image_data_param.batch_size();
  const int new_height = image_data_param.new_height();
  const int new_width = image_data_param.new_width();
  const bool is_color = image_data_param.is_color();
  string root_folder = image_data_param.root_folder();

  // Reshape according to the first image of each batch
  // on single input batches allows for inputs of varying dimension.
  cv::Mat cv_img = ReadImageToCVMat(root_folder + lines_[lines_id_].first,
      new_height, new_width, is_color);
  CHECK(cv_img.data) << "Could not load " << lines_[lines_id_].first;
  // Use data_transformer to infer the expected blob shape from a cv_img.
  vector<int> top_shape = this->data_transformer_->InferBlobShape(cv_img);  //计算cv_img大小,和SetUp函数操作一致。
  this->transformed_data_.Reshape(top_shape);   //设置blob空间尺寸,和SetUp函数操作一致。
  // Reshape batch according to the batch_size.
  top_shape[0] = batch_size;    //batch_size张图片,和SetUp函数操作一致。
  batch->data_.Reshape(top_shape);  //batch类只有两个public的变量data_和label_,都为Blob类型。

  Dtype* prefetch_data = batch->data_.mutable_cpu_data();   //mutable_cpu_data()返回data_的地址
  Dtype* prefetch_label = batch->label_.mutable_cpu_data();

  // datum scales
  const int lines_size = lines_.size(); //样本个数
  for (int item_id = 0; item_id < batch_size; ++item_id) {
    // get a blob
    timer.Start();
    CHECK_GT(lines_size, lines_id_);
    cv::Mat cv_img = ReadImageToCVMat(root_folder + lines_[lines_id_].first,    //读取第lines_id_个样本的数据,转化为Mat型
        new_height, new_width, is_color);
    CHECK(cv_img.data) << "Could not load " << lines_[lines_id_].first;
    read_time += timer.MicroSeconds();
    timer.Start();
    // Apply transformations (mirror, crop...) to the image
    int offset = batch->data_.offset(item_id);  //获取item_id个图像数据的偏移量
    this->transformed_data_.set_cpu_data(prefetch_data + offset);   //set_cpu_data指定数据地址为prefetch_data
    this->data_transformer_->Transform(cv_img, &(this->transformed_data_)); //把cv_img数据转换到transformed_data_指定的data地址prefetch_data + offset
    trans_time += timer.MicroSeconds();

    prefetch_label[item_id] = lines_[lines_id_].second;
    // go to the next iter
    lines_id_++;
    if (lines_id_ >= lines_size) {
      // We have reached the end. Restart from the first.
      DLOG(INFO) << "Restarting data prefetching from start.";
      lines_id_ = 0;
      if (this->layer_param_.image_data_param().shuffle()) {
        ShuffleImages();
      }
    }
  }
  batch_timer.Stop();
  DLOG(INFO) << "Prefetch batch: " << batch_timer.MilliSeconds() << " ms.";
  DLOG(INFO) << "     Read time: " << read_time / 1000 << " ms.";
  DLOG(INFO) << "Transform time: " << trans_time / 1000 << " ms.";
}

INSTANTIATE_CLASS(ImageDataLayer);
REGISTER_LAYER_CLASS(ImageData);

}  // namespace caffe
#endif  // USE_OPENCV

以ImageDataLayer层的使用:

    layer {  
      name: "data"  
      type: "ImageData"  //在ImageDataLayer.hpp中的type函数定义
      top: "data"  //由这两个top可知,ImageDataLayer会定义top[0]和top[1]两个输出,top[0]是数据,top[1]是label
      top: "label"  
      transform_param {  //定义了图像数据预处理的操作
        mirror: false  
        crop_size: 227  
        mean_file: "data/ilsvrc12/imagenet_mean.binaryproto"  
      }  
      image_data_param {  //这是我们要定义的source
        source: "examples/_temp/file_list.txt"  
        batch_size: 50  
        new_height: 256  
        new_width: 256  
      }  
    }  

你可能感兴趣的:(caffe学习)