caffe中DataTransformer类方法Transform中data_index的计算原理

问题抛出:

int top_index, data_index;
  for (int c = 0; c < datum_channels; ++c) {
    for (int h = 0; h < height; ++h) {
      for (int w = 0; w < width; ++w) {
        //data_index是如何计算的?
        data_index = (c * datum_height + h_off + h) * datum_width + w_off + w;
        if (do_mirror) {
          top_index = (c * height + h) * width + (width - 1 - w);
        } else {
          top_index = (c * height + h) * width + w;
        }
        if (has_uint8) {
          datum_element =
            static_cast(static_cast(data[data_index]));
        } else {
          datum_element = datum.float_data(data_index);
        }
        if (has_mean_file) {
          transformed_data[top_index] =
            (datum_element - mean[data_index]) * scale;
        } else {
          if (has_mean_values) {
            transformed_data[top_index] =
              (datum_element - mean_values_[c]) * scale;
          } else {
            transformed_data[top_index] = datum_element * scale;
          }
        }
      }
    }
  }
}

原理的分析:

首先,我们需要知道Datum的来龙去脉:

1、首先通过caffe的数据转换工具(如convert_mnist_data.cpp)将图像标签等写入数据库(LMDB等),代码如下所示:

  char label;
  char* pixels = new char[rows * cols];
  int count = 0;
  string value;

  Datum datum;
  datum.set_channels(1);
  datum.set_height(rows);
  datum.set_width(cols);
  LOG(INFO) << "A total of " << num_items << " items.";
  LOG(INFO) << "Rows: " << rows << " Cols: " << cols;
  for (int item_id = 0; item_id < num_items; ++item_id) {
    image_file.read(pixels, rows * cols);//将二进制数据流写入pixels
    label_file.read(&label, 1);
    datum.set_data(pixels, rows*cols);
    datum.set_label(label);
    string key_str = caffe::format_int(item_id, 8);
    datum.SerializeToString(&value);

    txn->Put(key_str, value);

2、然后在训练时,通过DataReader从数据库中读取数据(异步操作)

3、进而通过DataLayer的load_batch对数据进行加工处理并生成对应的batch数据。

4、在load_batch函数中调用了DataTransformer的Transform进行数据的预处理

data_index计算:

我们看如下的计算公式:
data_index = (c * datum_height + h_off + h) * datum_width + w_off + w;      

1. 当c=0时,且不考虑h_off时,就是熟悉的按行读取数据,当前数据位置索引可表示为:h * datum_width + w

2. 当考虑h_off时,就是在h基础上,偏移一个随机的裁剪位置(w_off同理),索引可表示为 (h_off + h) * datum_width + w_off + w;

3. 当c>1时,我们可以把计算公式分解一下,得到如下:

data_index = c * datum_height* datum_width  +  (h_off + h) * datum_width + w_off + w

其中绿色字体部分与情况2相同,表示当前通道的索引,蓝色字体部分表示前面所有通道的总体索引数。

 

你可能感兴趣的:(caffe,c++)