SSD 源码(1)

训练数据加载,与读取:主要由 AnnotatedDataLayer 承担

其构造函数如下:

template type>
AnnotatedDataLayertype>::AnnotatedDataLayer(const LayerParameter& param)
  : BasePrefetchingDataLayertype>(param),
    reader_(param) {
}

构造函数很简单,但是,创建了2个主要的数据读取线程:
1#基础数据(LMDB)读取线程
2#目标数据读取与分析线程。

第一个目标由成员变量reader_实现,第二个目标结合BasePrefetchingDataLayer层实现。为了了解实现原理,先梳理一下他们的类之间的继承关系。

SSD 源码(1)_第1张图片

如上图所示,class IntrnalThread 中:

  • StartInternalThread()函数线程开启
  • InternalThreadEntry()具体实现线程功能,此函数为虚函数,根据需要进行重写。

class DataReader 中的成员变量body,重写了InternalThreadEntry()函数,实现基本数据读取功能。
class BasePrefetchingDataLayer 重写了InternalThreadEntry()函数,实现标注数据信息的读取功能。AnnotatedDataLayer类中的load_batch()函数,重写了数据加载功能。

template <typename Dtype>
void AnnotatedDataLayer::DataLayerSetUp(
    const vector*>& bottom, const vector*>& top) {
  const int batch_size = this->layer_param_.data_param().batch_size();
  const AnnotatedDataParameter& anno_data_param =
      this->layer_param_.annotated_data_param();
  for (int i = 0; i < anno_data_param.batch_sampler_size(); ++i) {
    batch_samplers_.push_back(anno_data_param.batch_sampler(i));
  }
  label_map_file_ = anno_data_param.label_map_file();

  // Read a data point, and use it to initialize the top blob.
  AnnotatedDatum& anno_datum = *(reader_.full().peek());

  // Use data_transformer to infer the expected blob shape from anno_datum.
  vector<int> top_shape =
      this->data_transformer_->InferBlobShape(anno_datum.datum());
  this->transformed_data_.Reshape(top_shape);
  // Reshape top[0] and prefetch_data according to the batch_size.
  top_shape[0] = batch_size;
  top[0]->Reshape(top_shape);
  for (int i = 0; i < this->PREFETCH_COUNT; ++i) {
    this->prefetch_[i].data_.Reshape(top_shape);
  }
  LOG(INFO) << "output data size: " << top[0]->num() << ","
      << top[0]->channels() << "," << top[0]->height() << ","
      << top[0]->width();
  // label
  if (this->output_labels_) {
    has_anno_type_ = anno_datum.has_type();
    vector<int> label_shape(4, 1);
    if (has_anno_type_) {
      anno_type_ = anno_datum.type();
      // Infer the label shape from anno_datum.AnnotationGroup().
      int num_bboxes = 0;
      if (anno_type_ == AnnotatedDatum_AnnotationType_BBOX) {
        // Since the number of bboxes can be different for each image,
        // we store the bbox information in a specific format. In specific:
        // All bboxes are stored in one spatial plane (num and channels are 1)
        // And each row contains one and only one box in the following format:
        // [item_id, group_label, instance_id, xmin, ymin, xmax, ymax, diff]
        // Note: Refer to caffe.proto for details about group_label and
        // instance_id.
        for (int g = 0; g < anno_datum.annotation_group_size(); ++g) {
          num_bboxes += anno_datum.annotation_group(g).annotation_size();
        }
        label_shape[0] = 1;
        label_shape[1] = 1;
        // BasePrefetchingDataLayer::LayerSetUp() requires to call
        // cpu_data and gpu_data for consistent prefetch thread. Thus we make
        // sure there is at least one bbox.
        label_shape[2] = std::max(num_bboxes, 1);
        label_shape[3] = 8;
      } else {
        LOG(FATAL) << "Unknown annotation type.";
      }
    } else {
      label_shape[0] = batch_size;
    }
    top[1]->Reshape(label_shape);
    for (int i = 0; i < this->PREFETCH_COUNT; ++i) {
      this->prefetch_[i].label_.Reshape(label_shape);
    }
  }
}

// This function is called on prefetch thread
template<typename Dtype>
void AnnotatedDataLayer::load_batch(Batch* batch) {
  CPUTimer batch_timer;
  batch_timer.Start();
  double read_time = 0;
  double trans_time = 0;
  CPUTimer timer;
  CHECK(batch->data_.count());
  CHECK(this->transformed_data_.count());

  // Reshape according to the first anno_datum of each batch
  // on single input batches allows for inputs of varying dimension.
  const int batch_size = this->layer_param_.data_param().batch_size();
  AnnotatedDatum& anno_datum = *(reader_.full().peek()); //加载一副图像
  // Use data_transformer to infer the expected blob shape from anno_datum.
  vector<int> top_shape =
      this->data_transformer_->InferBlobShape(anno_datum.datum());
  this->transformed_data_.Reshape(top_shape);
  // Reshape batch according to the batch_size.
  top_shape[0] = batch_size;
  batch->data_.Reshape(top_shape);

  Dtype* top_data = batch->data_.mutable_cpu_data();
  Dtype* top_label = NULL;  // suppress warnings about uninitialized variables
  if (this->output_labels_ && !has_anno_type_) {
    top_label = batch->label_.mutable_cpu_data();
  }

  // Store transformed annotation.
  map<int, vector > all_anno;  //// 
  int num_bboxes = 0;

  for (int item_id = 0; item_id < batch_size; ++item_id) {

      /*1 获取 anno_datum*/
    timer.Start();
    // get a anno_datum
    AnnotatedDatum& anno_datum = *(reader_.full().pop("Waiting for data"));
    read_time += timer.MicroSeconds();
    timer.Start();

    /*2 生成变换后的  sampled_datum */
    // 按照batch_samplers_, 随机生成sampled_datum:尺寸、位置都会变化,然后变尺寸至300*300
    AnnotatedDatum sampled_datum;
    if (batch_samplers_.size() > 0) {
      // Generate sampled bboxes from anno_datum.
        //anno_datum 里面到底有几个图,为什么有gannotation_group:anno_datum代表一副图像,gannotation_group代表类别种类,同一类为一个group
      vector sampled_bboxes;
      GenerateBatchSamples(anno_datum, batch_samplers_, &sampled_bboxes);
      if (sampled_bboxes.size() > 0) {
        // Randomly pick a sampled bbox and crop the anno_datum.
        int rand_idx = caffe_rng_rand() % sampled_bboxes.size();
        this->data_transformer_->CropImage(anno_datum, sampled_bboxes[rand_idx],
                                           &sampled_datum);
      } else {
        sampled_datum.CopyFrom(anno_datum);
      }
    } else {
      sampled_datum.CopyFrom(anno_datum);
    }

    /*3 对图像进行变换 尺度、反转、剪切等操作 生成 transformed_anno_vec*/
    //见 网络参数:transform_param
    // Apply data transformations (mirror, scale, crop...)
    int offset = batch->data_.offset(item_id);
    this->transformed_data_.set_cpu_data(top_data + offset);
    vector transformed_anno_vec;
    if (this->output_labels_) {
      if (has_anno_type_) {
        // Make sure all data have same annotation type.
        CHECK(sampled_datum.has_type()) << "Some datum misses AnnotationType.";
        CHECK_EQ(anno_type_, sampled_datum.type()) <<
            "Different AnnotationType.";
        // Transform datum and annotation_group at the same time
        transformed_anno_vec.clear();
        //1、对数据进行变换
        //2、对标记结果进行变换
        this->data_transformer_->Transform(sampled_datum,
                                           &(this->transformed_data_),
                                           &transformed_anno_vec);
        if (anno_type_ == AnnotatedDatum_AnnotationType_BBOX) {
          // Count the number of bboxes.
          for (int g = 0; g < transformed_anno_vec.size(); ++g) {
            num_bboxes += transformed_anno_vec[g].annotation_size();
          }
        } else {
          LOG(FATAL) << "Unknown annotation type.";
        }
        all_anno[item_id] = transformed_anno_vec;
      } else {
        this->data_transformer_->Transform(sampled_datum.datum(),
                                           &(this->transformed_data_));
        // Otherwise, store the label from datum.
        CHECK(sampled_datum.datum().has_label()) << "Cannot find any label.";
        top_label[item_id] = sampled_datum.datum().label();
      }
    } else {
      this->data_transformer_->Transform(sampled_datum.datum(),
                                         &(this->transformed_data_));
    }
    trans_time += timer.MicroSeconds();

    reader_.free().push(const_cast(&anno_datum));
  }



  ///保存最终的目标结果
  // Store "rich" annotation if needed.
  if (this->output_labels_ && has_anno_type_) {
    vector<int> label_shape(4);
    if (anno_type_ == AnnotatedDatum_AnnotationType_BBOX) {
      label_shape[0] = 1;
      label_shape[1] = 1;
      label_shape[3] = 8;
      if (num_bboxes == 0) {
        // Store all -1 in the label.
        label_shape[2] = 1;
        batch->label_.Reshape(label_shape);
        caffe_set(8, -1, batch->label_.mutable_cpu_data());
      } else {
        // Reshape the label and store the annotation.
        label_shape[2] = num_bboxes;
        batch->label_.Reshape(label_shape);
        top_label = batch->label_.mutable_cpu_data();
        int idx = 0;
        for (int item_id = 0; item_id < batch_size; ++item_id) {
          const vector& anno_vec = all_anno[item_id];
          for (int g = 0; g < anno_vec.size(); ++g) {
            const AnnotationGroup& anno_group = anno_vec[g];
            for (int a = 0; a < anno_group.annotation_size(); ++a) {
              const Annotation& anno = anno_group.annotation(a);
              const NormalizedBBox& bbox = anno.bbox();
              top_label[idx++] = item_id;                   //属于batch中的图像序号
              top_label[idx++] = anno_group.group_label();  //一个图像上,同一个类别的目标最为一个group
              top_label[idx++] = anno.instance_id();        //一个group内的目标序号
              top_label[idx++] = bbox.xmin();
              top_label[idx++] = bbox.ymin();
              top_label[idx++] = bbox.xmax();
              top_label[idx++] = bbox.ymax();
              top_label[idx++] = bbox.difficult();          //是否为困难样本 bool型
            }
          }
        }
      }
    } else {
      LOG(FATAL) << "Unknown annotation type.";
    }
  }
  timer.Stop();
  batch_timer.Stop();
  DLOG(INFO) << "Prefetch batch: " << batch_timer.MilliSeconds() << " ms.";
  DLOG(INFO) << "     Read time: " << read_time / 1000 << " ms.";
  DLOG(INFO) << "Transform time: " << trans_time / 1000 << " ms.";
}

你可能感兴趣的:(深度学习,SSD算法)