Xgboost C++预测模块线程安全修复

1 背景

Xgboost在各种排序场景中有广泛的应用,离线训练一般在Spark平台或者单机环境执行。训练好的模型用到线上预测时一般要根据自己的环境重新开发预测代码,例如,如果时Java环境,则需要用Java开发预测代码。主要原因是Xgboost提供的预测模块不支持多线程,本文介绍如何修改C++代码,使其符合线上预测要求。下面将根据实际探索过程逐步介绍如何将xgboost4j的Java接口修改成符合线上应用的多线程程序。

2 具体修改过程

2.1 Java接口去处锁

结果测试发现Java的预测接口加锁了,查看发现加锁位置为如下的synchronized

 private **synchronized** float[][] predict(DMatrix data,
                                         boolean outputMargin,
                                         int treeLimit,
                                         boolean predLeaf,
                                         boolean predContribs) throws XGBoostError {
    int optionMask = 0;
    if (outputMargin) {
      optionMask = 1;
    }
    if (predLeaf) {
      optionMask = 2;
    }
    if (predContribs) {
      optionMask = 4;
    }
    float[][] rawPredicts = new float[1][];
    XGBoostJNI.checkCall(XGBoostJNI.XGBoosterPredict(handle, data.getHandle(), optionMask,
            treeLimit, rawPredicts));
    int row = (int) data.rowNum();
    int col = rawPredicts[0].length / row;
    float[][] predicts = new float[row][col];
    int r, c;
    for (int i = 0; i < rawPredicts[0].length; i++) {
      r = i / col;
      c = i % col;
      predicts[r][c] = rawPredicts[0][i];
    }
    return predicts;
  }
  • 首先将其去掉,不去掉是无法实现多线程调用的
  • 为什么人家需要加锁,猜测是不是C++中的代码线程不安全
  • 不管怎么说先去掉,测试
  • 测试出core,果然C++代码线程不安全

2.2 修改C++线程不安全代码-1

结果查看源码,C++中的预测方法调用关系如下
Xgboost C++预测模块线程安全修复_第1张图片
主要关心PredLoopSpecalize方法,该方法如下,在该方法中有一行语句InitThreadTemp(nthread, model.param.num_feature);,这是为了openmp运行时分配用于每次线程的存储变量thread_temp的存储空间的,thread_temp是类的成员变量,当多线程调用时,相当于是一个公共的资源被不同线程修改了,即多写问题,修改后的代码如下,将thread_temp改为方法内部的局部变量就可以了。结果这样的修改,确实不出core了,可是另一个问题又来了,见下一节。

 inline void PredLoopSpecalize(DMatrix* p_fmat,
                                std::vector* out_preds,
                                const gbm::GBTreeModel& model, int num_group,
                                unsigned tree_begin, unsigned tree_end) {
    const MetaInfo& info = p_fmat->info();
    const int nthread = omp_get_max_threads();
    //原始的线程不安全方法
    //InitThreadTemp(nthread, model.param.num_feature);
      /*
       * 为了处理多线程问题,此处把thread_temp全局变量修改为局部变量local_thread_temp
       */
      std::vector local_thread_temp;
      InitThreadTemp(nthread, model.param.num_feature, local_thread_temp);
      //============
    std::vector& preds = *out_preds;
    CHECK_EQ(model.param.size_leaf_vector, 0)
        << "size_leaf_vector is enforced to 0 so far";
    CHECK_EQ(preds.size(), p_fmat->info().num_row * num_group);
    // start collecting the prediction
    dmlc::DataIter* iter = p_fmat->RowIterator();
    iter->BeforeFirst();
      int b = 0;
    while (iter->BatchNext()) {
        b=1;
      const RowBatch& batch = iter->Value();
      // parallel over local batch
      const int K = 8;
      const bst_omp_uint nsize = static_cast(batch.size);
      const bst_omp_uint rest = nsize % K;
#pragma omp parallel for schedule(static)
      for (bst_omp_uint i = 0; i < nsize - rest; i += K) {
        const int tid = omp_get_thread_num();
        RegTree::FVec& feats = local_thread_temp[tid];

        int64_t ridx[K];
        RowBatch::Inst inst[K];
        for (int k = 0; k < K; ++k) {
          ridx[k] = static_cast(batch.base_rowid + i + k);
        }
        for (int k = 0; k < K; ++k) {
          inst[k] = batch[i + k];
        }
        for (int k = 0; k < K; ++k) {
          for (int gid = 0; gid < num_group; ++gid) {
            const size_t offset = ridx[k] * num_group + gid;
            preds[offset] += this->PredValue(
                inst[k], model.trees, model.tree_info, gid,
                info.GetRoot(ridx[k]), &feats, tree_begin, tree_end);
          }
        }
      }
      for (bst_omp_uint i = nsize - rest; i < nsize; ++i) {
        RegTree::FVec& feats = local_thread_temp[0];
        const int64_t ridx = static_cast(batch.base_rowid + i);
        const RowBatch::Inst inst = batch[i];
        for (int gid = 0; gid < num_group; ++gid) {
          const size_t offset = ridx * num_group + gid;
          preds[offset] +=
              this->PredValue(inst, model.trees, model.tree_info, gid,
                              info.GetRoot(ridx), &feats, tree_begin, tree_end);
        }
      }
        break;
    }
  }

2.3 修改C++线程不安全代码-2

经过2.2的修改,确实不出core了,但是当多次运行测试时发现,有时候预测的结果全是0,这又是为什么。出问题的代码部分如下

inline void PredLoopSpecalize(DMatrix* p_fmat,
                                std::vector<bst_float>* out_preds,
                                const gbm::GBTreeModel& model, int num_group,
                                unsigned tree_begin, unsigned tree_end) {
..........................
//======出问题的部分是如下两行=======
iter->BeforeFirst();
    while (iter->Next()) {
    }
//===============================
    ......................
    }

代码中的iter是SimpleCSRSource类的迭代器,在该类中实现了BeforeFirst和Next,如下。从这两个方法实现上可以看到,其用成员变量at_first作为判断标记,但是在这两个方法调用中间并不是原子行为,当多个线程调用时,会出现如下状态变化,导致thread_2在Next中判断到at_first_=false不执行预测导致退出
Xgboost C++预测模块线程安全修复_第2张图片

void SimpleCSRSource::BeforeFirst() {
  at_first_ = true;
}

bool SimpleCSRSource::Next() {
  if (!at_first_) return false;
  at_first_ = false;
  batch_.size = row_ptr_.size() - 1;
  batch_.base_rowid = 0;
  batch_.ind_ptr = dmlc::BeginPtr(row_ptr_);
  batch_.data_ptr = dmlc::BeginPtr(row_data_);
  return true;
}

解决方法如下
- 在DataIter中添加虚方法virtual bool BatchNext(void){ return false;};

class DataIter {
 public:
  /*! \brief destructor */
  virtual ~DataIter(void) {}
  /*! \brief set before first of the item */
  virtual void BeforeFirst(void) = 0;
  /*! \brief move to next item */
  virtual bool Next(void) = 0;
    virtual bool BatchNext(void){ return false;};
  /*! \brief get current data */
  virtual const DType &Value(void) const = 0;
};
  • 在SimpleCSRSource中实现BatchNext如下
bool SimpleCSRSource::BatchNext(){
        //LOG(CONSOLE)<<"SimpleCSRSource::Next()"<<"\n";
        //if (!at_first_) return false;
        //at_first_ = false;
        batch_.size = row_ptr_.size() - 1;
        batch_.base_rowid = 0;
        batch_.ind_ptr = dmlc::BeginPtr(row_ptr_);
        batch_.data_ptr = dmlc::BeginPtr(row_data_);
        return true;
    }
  • 将上面出问题的那两行改为
iter->BeforeFirst();
    while (iter->BatchNext()) {

2.4 修改C++线程不安全代码-3

经过上面两处修改总算把问题解决了,好了压力测试吧。仿佛启动多线程测试,有时在刚开始预测时会出core,一旦运行起来反而没问题。多事之秋,改别人的源代码就是麻烦,经过定位(定位过程略,整了一两天)发现问题如下,在c_api.cc的XGBoosterPredict中存在线程不安全问题,由bst->LazyInit();引起的。

XGB_DLL int XGBoosterPredict(BoosterHandle handle,
                             DMatrixHandle dmat,
                             int option_mask,
                             unsigned ntree_limit,
                             xgboost::bst_ulong *len,
                             const bst_float **out_result) {

  std::vector& preds = XGBAPIThreadLocalStore::Get()->ret_vec_float;
  API_BEGIN();
  Booster *bst = static_cast(handle);
  bst->LazyInit();
  bst->learner()->Predict(
      static_cast<std::shared_ptr*>(dmat)->get(),
      (option_mask & 1) != 0,
      &preds, ntree_limit,
      (option_mask & 2) != 0,
      (option_mask & 4) != 0,
      (option_mask & 8) != 0);
  *out_result = dmlc::BeginPtr(preds);
  *len = static_cast(preds.size());
  API_END();
}

下面是原始的LazyInit实现方式,一看就存在多线程问题,具体分析就不说了,这个简单,解决方法就是加锁呀,因为一旦有一个线程正常初始化一次,其它线程就安全了。加锁加锁快加锁哦

inline void LazyInit() {
      if (!configured_) {
        learner_->Configure(cfg_);
        configured_ = true;
      }
      if (!initialized_) {
        learner_->InitModel();
        initialized_ = true;
      }
  }

加锁完了的代码如下

inline void LazyInit() {
    //为了线程安全加锁
    if(!configured_ || !initialized_){
      pthread_mutex_lock(&lock_);
      if (!configured_) {
        learner_->Configure(cfg_);
        configured_ = true;
      }
      if (!initialized_) {
        learner_->InitModel();
        initialized_ = true;
      }
      pthread_mutex_unlock(&lock_);
    }

  }

2.5 哈哈哈

经过上面的修改,目前反复启动+随机多线程压测,跑几天了还没发现问题,吼吼吼。

3 总结

多线程问题一定要注意如下几点
- 避免公共资源多写问题,如果出现线程问题应该第一个想到这个。
- 注意空指针问题,在释放指针之前要判断是否为空。
- 类总如果有指针成员变量,且允许类的复制,一定要实现拷贝构造函数和重载赋值操作符,否则应该禁止类的对象拷贝(具体实现就不介绍了)。

4 吐槽

辜负了我对Xgboost代码实现者的崇拜了,很多地方设计的都不合理,预测方法中一点多线程都不考虑,哎哎。(不要说人家就不是为线上服务设计的,毕竟大牛写的呀)

5 后记

C++是改好了,怎么打包到jar包中哪,如何实现C++的夸平台哪,详见下一篇:http://blog.csdn.net/zc02051126/article/details/79427613

你可能感兴趣的:(机器学习计算工具)