【caffe】修改SSD的solver输出recall和precision

我们可以从SSD的caffe源码中得到test的mAP,mAP是不同阈值下的precision均值,但如果我们想看某个阈值下的recallprecision时,就需要对solver.cpp源码做一定修改。
关于mAP, recall和precision的解释这里不赘述,可以参考以下博客:
中文介绍
英文介绍

修改caffe.proto

首先在src/caffe/proto/caffe.proto中的SolverParameter这个message下加上一个参数rec_prec_thr,该参数是判断样本是否为true positive (tp) 的score阈值,我们给他一个默认值0.6,代码如下(注意序列号在自己的SolverParameter最后的序列号上加1)

optional float rec_prec_thr = 45 [default = 0.6];

修改solver.cpp

接下来在src/caffe/solver.cpp的void Solver::TestDetection(const int test_net_id)函数中加入计算recall和precision的代码:

template <typename Dtype>
void Solver::TestDetection(const int test_net_id) {
  CHECK(Caffe::root_solver());
  LOG(INFO) << "Iteration " << iter_
            << ", Testing net (#" << test_net_id << ")";
  CHECK_NOTNULL(test_nets_[test_net_id].get())->
      ShareTrainedLayersWith(net_.get());
  map<int, map<int, vectorfloat, int> > > > all_true_pos;
  map<int, map<int, vectorfloat, int> > > > all_false_pos;
  map<int, map<int, int> > all_num_pos;
  const shared_ptr >& test_net = test_nets_[test_net_id];
  Dtype loss = 0;
  for (int i = 0; i < param_.test_iter(test_net_id); ++i) {
    SolverAction::Enum request = GetRequestedAction();
    // Check to see if stoppage of testing/training has been requested.
    while (request != SolverAction::NONE) {
        if (SolverAction::SNAPSHOT == request) {
          Snapshot();
        } else if (SolverAction::STOP == request) {
          requested_early_exit_ = true;
        }
        request = GetRequestedAction();
    }
    if (requested_early_exit_) {
      // break out of test loop.
      break;
    }

    Dtype iter_loss;
    const vector*>& result = test_net->Forward(&iter_loss);
    if (param_.test_compute_loss()) {
      loss += iter_loss;
    }
    for (int j = 0; j < result.size(); ++j) {
      CHECK_EQ(result[j]->width(), 5);
      const Dtype* result_vec = result[j]->cpu_data();
      int num_det = result[j]->height();
      for (int k = 0; k < num_det; ++k) {
        int item_id = static_cast<int>(result_vec[k * 5]);
        int label = static_cast<int>(result_vec[k * 5 + 1]);
        if (item_id == -1) {
          // Special row of storing number of positives for a label.
          if (all_num_pos[j].find(label) == all_num_pos[j].end()) {
            all_num_pos[j][label] = static_cast<int>(result_vec[k * 5 + 2]);
          } else {
            all_num_pos[j][label] += static_cast<int>(result_vec[k * 5 + 2]);
          }
        } else {
          // Normal row storing detection status.
          float score = result_vec[k * 5 + 2];
          int tp = static_cast<int>(result_vec[k * 5 + 3]);
          int fp = static_cast<int>(result_vec[k * 5 + 4]);
          if (tp == 0 && fp == 0) {
            // Ignore such case. It happens when a detection bbox is matched to
            // a difficult gt bbox and we don't evaluate on difficult gt bbox.
            continue;
          }
          all_true_pos[j][label].push_back(std::make_pair(score, tp));
          all_false_pos[j][label].push_back(std::make_pair(score, fp));
        }
      }
    }
  }
  if (requested_early_exit_) {
    LOG(INFO)     << "Test interrupted.";
    return;
  }
  if (param_.test_compute_loss()) {
    loss /= param_.test_iter(test_net_id);
    LOG(INFO) << "Test loss: " << loss;
  }
  for (int i = 0; i < all_true_pos.size(); ++i) {
    if (all_true_pos.find(i) == all_true_pos.end()) {
      LOG(FATAL) << "Missing output_blob true_pos: " << i;
    }
    const map<int, vectorfloat, int> > >& true_pos =
        all_true_pos.find(i)->second;
    if (all_false_pos.find(i) == all_false_pos.end()) {
      LOG(FATAL) << "Missing output_blob false_pos: " << i;
    }
    const map<int, vectorfloat, int> > >& false_pos =
        all_false_pos.find(i)->second;
    if (all_num_pos.find(i) == all_num_pos.end()) {
      LOG(FATAL) << "Missing output_blob num_pos: " << i;
    }
    const map<int, int>& num_pos = all_num_pos.find(i)->second;
    map<int, float> APs;
    map<int, float> recalls;// 各个类别的recall
    map<int, float> precisions;// 各个类别的precision
    float mAP = 0.;
    // Sort true_pos and false_pos with descend scores.
    for (map<int, int>::const_iterator it = num_pos.begin();
         it != num_pos.end(); ++it) {
      int label = it->first;
      int label_num_pos = it->second;
      if (true_pos.find(label) == true_pos.end()) {
        LOG(WARNING) << "Missing true_pos for label: " << label;
        continue;
      }
      const vectorfloat, int> >& label_true_pos =
          true_pos.find(label)->second;
      if (false_pos.find(label) == false_pos.end()) {
        LOG(WARNING) << "Missing false_pos for label: " << label;
        continue;
      }
      const vectorfloat, int> >& label_false_pos =
          false_pos.find(label)->second;
      vector<float> prec, rec;
      ComputeAP(label_true_pos, label_num_pos, label_false_pos,
                param_.ap_version(), &prec, &rec, &(APs[label]));
      mAP += APs[label];

      // 在阈值下计算recall和precision,并打印出来
      float thr = param_.rec_prec_thr(); // 可在solver中定义的阈值参数
      int tp_sum = 0; // true positive的总数
      int fp_sum = 0; // false positive的总数
      for(int i = 0; i < label_true_pos.size(); ++i) {// 计算tp
          if(label_true_pos[i].first > thr) {
              tp_sum += label_true_pos[i].second;
          }
      }
      recalls[label] = static_cast<float>(tp_sum) / label_num_pos;
      for(int i = 0; i < label_false_pos.size(); ++i) {// 计算fp
          if(label_false_pos[i].first > thr) {
              fp_sum += label_false_pos[i].second;
          }
      }
      precisions[label] = static_cast<float>(tp_sum) / (tp_sum + fp_sum);

      if (param_.show_per_class_result()) {
        LOG(INFO) << "class" << label << ": " << APs[label];
      }
    }
    mAP /= num_pos.size();
    const int output_blob_index = test_net->output_blob_indices()[i];
    const string& output_name = test_net->blob_names()[output_blob_index];
    LOG(INFO) << "    Test net output #" << i << ": " << output_name << " = "
              << mAP;
    // 打印出所有类别在一定阈值下的recall
    LOG(INFO) << "-------------recalls-----------";
    for(map<int, float>::const_iterator it = recalls.begin();
        it != recalls.end(); ++it) {
        int label = it->first;
        float recall = it->second;
        LOG(INFO) << "class" << label << ": " << recall;
    }
    LOG(INFO) << "-------------recalls-----------";

    // 打印出所有类别在一定阈值下的precision
    LOG(INFO) << "-------------precisions-----------";
    for(map<int, float>::const_iterator it = precisions.begin();
        it != precisions.end(); ++it) {
        int label = it->first;
        float precision = it->second;
        LOG(INFO) << "class" << label << ": " << precision;
    }
    LOG(INFO) << "-------------precisions-----------";
  }
}

重新编译caffe即可

make clean
make all

在使用时通过在solver.prototxt中设定rec_prec_thr的值来调整阈值。

你可能感兴趣的:(深度学习,caffe)