secretflow推理服务源码解读

secretflow-serving(https://github.com/secretflow/serving)是隐语提供的一套aby3的推理服务,代码量只有clickhouse的百分之一(一万行不到),但是麻雀虽小,五脏俱全,有模型加载和推理的整套流程,还结合Prometheus实现了监控服务。
secretflow-serving使用了C++17,代码也写的很清晰易懂,本文就结合它的架构解读一下它的源码,因为笔者并非机器学习专业人士,有错误之处希望读者不吝指教。

架构

secretflow-serving的整体架构,分为启动和预测服务两个阶段。启动主要是读取模型并且开启了brpc的几个rpc服务,具体可见(https://brpc.apache.org/zh/docs/server/serve-grpc/)。

启动阶段

  • Source Adapter: 作为模型存储适配器,负责与不同的模型存储接口(如http、dm、filesystem)进行对接。
  • Model Loader: 扮演模型加载器的角色,依据模型类型以其对应格式加载模型。
  • Executable: 模型执行框架,其接口服务于RPC,不同的模型加载器将生成不同类型的Executable。

预测阶段

  • Prediction Service: 作为预测RPC请求的入口。
  • Scheduler: 根据配置对上游请求进行批处理。
  • Prediction Controller: 转换预测RPC请求,组织并向各参与节点发送预测执行请求。
  • Execution Service: 作为模型执行RPC请求的入口。
  • Feature Adapter: 对接不同的特征服务SPI,负责获取特征值并应用特征映射规则(在线特征映射到输入模型的特征)。
  • Executable: 接收模型执行请求(executable->Run()),并返回预测分数。

启动阶段

入口函数

入口函数在https://github.com/secretflow/serving/blob/475bb3356e3a246f444fc25f62f9619874870680/secretflow_serving/server/main.cc#L46
这里用absl和gflags处理输入参数。
OpFactory 是一个Singleton,secretflow在这里实现了一个标准的Meyers’ Singleton(https://github.com/secretflow/serving/blob/475bb3356e3a246f444fc25f62f9619874870680/secretflow_serving/core/singleton.h#L20):

    {
      auto op_def_list =
          secretflow::serving::op::OpFactory::GetInstance()->GetAllOps();
      std::vector<std::string> op_names;
      std::for_each(
          op_def_list.begin(), op_def_list.end(),
          [&](const std::shared_ptr<const secretflow::serving::op::OpDef>& o) {
            op_names.emplace_back(o->name());
          });

      SPDLOG_INFO("op list: {}",
                  fmt::join(op_names.begin(), op_names.end(), ", "));
    }

OpFactory在https://github.com/secretflow/serving/blob/475bb3356e3a246f444fc25f62f9619874870680/secretflow_serving/ops/op_factory.h。

class OpFactory final : public Singleton<OpFactory> {
 public:
  void Register(const std::shared_ptr<OpDef>& op_def) {
    std::lock_guard<std::mutex> lock(mutex_);
    SERVING_ENFORCE(op_defs_.emplace(op_def->name(), op_def).second,
                    errors::ErrorCode::LOGIC_ERROR,
                    "duplicated op_def registered for {}", op_def->name());
  }

  const std::shared_ptr<OpDef> Get(const std::string& name) {
    std::lock_guard<std::mutex> lock(mutex_);
    auto iter = op_defs_.find(name);
    SERVING_ENFORCE(iter != op_defs_.end(), errors::ErrorCode::UNEXPECTED_ERROR,
                    "no op_def registered for {}", name);
    return iter->second;
  }

  std::vector<std::shared_ptr<const OpDef>> GetAllOps() {
    std::vector<std::shared_ptr<const OpDef>> result;

    std::lock_guard<std::mutex> lock(mutex_);
    for (const auto& pair : op_defs_) {
      result.emplace_back(pair.second);
    }
    return result;
  }

 private:
  std::unordered_map<std::string, std::shared_ptr<OpDef>> op_defs_;
  std::mutex mutex_;
};

我们可以看到,OP使用REGISTER_OP静态注册。并且在执行图节点被使用(见https://github.com/secretflow/serving/blob/475bb3356e3a246f444fc25f62f9619874870680/secretflow_serving/ops/node.cc#L23)

#define REGISTER_OP(op_name, version, desc)     \
  static OpRegister const regist_op_##op_name = \
      OpRegister{} << internal::OpDefBuilderWrapper(#op_name, version, desc)

之后读取了server的各种参数,然后开始启动流程。
整体代码如下:


// @hint 入口函数
int main(int argc, char* argv[]) {
    // Initialize the symbolizer to get a human-readable stack trace
    // 这里用absl和gflags处理输入参数
    absl::InitializeSymbolizer(argv[0]);

    gflags::SetVersionString(SERVING_VERSION_STRING);
    gflags::AllowCommandLineReparsing();
    gflags::ParseCommandLineFlags(&argc, &argv, true);

    try {
        // init logger
        secretflow::serving::LoggingConfig log_config;
        if (!FLAGS_logging_config_file.empty()) {
            secretflow::serving::LoadPbFromJsonFile(FLAGS_logging_config_file,
                &log_config);
        }
        secretflow::serving::SetupLogging(log_config);

        SPDLOG_INFO("version: {}", SERVING_VERSION_STRING);

        {
            // OpFactory 是一个Meyers' Singleton
            // OP使用REGISTER_OP静态注册
            auto op_def_list =
            secretflow::serving::op::OpFactory::GetInstance()->GetAllOps();
            std::vector<std::string> op_names;
            std::for_each(
                op_def_list.begin(), op_def_list.end(),
                [&](const std::shared_ptr<const secretflow::serving::op::OpDef>& o) {
                    op_names.emplace_back(o->name());
                });

            SPDLOG_INFO("op list: {}",
                fmt::join(op_names.begin(), op_names.end(), ", "));
        }

        STRING_EMPTY_VALIDATOR(FLAGS_serving_config_file);

        // init server options
        secretflow::serving::Server::Options server_opts;
        if (FLAGS_config_mode == "kuscia") {
            secretflow::serving::kuscia::KusciaConfigParser config_parser(
            FLAGS_serving_config_file);
            server_opts.server_config = config_parser.server_config();
            server_opts.cluster_config = config_parser.cluster_config();
            server_opts.model_config = config_parser.model_config();
            server_opts.feature_source_config = config_parser.feature_config();
            server_opts.service_id = config_parser.service_id();
        } else {
            secretflow::serving::ServingConfig serving_conf;
            LoadPbFromJsonFile(FLAGS_serving_config_file, &serving_conf);

            server_opts.server_config = serving_conf.server_conf();
            server_opts.cluster_config = serving_conf.cluster_conf();
            server_opts.model_config = serving_conf.model_conf();
            if (serving_conf.has_feature_source_conf()) {
                server_opts.feature_source_config = serving_conf.feature_source_conf();
            }
            server_opts.service_id = serving_conf.id();
        }
    	// 启动服务
        secretflow::serving::Server server(std::move(server_opts));
        server.Start();
    	// 运行直到brpc服务结束
        server.WaitForEnd();
    } catch (const secretflow::serving::Exception& e) {
        // TODO: custom status sink
        SPDLOG_ERROR("server startup failed, code: {}, msg: {}, stack: {}",
            e.code(), e.what(), e.stack_trace());
        return -1;
    } catch (const std::exception& e) {
        // TODO: custom status sink
        SPDLOG_ERROR("server startup failed, msg:{}", e.what());
        return -1;
    }

    return 0;
}

模型定义

在介绍启动流程之前,我们先看一下secretflow中用于推理的模型是怎么定义的:
https://github.com/secretflow/serving/blob/475bb3356e3a246f444fc25f62f9619874870680/secretflow_serving/protos/bundle.proto#L37
model_bundle是一个proto定义, 包含了完整的模型信息。
GraphDef是执行图的定义,包括了一组携带数据的节点信息(NodeDef)和一组图的执行信息(ExecutionDef)。

// Represents an exported secertflow model. It consists of a GraphDef and extra
// metadata required for serving.
message ModelBundle {
  string name = 1;

  string desc = 2;

  GraphDef graph = 3;
}

// The definition of a Graph. A graph consists of a set of nodes carrying data
// and a set of executions that describes the scheduling of the graph.
message GraphDef {
  // Version of the graph
  string version = 1;

  repeated NodeDef node_list = 2;

  repeated ExecutionDef execution_list = 3;
}

我们继续看NodeDef和ExecutionDef:


// The definition of a node.
message NodeDef {
  // Must be unique among all nodes of the graph.
  string name = 1;

  // The operator name.
  string op = 2;

  // The parent node names of the node. The order of the parent nodes should
  // match the order of the inputs of the node.
  repeated string parents = 3;
	// 节点OP的属性
  // The attribute values configed in the node. Note that this should include
  // all attrs defined in the corresponding OpDef.
  map attr_values = 4;

  // The operator version.
  string op_version = 5;
}


// The value of an attribute
message AttrValue {
  oneof value {
    // INT
    int32 i32 = 1;
    int64 i64 = 2;
    // FLOAT
    float f = 3;
    double d = 4;
    // STRING
    string s = 5;
    // BOOL
    bool b = 6;
    // BYTES
    bytes by = 7;

    // Lists

    // INTS
    Int32List i32s = 11;
    Int64List i64s = 12;
    // FLOATS
    FloatList fs = 13;
    DoubleList ds = 14;
    // STRINGS
    StringList ss = 15;
    // BOOLS
    BoolList bs = 16;
    // BYTESS
    BytesList bys = 17;
  }
}

// The definition of a execution. A execution represents a subgraph within a
// graph that can be scheduled for execution in a specified pattern.
message ExecutionDef {
  // 包含运行时配置和节点
  // Represents the nodes contained in this execution. Note that
  // these node names should be findable and unique within the node
  // definitions. One node can only exist in one execution and must exist in
  // one.
  repeated string nodes = 1;

  // The runtime config of the execution.
  RuntimeConfig config = 2;
}


启动流程

启动的代码在
https://github.com/secretflow/serving/blob/475bb3356e3a246f444fc25f62f9619874870680/secretflow_serving/server/server.cc#L58

模型文件的拉取和执行图初始化

SourceFactory也是一个Singleton,初始化之后从文件拉取模型:

/*SourceFactory 初始化*/

  // get model package
  auto source = SourceFactory::GetInstance()->Create(opts_.model_config,
                                                     opts_.service_id);
  // @hint 拉取模型, channels 初始化
  // 这一步从文件读取
  auto package_path = source->PullModel();

/*PullModel代码如下*/

std::string Source::PullModel() {
  auto dst_dir = std::filesystem::path(data_dir_).append(config_.model_id());
  if (!std::filesystem::exists(dst_dir)) {
    std::filesystem::create_directories(dst_dir);
  }

  auto dst_file_path = dst_dir.append(kModelFileName);
  const auto& source_sha256 = config_.source_sha256();
  if (std::filesystem::exists(dst_file_path)) {
    if (!source_sha256.empty()) {
      if (SysUtil::CheckSHA256(dst_file_path.string(), source_sha256)) {
        return dst_file_path;
      }
    }
    SPDLOG_INFO("remove tmp model file:{}", dst_file_path.string());
    std::filesystem::remove(dst_file_path);
  }
  // OnPullModel 从oss拉取模型
  OnPullModel(dst_file_path);
  if (!source_sha256.empty()) {
    SERVING_ENFORCE(SysUtil::CheckSHA256(dst_file_path.string(), source_sha256),
                    errors::ErrorCode::IO_ERROR,
                    "model({}) sha256 check failed", config_.source_path());
  }

  return dst_file_path;
}

然后根据参与方信息 初始化rpc channel:

// build channels
  std::string self_address;
  std::vector<std::string> cluster_ids;
  // 通过channels和每一方通信
  auto channels = std::make_shared<PartyChannelMap>();
  for (const auto& party : opts_.cluster_config.parties()) {
    cluster_ids.emplace_back(party.id());
    if (party.id() == self_party_id) {
      self_address = party.listen_address().empty() ? party.address()
                                                    : party.listen_address();
      continue;
    }
    channels->emplace(
        party.id(),
        CreateBrpcChannel(
            party.address(), opts_.cluster_config.channel_desc().protocol(),
            FLAGS_enable_peers_load_balancer,
            opts_.cluster_config.channel_desc().rpc_timeout_ms() > 0
                ? opts_.cluster_config.channel_desc().rpc_timeout_ms()
                : kPeerRpcTimeoutMs,
            opts_.cluster_config.channel_desc().connect_timeout_ms() > 0
                ? opts_.cluster_config.channel_desc().connect_timeout_ms()
                : kPeerConnectTimeoutMs,
            opts_.cluster_config.channel_desc().has_tls_config()
                ? &opts_.cluster_config.channel_desc().tls_config()
                : nullptr));
  }

然后从oss拉取的文件中解压并且读取proto文件,这里我们关注load模型的过程和graph的构造函数:

  // load model package
  auto loader = std::make_unique<ModelLoader>();
  loader->Load(package_path);
  const auto& model_bundle = loader->GetModelBundle();
  Graph graph(model_bundle->graph());
// 此处model_bundle是一个proto定义, 包含了完整的模型信息
// Represents an exported secertflow model. It consists of a GraphDef and extra
// metadata required for serving.
// message ModelBundle {
//   string name = 1;
//   string desc = 2;
//   GraphDef graph = 3;
// }

/*Load方法*/

void ModelLoader::Load(const std::string& file_path) {
  SPDLOG_INFO("begin load file: {}", file_path);

  auto model_dir =
      std::filesystem::path(file_path).parent_path().append("data");
  if (std::filesystem::exists(model_dir)) {
    // remove tmp model dir
    SPDLOG_WARN("remove tmp model dir: {}", model_dir.string());
    std::filesystem::remove_all(model_dir);
  }

  // unzip package file
  try {
    SysUtil::ExtractGzippedArchive(file_path, model_dir);
  } catch (const std::exception& e) {
    std::filesystem::remove_all(file_path);
    SERVING_THROW(errors::ErrorCode::IO_ERROR,
                  "failed to extract model package {}, detail: {}", file_path,
                  e.what());
  }

  auto manifest_path =
      std::filesystem::path(model_dir).append(kManifestFileName);
  SERVING_ENFORCE(
      std::filesystem::exists(manifest_path), errors::ErrorCode::IO_ERROR,
      "can not find manifest file {}, model package file is corrupted",
      manifest_path.string());

  // load manifest
  ModelManifest manifest;
    
  // pb文件反序列化
  LoadPbFromJsonFile(manifest_path.string(), &manifest);

  auto model_file_path = model_dir.append(manifest.bundle_path());

  auto model_bundle = std::make_shared<ModelBundle>();
  if (manifest.bundle_format() == FileFormatType::FF_PB) {
    LoadPbFromBinaryFile(model_file_path.string(), model_bundle.get());
  } else if (manifest.bundle_format() == FileFormatType::FF_JSON) {
    LoadPbFromJsonFile(model_file_path.string(), model_bundle.get());
  } else {
    SERVING_THROW(errors::ErrorCode::UNEXPECTED_ERROR,
                  "found unknown bundle_format:{}",
                  FileFormatType_Name(manifest.bundle_format()));
  }
  model_bundle_ = std::move(model_bundle);

  SPDLOG_INFO("end load model bundle, name: {}, desc: {}, graph version: {}",
              model_bundle_->name(), model_bundle_->desc(),
              model_bundle_->graph().version());
}


/*Graph构造函数*/

Graph::Graph(GraphDef graph_def) : def_(std::move(graph_def)) {
  // TODO: check version

  // TODO: consider not storing def_ to avoiding multiple copies of node_defs
  // and execution_defs

  graph_view_.set_version(def_.version());
  for (auto& node : def_.node_list()) {
    NodeView view;
    *(view.mutable_name()) = node.name();
    *(view.mutable_op()) = node.op();
    *(view.mutable_op_version()) = node.op_version();
    *(view.mutable_parents()) = node.parents();
    graph_view_.mutable_node_list()->Add(std::move(view));
  }
  *(graph_view_.mutable_execution_list()) = def_.execution_list();

  // create nodes
  // 读取node,node是一个std::unordered_map> nodes_;
  for (int i = 0; i < def_.node_list_size(); ++i) {
    const auto node_name = def_.node_list(i).name();
    auto node = std::make_shared<Node>(def_.node_list(i));
    SERVING_ENFORCE(nodes_.emplace(node_name, node).second,
                    errors::ErrorCode::LOGIC_ERROR, "found duplicate node:{}",
                    node_name);
  }

  // create edges
  // 构建edge,edge只是一个顺序存放的vector
  for (const auto& [name, node] : nodes_) {
    const auto& input_nodes = node->GetInputNodeNames();
    if (input_nodes.empty()) {
      SERVING_ENFORCE(node->GetOpDef()->inputs_size() == 1,
                      errors::ErrorCode::LOGIC_ERROR,
                      "the entry op should only have one input to accept "
                      "the features, node:{}, op:{}",
                      name, node->node_def().op());
      entry_nodes_.emplace_back(node);
    }
    for (size_t i = 0; i < input_nodes.size(); ++i) {
      auto n_iter = nodes_.find(input_nodes[i]);
      SERVING_ENFORCE(n_iter != nodes_.end(), errors::ErrorCode::LOGIC_ERROR,
                      "can not found input node:{} for node:{}", input_nodes[i],
                      name);
      auto edge = std::make_shared<Edge>(n_iter->first, name, i);
      n_iter->second->AddOutEdge(edge);
      node->AddInEdge(edge);
      edges_.emplace_back(edge);
    }
  }

  // find exit node
  // exit_node 在一个执行图中只有一个
  size_t exit_node_count = 0;
  for (const auto& pair : nodes_) {
    if (pair.second->out_edges().empty()) {
      exit_node_ = pair.second;
      ++exit_node_count;
    }
  }
  SERVING_ENFORCE(!entry_nodes_.empty(), errors::ErrorCode::LOGIC_ERROR,
                  "can not found any entry node, please check graph def.");
  SERVING_ENFORCE(exit_node_count == 1, errors::ErrorCode::LOGIC_ERROR,
                  "found {} exit nodes, expect only 1 in graph",
                  exit_node_count);
  SERVING_ENFORCE(exit_node_->GetOpDef()->tag().returnable(),
                  errors::ErrorCode::LOGIC_ERROR,
                  "exit node({}) op({}) must returnable", exit_node_->GetName(),
                  exit_node_->GetOpDef()->name());

  CheckNodesReachability();
  CheckEdgeValidate();
// 这里构建了Execution
  BuildExecution();
  CheckExecutionValidate();
}

Executor的构建

模型的执行图初始化完成之后,就开始构建Executor、Executable并且用它们初始化ExecutionCore了,Executable是Executor的容器类,所以我们在这里主要看Executor的构建过程。
Executor是Execute执行的最小单位,serving预测通过Executor实现了动态调度的能力。


// 构建Executor
  std::vector<std::shared_ptr<Executor>> executors;
  for (const auto& execution : graph.GetExecutions()) {
    executors.emplace_back(std::make_shared<Executor>(execution));
  }
  ExecutionCore::Options exec_opts;
  exec_opts.id = opts_.service_id;
  exec_opts.party_id = self_party_id;

  exec_opts.executable = std::make_shared<Executable>(std::move(executors));
  if (opts_.server_config.op_exec_worker_num() > 0) {
    exec_opts.op_exec_workers_num = opts_.server_config.op_exec_worker_num();
  }

  if (!opts_.server_config.feature_mapping().empty()) {
    exec_opts.feature_mapping = {opts_.server_config.feature_mapping().begin(),
                                 opts_.server_config.feature_mapping().end()};
  }
  exec_opts.feature_source_config = opts_.feature_source_config;
// 构建ExecutionCore
  auto execution_core = std::make_shared<ExecutionCore>(std::move(exec_opts));


构建Executor的时候做了什么呢?我们继续看这里的构造函数:
首先,给每一个node 创建了一个op_kernel,和之前的op一样,op_kernel也有一个工厂类,也同样采用静态注册的方式注册op_kernel,这里注册的时候会在creators_里面放一个function用于回调。Executor执行时就会调度到op_kernel。
https://github.com/secretflow/serving/blob/475bb3356e3a246f444fc25f62f9619874870680/secretflow_serving/ops/op_kernel_factory.h#L22
Create会根据opts.op_def->name()调用不同的回调,**这里OpKernel的子类有三个:ArrowProcessing、MergeY、DotProduct,前两个都是Arrow表示的数据,后一个是Eigen表示的数据。**在构造时会构造不同的schema。


Executor::Executor(const std::shared_ptr<Execution>& execution)
    : execution_(execution) {
  // create op_kernel
  auto nodes = execution_->nodes();
  node_items_ = std::make_shared<
      std::unordered_map<std::string, std::shared_ptr<NodeItem>>>();

  for (const auto& [node_name, node] : nodes) {
    op::OpKernelOptions ctx{node->node_def(), node->GetOpDef()};

    auto item = std::make_shared<NodeItem>();
    item->node = node;
    item->op_kernel =
        op::OpKernelFactory::GetInstance()->Create(std::move(ctx));
    node_items_->emplace(node_name, item);
  }

我们以ArrowProcessing为例,看看它的构造流程。首先它调用了父类构造:


// 调用父类方法
explicit OpKernel(OpKernelOptions opts) : opts_(std::move(opts)) {
    num_inputs_ = opts_.op_def->inputs_size();
    if (opts_.op_def->tag().variable_inputs()) {
      // The actual number of inputs for op with variable parameters
      // depends on node's parents.
      num_inputs_ = opts_.node_def.parents_size();
    }
  }

然后执行 BuildInputSchema()和BuildOutputSchema(),BuildInputSchema调用了GetNodeBytesAttr,GetNodeBytesAttr是一个inline方法,用来规避ODR,也是常用的写法,一般都会建议多用inline少用static。

// 使用GetNodeBytesAttr(opts_.node_def, "input_schema_bytes") 调用
inline std::string GetNodeBytesAttr(const NodeDef& node_def,
                                    const std::string& attr_name) {
  std::string value;
  if (!GetNodeBytesAttr(node_def, attr_name, &value)) {
    SERVING_THROW(errors::ErrorCode::UNEXPECTED_ERROR,
                  "can not get attr:{} from node:{}, op:{}", attr_name,
                  node_def.name(), node_def.op());
  }
  return value;
}
bool GetNodeBytesAttr(const NodeDef& node_def, const std::string& attr_name,
                      std::vector<std::string>* value) {
  AttrValue attr_value;
  if (!GetAttrValue(node_def, attr_name, &attr_value)) {
    return false;
  }
  SERVING_ENFORCE(
      attr_value.has_by(), errors::ErrorCode::LOGIC_ERROR,
      "attr_value({}) does not have expected type(bytes) value, node: {}",
      attr_name, node_def.name());
  SERVING_ENFORCE(!attr_value.bys().data().empty(),
                  errors::ErrorCode::INVALID_ARGUMENT,
                  "attr_value({}) type(BytesList) has empty value, node: {}",
                  attr_name, node_def.name());
  value->reserve(attr_value.bys().data().size());
  for (const auto& v : attr_value.bys().data()) {
    value->emplace_back(v);
  }
  return true;
}

从这里我们也可以看出来, ** input_schema_bytes存放在map attr_values 中**,读者可以回头看看NodeDef在proto中的定义。
然后调用arrow进行反序列化就得到了input schema。BuildOutputSchema()也是同样的流程,这里就不重复介绍了。


std::shared_ptr<arrow::Schema> DeserializeSchema(const std::string& buf) {
  std::shared_ptr<arrow::Schema> result;

  std::shared_ptr<arrow::io::RandomAccessFile> buffer_reader =
      std::make_shared<arrow::io::BufferReader>(buf);

  arrow::ipc::DictionaryMemo tmp_memo;
  SERVING_GET_ARROW_RESULT(
      arrow::ipc::ReadSchema(
          std::static_pointer_cast<arrow::io::InputStream>(buffer_reader).get(),
          &tmp_memo),
      result);

  return result;
}

之后的代码在arrow_processing中注册函数调用,这个在预测中我会再次介绍。
https://github.com/secretflow/serving/blob/475bb3356e3a246f444fc25f62f9619874870680/secretflow_serving/ops/arrow_processing.cc#L143
继续看Executor的构造流程,剩下部分建立了node_name到input_schema的映射,然后找到输入特征的schema,并且对node中所有的input_schema进行校验。


  // get input schema
  const auto& entry_nodes = execution_->GetEntryNodes();
  for (const auto& node : entry_nodes) {
    const auto& node_name = node->node_def().name();
    auto iter = node_items_->find(node_name);
    SERVING_ENFORCE(iter != node_items_->end(), errors::ErrorCode::LOGIC_ERROR);
    const auto& input_schema = iter->second->op_kernel->GetAllInputSchema();
    entry_node_names_.emplace_back(node_name);
    input_schema_map_.emplace(node_name, input_schema);
  }

  if (execution_->IsEntry()) {
    // build feature schema from entry execution
    auto iter = input_schema_map_.begin();
    const auto& first_input_schema_list = iter->second;
    SERVING_ENFORCE(first_input_schema_list.size() == 1,
                    errors::ErrorCode::LOGIC_ERROR);
    const auto& target_schema = first_input_schema_list.front();
    ++iter;
    for (; iter != input_schema_map_.end(); ++iter) {
      SERVING_ENFORCE_EQ(iter->second.size(), 1U,
                         "entry nodes should have only one input table");
      const auto& schema = iter->second.front();

      SERVING_ENFORCE_EQ(
          target_schema->num_fields(), schema->num_fields(),
          "entry nodes should have same shape inputs, expect: {}, found: {}",
          target_schema->num_fields(), schema->num_fields());
      CheckReferenceFields(schema, target_schema,
                           fmt::format("entry nodes should have same input "
                                       "schema, found node:{} mismatch",
                                       iter->first));
    }
    input_feature_schema_ = target_schema;
  }
}

初始化ExecutionCore

ExecutionCore是预测时的核心模块,我们看一下这部分代码:


ExecutionCore::ExecutionCore(Options opts)
    : opts_(std::move(opts)),
      stats_({{"handler", "ExecutionCore"}, {"party_id", opts_.party_id}}) {
  SERVING_ENFORCE(!opts_.id.empty(), errors::ErrorCode::INVALID_ARGUMENT);
  SERVING_ENFORCE(!opts_.party_id.empty(), errors::ErrorCode::INVALID_ARGUMENT);
  SERVING_ENFORCE(opts_.executable, errors::ErrorCode::INVALID_ARGUMENT);
  SERVING_ENFORCE(opts_.op_exec_workers_num > 0,
                  errors::ErrorCode::INVALID_ARGUMENT);
// 线程池初始化
  ThreadPool::GetInstance()->Start(opts_.op_exec_workers_num);

  // key: model input feature name
  // value: source or predefined feature name
  // 这里可以插入一些预定义的模型特征,前面的初始化流程中可以进行配置
  std::unordered_map<std::string, std::string> model_feature_mapping;
  valid_feature_mapping_flag_ = false;
  if (opts_.feature_mapping.has_value()) {
    for (const auto& pair : opts_.feature_mapping.value()) {
      if (pair.first != pair.second) {
        valid_feature_mapping_flag_ = true;
      }
      SERVING_ENFORCE(
          model_feature_mapping.emplace(pair.second, pair.first).second,
          errors::ErrorCode::INVALID_ARGUMENT,
          "found duplicate feature mapping value:{}", pair.second);
    }
  }
// 读取Executable的所有schema
  const auto& model_input_schema = opts_.executable->GetInputFeatureSchema();
// 加入source_schema_中
  if (model_feature_mapping.empty()) {
    source_schema_ = model_input_schema;
  } else {
    arrow::SchemaBuilder builder;
    int num_fields = model_input_schema->num_fields();
    for (int i = 0; i < num_fields; ++i) {
      const auto& f = model_input_schema->field(i);
      auto iter = model_feature_mapping.find(f->name());
      SERVING_ENFORCE(iter != model_feature_mapping.end(),
                      errors::ErrorCode::INVALID_ARGUMENT,
                      "can not found {} in feature mapping rule", f->name());
      SERVING_CHECK_ARROW_STATUS(
          builder.AddField(arrow::field(iter->second, f->type())));
    }
    SERVING_GET_ARROW_RESULT(builder.Finish(), source_schema_);
  }
// 初始化feature_adapter_
  if (opts_.feature_source_config.has_value()) {
    SPDLOG_INFO("create feature adapter, type:{}",
                static_cast<int>(opts_.feature_source_config->options_case()));
    feature_adapter_ = feature::FeatureAdapterFactory::GetInstance()->Create(
        *opts_.feature_source_config, opts_.id, opts_.party_id, source_schema_);
  }
}

FeatureSourceConfig也是一个proto定义:

// Config for a feature source
message FeatureSourceConfig {
  oneof options {
    MockOptions mock_opts = 1;
    HttpOptions http_opts = 2;
    CsvOptions csv_opts = 3;
  }
}

可以看出来,它对三种格式做了适配,并且通过OnFetchFeature方法处理请求,这个我们在预测阶段再介绍。

实例化监控服务

这里一共实例化了两个brpc server,第一个是普罗米修斯监控服务,普罗米修斯应该是现在最常用的监控框架。


  // start mertrics server
  if (opts_.server_config.metrics_exposer_port() > 0) {
    std::vector<std::string> strs = absl::StrSplit(self_address, ':');
    SERVING_ENFORCE(strs.size() == 2, errors::ErrorCode::LOGIC_ERROR,
                    "invalid self address.");
    auto metrics_listen_address = fmt::format(
        "{}:{}", strs[0], opts_.server_config.metrics_exposer_port());

    brpc::ServerOptions metrics_server_options;
    if (opts_.server_config.has_tls_config()) {
      auto* ssl_opts = metrics_server_options.mutable_ssl_options();
      ssl_opts->default_cert.certificate =
          opts_.server_config.tls_config().certificate_path();
      ssl_opts->default_cert.private_key =
          opts_.server_config.tls_config().private_key_path();
      ssl_opts->verify.verify_depth = 1;
      ssl_opts->verify.ca_file_path =
          opts_.server_config.tls_config().ca_file_path();
    }
    // @hint 注册普罗米修斯监控服务
    auto* metrics_service = new metrics::MetricsService();
    metrics_service->RegisterCollectable(metrics::GetDefaultRegistry());

    metrics_server_.set_version(SERVING_VERSION_STRING);
    if (metrics_server_.AddService(metrics_service,
                                   brpc::SERVER_OWNS_SERVICE) != 0) {
      SERVING_THROW(errors::ErrorCode::UNEXPECTED_ERROR,
                    "fail to add metrics service into brpc server.");
    }

    if (metrics_server_.Start(metrics_listen_address.c_str(),
                              &metrics_server_options) != 0) {
      SERVING_THROW(errors::ErrorCode::UNEXPECTED_ERROR,
                    "fail to start metrics server at {}", self_address);
    }

    SPDLOG_INFO("begin metrics service listen at {}, ", metrics_listen_address);
  }

实例化模型server - 模型服务

模型server包括三个service,第一个是模型服务,这里暂时还没有动态模型注册功能,只是能获取当前模型的状态信息等。


// build model_info_collector
// @hint 模型信息收集 
  ModelInfoCollector::Options m_c_opts;
  m_c_opts.model_bundle = model_bundle;
  m_c_opts.service_id = opts_.service_id;
  m_c_opts.self_party_id = self_party_id;
  m_c_opts.remote_channel_map = channels;
  ModelInfoCollector model_info_collector(std::move(m_c_opts));
  {
    auto max_retry_cnt =
        opts_.cluster_config.channel_desc().handshake_max_retry_cnt();
    if (max_retry_cnt != 0) {
      model_info_collector.SetRetryCounts(max_retry_cnt);
    }
    auto retry_interval_ms =
        opts_.cluster_config.channel_desc().handshake_retry_interval_ms();
    if (retry_interval_ms != 0) {
      model_info_collector.SetRetryIntervalMs(retry_interval_ms);
    }
  }

  // add services
  auto* model_service = new ModelServiceImpl(
      {{opts_.service_id, model_info_collector.GetSelfModelInfo()}},
      self_party_id);
  // @hint brpc注册服务
  // 模型服务,没有动态模型注册功能
  if (service_server_.AddService(model_service, brpc::SERVER_OWNS_SERVICE) !=
      0) {
    SERVING_THROW(errors::ErrorCode::UNEXPECTED_ERROR,
                  "fail to add model service into brpc server.");
  }

接口如下:


// 模型 - 服务入口
class ModelServiceImpl : public apis::ModelService {
 public:
  explicit ModelServiceImpl(std::map<std::string, ModelInfo> model_infos,
                            const std::string& self_party_id);

  void GetModelInfo(::google::protobuf::RpcController* controller,
                    const apis::GetModelInfoRequest* request,
                    apis::GetModelInfoResponse* response,
                    ::google::protobuf::Closure* done) override;

 private:
  struct Stats {
    // for request api
    ::prometheus::Family<::prometheus::Counter>& api_request_counter_family;
    ::prometheus::Family<::prometheus::Summary>&
        api_request_duration_summary_family;

    explicit Stats(std::map<std::string, std::string> labels,
                   const std::shared_ptr<::prometheus::Registry>& registry =
                       metrics::GetDefaultRegistry());
  };

  void RecordMetrics(const apis::GetModelInfoRequest& request,
                     const apis::GetModelInfoResponse& response,
                     double duration_ms, const std::string& action);

实例化模型server - 执行服务

执行服务使用execution_core进行初始化,有状态信息、执行、监控接口:

  auto* execution_service = new ExecutionServiceImpl(execution_core);
  if (service_server_.AddService(execution_service,
                                 brpc::SERVER_OWNS_SERVICE) != 0) {
    SERVING_THROW(errors::ErrorCode::UNEXPECTED_ERROR,
                  "fail to add execution service into brpc server.");
  }
class ExecutionServiceImpl : public apis::ExecutionService {
 public:
  explicit ExecutionServiceImpl(
      const std::shared_ptr<ExecutionCore>& execution_core);

  void Execute(::google::protobuf::RpcController* controller,
               const apis::ExecuteRequest* request,
               apis::ExecuteResponse* response,
               ::google::protobuf::Closure* done) override;

 private:
  void RecordMetrics(const apis::ExecuteRequest& request,
                     const apis::ExecuteResponse& response, double duration_ms,
                     const std::string& action);
  struct Stats {
    // for service interface
    ::prometheus::Family<::prometheus::Counter>& api_request_counter_family;
    ::prometheus::Family<::prometheus::Summary>&
        api_request_duration_summary_family;

    Stats(std::map<std::string, std::string> labels,
          const std::shared_ptr<::prometheus::Registry>& registry =
              metrics::GetDefaultRegistry());
  };

实例化模型server - 预测服务

预测服务会比执行服务多一些预处理逻辑,最后仍然会调用到execution_core。
这里在初始化之后会进行一些server的参数设置:

  auto* prediction_service = new PredictionServiceImpl(self_party_id);
  if (service_server_.AddService(prediction_service,
                                 brpc::SERVER_OWNS_SERVICE) != 0) {
    SERVING_THROW(errors::ErrorCode::UNEXPECTED_ERROR,
                  "fail to add prediction service into brpc server.");
  }

  // build services server opts
  brpc::ServerOptions server_options;
  server_options.max_concurrency = opts_.server_config.max_concurrency();
  if (opts_.server_config.worker_num() > 0) {
    server_options.num_threads = opts_.server_config.worker_num();
  }
  if (opts_.server_config.brpc_builtin_service_port() > 0) {
    server_options.has_builtin_services = true;
    server_options.internal_port =
        opts_.server_config.brpc_builtin_service_port();
    SPDLOG_INFO("internal port: {}", server_options.internal_port);
  }
  if (opts_.server_config.has_tls_config()) {
    auto* ssl_opts = server_options.mutable_ssl_options();
    ssl_opts->default_cert.certificate =
        opts_.server_config.tls_config().certificate_path();
    ssl_opts->default_cert.private_key =
        opts_.server_config.tls_config().private_key_path();
    ssl_opts->verify.verify_depth = 1;
    ssl_opts->verify.ca_file_path =
        opts_.server_config.tls_config().ca_file_path();
  }
  health::ServingHealthReporter hr;
  server_options.health_reporter = &hr;

  // start services server
  service_server_.set_version(SERVING_VERSION_STRING);
  if (service_server_.Start(self_address.c_str(), &server_options) != 0) {
    SERVING_THROW(errors::ErrorCode::UNEXPECTED_ERROR,
                  "fail to start brpc server at {}", self_address);
  }

之后设置prediction_core,prediction_core只记录一些信息。从这里我们可以看出来,execution是单方的,predict不仅有多方,而且有主从,prediction_core用来协调各方的推理流程。

PredictionCore::Options prediction_core_opts;
  prediction_core_opts.service_id = opts_.service_id;
  prediction_core_opts.party_id = self_party_id;
  prediction_core_opts.cluster_ids = std::move(cluster_ids);
  prediction_core_opts.predictor = predictor;
  auto prediction_core =
      std::make_shared<PredictionCore>(std::move(prediction_core_opts));
  prediction_service->Init(prediction_core);

到这里启动流程就结束了,返回main函数,运行直到brpc服务结束。

预测阶段

ExecuteRequest

我们先看execute请求的proto定义:

// Execute request containing one or more requests.
message ExecuteRequest {
  // Custom data. The header will be passed to the downstream system which
  // implement the feature service spi.
  Header header = 1;

  // Represents the id of the requesting party
  string requester_id = 2;

  // Model service specification.
  // 在execute时只需要检查和ExecutionCore定义的id一致即可。
  ServiceSpec service_spec = 3;

  // Represents the session of this execute.
  string session_id = 4;

  FeatureSource feature_source = 5;

  ExecutionTask task = 6;
}

FeatureSource定义了特征的拉取策略:

// Support feature source type
enum FeatureSourceType {
  UNKNOWN_FS_TYPE = 0;

  // No need features.
  FS_NONE = 1;
  // Fetch features from feature service.
  FS_SERVICE = 2;
  // The feature is defined in the request.
  FS_PREDEFINED = 3;
}

// Descriptive feature source
message FeatureSource {
  // Identifies the source type of the features
  FeatureSourceType type = 1;

  // Custom parameter for fetch features from feature service or other systems.
  // Valid when `type==FeatureSourceType::FS_SERVICE`
  FeatureParam fs_param = 2;

  // Defined features.
  // Valid when `type==FeatureSourceType::FS_PREDEFINED`
  repeated Feature predefineds = 3;
}

ExecutionTask指定了execution id和execution的入参,入参通过NodeIo序列化传入。


// Execute request task.
message ExecutionTask {
  // Specified the execution id.
  int32 execution_id = 1;

  repeated NodeIo nodes = 2;
}

// The serialized data of the node input/output.
message IoData {
  repeated bytes datas = 1;
}

// Represents the node input/output data.
message NodeIo {
  // Node name.
  string name = 1;

  repeated IoData ios = 2;
}

特征拉取

因为预测也会调用到Execute,所以我们先看Execution服务。
这里有两种拉取特征的类型,一种是远程拉取,一种是本地拉取。


    std::shared_ptr<arrow::RecordBatch> features;
    if (request->feature_source().type() ==
        apis::FeatureSourceType::FS_SERVICE) {
      SERVING_ENFORCE(
          !request->feature_source().fs_param().query_datas().empty(),
          errors::ErrorCode::INVALID_ARGUMENT,
          "get empty feature service query datas.");
      SERVING_ENFORCE(request->task().nodes().empty(),
                      errors::ErrorCode::LOGIC_ERROR);
      features = BatchFetchFeatures(request, response);
    } else if (request->feature_source().type() ==
               apis::FeatureSourceType::FS_PREDEFINED) {
      SERVING_ENFORCE(!request->feature_source().predefineds().empty(),
                      errors::ErrorCode::INVALID_ARGUMENT,
                      "get empty predefined features.");
      SERVING_ENFORCE(request->task().nodes().empty(),
                      errors::ErrorCode::LOGIC_ERROR);
      features = FeaturesToTable(request->feature_source().predefineds(),
                                 source_schema_);
    }

先来看远程拉取:


std::shared_ptr<arrow::RecordBatch> ExecutionCore::BatchFetchFeatures(
    const apis::ExecuteRequest* request,
    apis::ExecuteResponse* response) const {
  SERVING_ENFORCE(feature_adapter_, errors::ErrorCode::INVALID_ARGUMENT,
                  "feature source is not set, please check config.");

  yacl::ElapsedTimer timer;
  try {
    feature::FeatureAdapter::Request fa_request;
    fa_request.header = &request->header();
    fa_request.fs_param = &request->feature_source().fs_param();
    feature::FeatureAdapter::Response fa_response;
    fa_response.header = response->mutable_header();
    // 拉取特征
    feature_adapter_->FetchFeature(fa_request, &fa_response);

    RecordBatchFeatureMetrics(request->service_spec().id(),
                              request->requester_id(), errors::ErrorCode::OK,
                              timer.CountMs());
    return fa_response.features;
  } catch (Exception& e) {
    RecordBatchFeatureMetrics(request->service_spec().id(),
                              request->requester_id(), e.code(),
                              timer.CountMs());
    throw e;
  }
}

这里就会调用到feature_adapter_进行特征的拉取。FetchFeature调用了子类实现的OnFetchFeature方法。我们以HttpFeatureAdapter来看下,这里也没什么好说的,就是用brpc完成了一个http拉取请求。

void FeatureAdapter::FetchFeature(const Request& request, Response* response) {
  OnFetchFeature(request, response);

  CheckFeatureValid(request, response->features);
}
/** HttpFeatureAdapter 的OnFetchFeature **/
void HttpFeatureAdapter::OnFetchFeature(const Request& request,
                                        Response* response) {
  auto request_body = SerializeRequest(request);

  yacl::ElapsedTimer timer;

  brpc::Controller cntl;
  cntl.http_request().uri() = spec_.http_opts().endpoint();
  cntl.http_request().set_method(brpc::HTTP_METHOD_POST);
  cntl.http_request().set_content_type("application/json");
  cntl.request_attachment().append(request_body);
  channel_->CallMethod(NULL, &cntl, NULL, NULL, NULL);
  SERVING_ENFORCE(!cntl.Failed(), errors::ErrorCode::NETWORK_ERROR,
                  "http request failed, endpoint:{}, detail:{}",
                  spec_.http_opts().endpoint(), cntl.ErrorText());

  DeserializeResponse(cntl.response_attachment().to_string(), response);
}

还有一种是预定义的特征,就是请求携带了特征,这里会将特征读取到arrow中,比较简单,不多讲了。


std::shared_ptr<arrow::RecordBatch> FeaturesToTable(
    const ::google::protobuf::RepeatedPtrField<Feature>& features,
    const std::shared_ptr<const arrow::Schema>& target_schema) {
  arrow::SchemaBuilder schema_builder;
  std::vector<std::shared_ptr<arrow::Array>> arrays;
  int num_rows = -1;

  for (const auto& field : target_schema->fields()) {
    bool found = false;
    for (const auto& f : features) {
      if (f.field().name() == field->name()) {
        FeatureToArrayVisitor visitor{.target_field = field, .array = {}};
        FeatureVisit(visitor, f);

        if (num_rows >= 0) {
          SERVING_ENFORCE_EQ(
              num_rows, visitor.array->length(),
              "features must have same length value. {}:{}, others:{}",
              f.field().name(), visitor.array->length(), num_rows);
        }
        num_rows = visitor.array->length();
        arrays.emplace_back(visitor.array);
        found = true;
        break;
      }
    }
    SERVING_ENFORCE(found, errors::ErrorCode::UNEXPECTED_ERROR,
                    "can not found feature:{} in response", field->name());
  }
  return MakeRecordBatch(target_schema, num_rows, std::move(arrays));
}

转换入模特征

ApplyFeatureMappingRule实现了在线特征和模型特征的转换。
在配置的feature_mapping中找到对应字段的schema,然后使用arrow::RecordBatch::Make进行特征转换。


std::shared_ptr<arrow::RecordBatch> ExecutionCore::ApplyFeatureMappingRule(
    const std::shared_ptr<arrow::RecordBatch>& features) {
  if (features == nullptr || !valid_feature_mapping_flag_) {
    // no need mapping
    return features;
  }
  const auto& feature_mapping = opts_.feature_mapping.value();

  int num_cols = features->num_columns();
  const auto& old_schema = features->schema();
  arrow::SchemaBuilder builder;
  for (int i = 0; i < num_cols; ++i) {
    auto field = old_schema->field(i);
    auto iter = feature_mapping.find(field->name());
    if (iter != feature_mapping.end()) {
      field = arrow::field(iter->second, field->type());
    }
    SERVING_CHECK_ARROW_STATUS(builder.AddField(field));
  }

  std::shared_ptr<arrow::Schema> schema;
  SERVING_GET_ARROW_RESULT(builder.Finish(), schema);

  return MakeRecordBatch(schema, features->num_rows(), features->columns());
    // 调用了  arrow::RecordBatch::Make(schema, num_rows, std::move(columns));
}

Task初始化

首先根据入参进行Task实例化,这里将nodeio的参数转换为了op::OpComputeInputs。

   // executable run
    Executable::Task task;
    task.id = request->task().execution_id();
    task.features = features;
    task.node_inputs = std::make_shared<std::unordered_map<
        std::string, std::shared_ptr<op::OpComputeInputs>>>();
    for (const auto& n : request->task().nodes()) {
      auto compute_inputs = std::make_shared<op::OpComputeInputs>();
      for (const auto& io : n.ios()) {
        std::vector<std::shared_ptr<arrow::RecordBatch>> inputs;
        for (const auto& d : io.datas()) {
          inputs.emplace_back(DeserializeRecordBatch(d));
        }
        compute_inputs->emplace_back(std::move(inputs));
      }
      task.node_inputs->emplace(n.name(), std::move(compute_inputs));
    }
    opts_.executable->Run(task);

然后调用执行,如果有特征的话,那么就用特征进行预测;否则,使用使用node_inputs进行预测。这里其实会将feature转换为std::unordered_map,然后调用对应的run方法。

void Executable::Run(Task& task) {
  SERVING_ENFORCE(task.id < executors_.size(), errors::ErrorCode::LOGIC_ERROR);
  auto executor = executors_[task.id];
  if (task.features) {
    task.outputs = executor->Run(task.features);
  } else {
    SERVING_ENFORCE(!task.node_inputs->empty(), errors::ErrorCode::LOGIC_ERROR);
    task.outputs = executor->Run(*(task.node_inputs));
  }

  SPDLOG_DEBUG("Executable::Run end, task.outputs.size:{}",
               task.outputs->size());
}

首先我们看Executor的Run方法:

std::shared_ptr<std::vector<NodeOutput>> Executor::Run(
    std::shared_ptr<arrow::RecordBatch>& features) {
  SERVING_ENFORCE(execution_->IsEntry(), errors::ErrorCode::LOGIC_ERROR);
  auto inputs =
      std::unordered_map<std::string, std::shared_ptr<op::OpComputeInputs>>();
  for (size_t i = 0; i < entry_node_names_.size(); ++i) {
    auto op_inputs = std::make_shared<op::OpComputeInputs>();
    std::vector<std::shared_ptr<arrow::RecordBatch>> record_list = {features};
    op_inputs->emplace_back(std::move(record_list));
    inputs.emplace(entry_node_names_[i], std::move(op_inputs));
  }
  return Run(inputs);
}
// 然后调用下面的run方法

std::shared_ptr<std::vector<NodeOutput>> Executor::Run(
    std::unordered_map<std::string, std::shared_ptr<op::OpComputeInputs>>&
        inputs) {
  // 入口node
  std::vector<std::shared_ptr<op::OpComputeInputs>> entry_node_inputs;
  for (const auto& node : execution_->GetEntryNodes()) {
    auto iter = inputs.find(node->node_def().name());
    SERVING_ENFORCE(iter != inputs.end(), errors::ErrorCode::INVALID_ARGUMENT,
                    "can not found inputs for node:{}",
                    node->node_def().name());
    entry_node_inputs.emplace_back(iter->second);
  }
  // 实例化一个调度器
  // execution_ 每个executor都有一个,在初始化时传入
  auto sched = std::make_shared<ExecuteScheduler>(
      node_items_, execution_->GetExitNodeNum(), ThreadPool::GetInstance(),
      execution_);
  // 调度器执行入口
  const auto& entry_nodes = execution_->GetEntryNodes();
  for (size_t i = 0; i != execution_->GetEntryNodeNum(); ++i) {
    sched->AddEntryNode(entry_nodes[i], entry_node_inputs[i]);
  }

  sched->Schedule();

  auto task_exception = sched->GetTaskException();
  if (task_exception) {
    SPDLOG_ERROR("Execution {} run with exception.", execution_->id());
    std::rethrow_exception(task_exception);
  }
  SERVING_ENFORCE_EQ(sched->GetSchedCount(), execution_->nodes().size());
  return std::make_shared<std::vector<NodeOutput>>(sched->GetResults());
}


这里最关键的在于调度器,它的入参是:

  1. std::shared_ptr> node_items_它携带了当前Executor拥有的node和对应的调度器op_kernel
  2. 出口node的数目
  3. 线程池实例
  4. execution_ 也就是模型最初对于executor的配置

接下来我们看这个调度器的代码。

调度器

首先,在构造函数这里有个有意思的东西:

  ExecuteScheduler(
      std::shared_ptr<
          std::unordered_map<std::string, std::shared_ptr<NodeItem>>>
          node_items,
      uint64_t res_cnt, const std::shared_ptr<ThreadPool>& thread_pool,
      std::shared_ptr<Execution> execution)
      : node_items_(std::move(node_items)),
        context_(res_cnt),
        thread_pool_(thread_pool),
        execution_(std::move(execution)),
        propagator_(execution_->nodes()),
        sched_count_(0) {}

注意到了吗,propagator_。我们来看它的定义和实现,看起来它保管了所有node的状态信息和输入输出。
像是某种用于节点管控的东西。


struct ComputeContext {
  // TODO: Session
  OpComputeInputs inputs;
  std::shared_ptr<arrow::RecordBatch> output;
};

struct FrameState {
  std::atomic<int> pending_count;

  op::ComputeContext compute_ctx;
};

class Propagator {
 public:
  explicit Propagator(
      const std::unordered_map<std::string, std::shared_ptr<Node>>& nodes);

  FrameState* GetFrame(const std::string& node_name);

 private:
  std::unordered_map<std::string, FrameState*> node_frame_map_;
  std::vector<FrameState> frame_pool_;
};
}  // namespace secretflow::serving

Propagator::Propagator(
    const std::unordered_map<std::string, std::shared_ptr<Node>>& nodes) {
  frame_pool_ = std::vector<FrameState>(nodes.size());
  size_t idx = 0;
  for (auto& [node_name, node] : nodes) {
    auto frame = &frame_pool_[idx++];
    frame->pending_count = node->GetInputNum();
    frame->compute_ctx.inputs.resize(frame->pending_count);

    SERVING_ENFORCE(node_frame_map_.emplace(node_name, std::move(frame)).second,
                    errors::ErrorCode::LOGIC_ERROR);
  }
}

FrameState* Propagator::GetFrame(const std::string& node_name) {
  auto iter = node_frame_map_.find(node_name);
  SERVING_ENFORCE(iter != node_frame_map_.end(), errors::ErrorCode::LOGIC_ERROR,
                  "can not found frame for node: {}", node_name);
  return iter->second;
}

我们带着疑问继续来看执行代码。
调度器的调度代码,这里注释说想用bthread的能力来实现worker窃取,挺有意思,关于brpc我之前也写过一篇,感兴趣的可以看看:
https://www.yuque.com/treblez/qksu6c/owbw5sm9xzmqv2qa?singleDoc# 《brpc:优秀代码鉴赏》
这里没有换出机制,仅仅是一个简单的原地等待。ready_nodes_是一个ThreadSafeQueue,这个也不是什么无锁队列,简单的mutex同步,代码就不贴了。

  void Schedule() {
    while (!stop_flag_.load() && !context_.IsFinish()) {
      // TODO: consider use bthread::Mutex and bthread::ConditionVariable
      //       to make this worker can switch to others
      std::shared_ptr<NodeItem> node_item;
      // 调用了ready_nodes_.WaitPop(node_item);
      // ready_nodes_是一个 ThreadSafeQueue 
      if (!context_.GetReadyNode(node_item)) {
        continue;
      }
      SubmitExecuteOpTask(node_item);
    }
  }
/** 调用到下边 **/
  void SubmitExecuteOpTask(std::shared_ptr<NodeItem>& node_item) {
    if (stop_flag_.load()) {
      return;
    }
    thread_pool_->SubmitTask(
        std::make_unique<ExecuteOpTask>(node_item, shared_from_this()));
  }
/** 下面的类被提交到线程池 **/
  class ExecuteOpTask : public ThreadPool::Task {
   public:
    const char* Name() override { return "ExecuteOpTask"; }

    ExecuteOpTask(std::shared_ptr<NodeItem> node_item,
                  std::shared_ptr<ExecuteScheduler> sched)
        : node_item_(std::move(node_item)), sched_(std::move(sched)) {}
	// 回调ExecuteOp
    void Exec() override { sched_->ExecuteOp(node_item_); }

    void OnException(std::exception_ptr e) noexcept override {
      sched_->SetTaskException(e);
    }

   private:
    std::shared_ptr<NodeItem> node_item_;
    std::shared_ptr<ExecuteScheduler> sched_;
  };

/** 调用ExecuteOp **/ 

  void ExecuteOp(const std::shared_ptr<NodeItem>& node_item) {
    if (stop_flag_.load()) {
      return;
    }

    auto* frame = propagator_.GetFrame(node_item->node->node_def().name());
	//这里调用了OpKernel,进行执行逻辑
    node_item->op_kernel->Compute(&(frame->compute_ctx));
    sched_count_++;

    if (execution_->IsExitNode(node_item->node->node_def().name())) {
      context_.AddResult(node_item->node->node_def().name(),
                         frame->compute_ctx.output);
    }

    const auto& edges = node_item->node->out_edges();
    for (const auto& edge : edges) {
      CompleteOutEdge(edge, frame->compute_ctx.output);
    }
  }

执行

之前我介绍过OpKernel的子类有三个:ArrowProcessing、MergeY、DotProduct
这里分别讲一下它们执行期的执行方法。

ArrowProcessing

先来看最重要的ArrowProcessing。
ComputeTrace被定义在proto中,表示要执行的方法。

message FunctionTrace {
  // The Function name.
  string name = 1;

  // The serialized function options.
  bytes option_bytes = 2;

  // Inputs of this function.
  repeated FunctionInput inputs = 3;

  // Output of this function.
  FunctionOutput output = 4;
}

message ComputeTrace {
  // The name of this Compute.
  string name = 1;

  repeated FunctionTrace func_traces = 2;
}

// Function name定义如下:

enum ExtendFunctionName {
  // Placeholder for proto3 default value, do not use it
  UNKOWN_EX_FUNCTION_NAME = 0;

  // Get colunm from table(record_batch).
  // see
  // https://arrow.apache.org/docs/cpp/api/table.html#_CPPv4NK5arrow11RecordBatch6columnEi
  EFN_TB_COLUMN = 1;
  // Add colum to table(record_batch).
  // see
  // https://arrow.apache.org/docs/cpp/api/table.html#_CPPv4NK5arrow11RecordBatch9AddColumnEiNSt6stringERKNSt10shared_ptrI5ArrayEE
  EFN_TB_ADD_COLUMN = 2;
  // Remove colunm from table(record_batch).
  // see
  // https://arrow.apache.org/docs/cpp/api/table.html#_CPPv4NK5arrow11RecordBatch12RemoveColumnEi
  EFN_TB_REMOVE_COLUMN = 3;
  // Set colunm to table(record_batch).
  // see
  // https://arrow.apache.org/docs/cpp/api/table.html#_CPPv4NK5arrow11RecordBatch9SetColumnEiRKNSt10shared_ptrI5FieldEERKNSt10shared_ptrI5ArrayEE
  EFN_TB_SET_COLUMN = 4;
}

// FunctionInput 定义如下

message FunctionInput {
  oneof value {
    // '0' means root input data
    int32 data_id = 1;
    Scalar custom_scalar = 2;
  }
}

我们可以看到,这里的几个函数就是Arrow的CRUD。
执行逻辑如下所示,这里会将compute_trace_中的所有定义函数放到执行列表func_list_里面,后续会对func_list_进行顺序的调用。


switch (ex_func_name) {
case compute::ExtendFunctionName::EFN_TB_COLUMN: {
    // 查找
  func_list_.emplace_back([](arrow::Datum& result_datum,
                             std::vector<arrow::Datum>& func_inputs) {
    result_datum = func_inputs[0].record_batch()->column(
        std::static_pointer_cast<arrow::Int64Scalar>(
            func_inputs[1].scalar())
            ->value);
  });
  break;
}
case compute::ExtendFunctionName::EFN_TB_ADD_COLUMN: {
    // 增加一列
  func_list_.emplace_back([](arrow::Datum& result_datum,
                             std::vector<arrow::Datum>& func_inputs) {
    int64_t index = std::static_pointer_cast<arrow::Int64Scalar>(
                        func_inputs[1].scalar())
                        ->value;
    std::string field_name(
        std::static_pointer_cast<arrow::StringScalar>(
            func_inputs[2].scalar())
            ->view());
    std::shared_ptr<arrow::RecordBatch> new_batch;
    SERVING_GET_ARROW_RESULT(
        func_inputs[0].record_batch()->AddColumn(
            index, std::move(field_name), func_inputs[3].make_array()),
        new_batch);
    result_datum = new_batch;
  });
  break;
}
case compute::ExtendFunctionName::EFN_TB_REMOVE_COLUMN: {
    // 删除一列
  func_list_.emplace_back([](arrow::Datum& result_datum,
                             std::vector<arrow::Datum>& func_inputs) {
    std::shared_ptr<arrow::RecordBatch> new_batch;
    SERVING_GET_ARROW_RESULT(
        func_inputs[0].record_batch()->RemoveColumn(
            std::static_pointer_cast<arrow::Int64Scalar>(
                func_inputs[1].scalar())
                ->value),
        new_batch);
    result_datum = new_batch;
  });
  break;
}
case compute::ExtendFunctionName::EFN_TB_SET_COLUMN: {
    // 修改一列
  func_list_.emplace_back([](arrow::Datum& result_datum,
                             std::vector<arrow::Datum>& func_inputs) {
    int64_t index = std::static_pointer_cast<arrow::Int64Scalar>(
                        func_inputs[1].scalar())
                        ->value;
    std::string field_name(
        std::static_pointer_cast<arrow::StringScalar>(
            func_inputs[2].scalar())
            ->view());
    std::shared_ptr<arrow::Array> array = func_inputs[3].make_array();
    std::shared_ptr<arrow::RecordBatch> new_batch;
    SERVING_GET_ARROW_RESULT(
        func_inputs[0].record_batch()->SetColumn(
            index, arrow::field(std::move(field_name), array->type()),
            array),
        new_batch);
    result_datum = new_batch;
  });
  break;
}
default:
  SERVING_THROW(errors::ErrorCode::UNEXPECTED_ERROR,
                "invalid ext func name enum: {}",
                static_cast<int>(ex_func_name));
}

然后我们继续看执行时逻辑DoCompute,这里比较奇怪的是拆成了两个,这里只是做一个顺序的调用,就不多讲了。


void ArrowProcessing::DoCompute(ComputeContext* ctx) {
  // sanity check
  SERVING_ENFORCE(ctx->inputs.size() == 1, errors::ErrorCode::LOGIC_ERROR);
  SERVING_ENFORCE(ctx->inputs.front().size() == 1,
                  errors::ErrorCode::LOGIC_ERROR);

  if (dummy_flag_) {
    ctx->output = ctx->inputs.front().front();
    return;
  }

  SPDLOG_INFO("replay compute: {}", compute_trace_.name());

  ctx->output = ReplayCompute(ctx->inputs.front().front());
}



std::shared_ptr<arrow::RecordBatch> ArrowProcessing::ReplayCompute(
    const std::shared_ptr<arrow::RecordBatch>& input) {
  std::map<int32_t, arrow::Datum> datas = {{0, input}};

  arrow::Datum result_datum;
  for (int i = 0; i < compute_trace_.func_traces_size(); ++i) {
    const auto& func = compute_trace_.func_traces(i);
    SPDLOG_DEBUG("replay func: {}", func.ShortDebugString());
    auto func_inputs = BuildInputDatums(func.inputs(), datas);
    func_list_[i](result_datum, func_inputs);

    SERVING_ENFORCE(
        datas.emplace(func.output().data_id(), std::move(result_datum)).second,
        errors::ErrorCode::LOGIC_ERROR);
  }

  return datas[result_id_].record_batch();
}

DotProduct

和名字一样,这个OpKernel使用Eigen做一个点乘,Arrow没有提供点乘的能力


void DotProduct::DoCompute(ComputeContext* ctx) {
  SERVING_ENFORCE(ctx->inputs.size() == 1, errors::ErrorCode::LOGIC_ERROR);
  SERVING_ENFORCE(ctx->inputs.front().size() == 1,
                  errors::ErrorCode::LOGIC_ERROR);

  auto features = TableToMatrix(ctx->inputs.front().front());

  Double::ColVec score_vec = features * weights_;
  score_vec.array() += intercept_;

  std::shared_ptr<arrow::Array> array;
  arrow::DoubleBuilder builder;
  for (int i = 0; i < score_vec.rows(); ++i) {
    auto row = score_vec.row(i);
    SERVING_CHECK_ARROW_STATUS(builder.AppendValues(row.data(), 1));
  }
  SERVING_CHECK_ARROW_STATUS(builder.Finish(&array));
  ctx->output = MakeRecordBatch(output_schema_, score_vec.rows(), {array});
}

MergeY

MergeY对两方数据做一个合并:

void MergeY::DoCompute(ComputeContext* ctx) {
  // santiy check
  SERVING_ENFORCE(ctx->inputs.size() == 1, errors::ErrorCode::LOGIC_ERROR);
  SERVING_ENFORCE(ctx->inputs.front().size() >= 1,
                  errors::ErrorCode::LOGIC_ERROR);

  // merge partial_y
  arrow::Datum incremented_datum(ctx->inputs.front()[0]->column(0));
  for (size_t i = 1; i < ctx->inputs.front().size(); ++i) {
    auto cur_array = ctx->inputs.front()[i]->column(0);
    SERVING_GET_ARROW_RESULT(arrow::compute::Add(incremented_datum, cur_array),
                             incremented_datum);
  }
  auto merged_array = std::static_pointer_cast<arrow::DoubleArray>(
      std::move(incremented_datum).make_array());

  // apply link func
  arrow::DoubleBuilder builder;
  SERVING_CHECK_ARROW_STATUS(builder.Resize(merged_array->length()));
  for (int64_t i = 0; i < merged_array->length(); ++i) {
    auto score =
        ApplyLinkFunc(merged_array->Value(i), link_function_) * yhat_scale_;
    SERVING_CHECK_ARROW_STATUS(builder.Append(score));
  }
  std::shared_ptr<arrow::Array> res_array;
  SERVING_CHECK_ARROW_STATUS(builder.Finish(&res_array));
  ctx->output =
      MakeRecordBatch(output_schema_, res_array->length(), {res_array});
}

后处理

  void CompleteOutEdge(const std::shared_ptr<Edge>& edge,
                       std::shared_ptr<arrow::RecordBatch> output) {
    std::shared_ptr<Node> dst_node;
    if (!execution_->TryGetNode(edge->dst_node(), &dst_node)) {
      return;
    }

    auto* child_frame = propagator_.GetFrame(dst_node->GetName());
    child_frame->compute_ctx.inputs[edge->dst_input_id()].emplace_back(
        std::move(output));

    if (child_frame->pending_count.fetch_sub(1) == 1) {
      context_.AddReadyNode(
          node_items_->find(dst_node->node_def().name())->second);
    }
  }

PredictRequest

// The value of a feature
message FeatureValue {
  // int list
  repeated int32 i32s = 1;
  repeated int64 i64s = 2;
  // float list
  repeated float fs = 3;
  repeated double ds = 4;
  // string list
  repeated string ss = 5;
  // bool list
  repeated bool bs = 6;
}

// The definition of a feature field.
message FeatureField {
  // Unique name of the feature
  string name = 1;

  // Field type of the feature
  FieldType type = 2;
}

// The definition of a feature
message Feature {
  FeatureField field = 1;

  FeatureValue value = 2;
}

message PredictRequest {
  // Custom data. The header will be passed to the downstream system which
  // implement the feature service spi.
  Header header = 1;

  // Model service specification.
  ServiceSpec service_spec = 2;

  // The params for fetch features. Note that this should include all the
  // parties involved in the prediction.
  // Key: party's id.
  // Value: params for fetch features.
  map<string, FeatureParam> fs_params = 3;

  // Optional.
  // If defined, the request party will no longer query for the feature but will
  // use defined fetures in `predefined_features` for the prediction.
  repeated Feature predefined_features = 4;
}

预测

预测仅仅是比Execute多了一个多方通信的过程。
这里会用在启动流程中初始化的rpc channel,启动对方的execute,然后执行自己的流程,最后等待其余方的execute结束。


void Predictor::Predict(const apis::PredictRequest* request,
                        apis::PredictRespaonse* response) {
  std::unordered_map<std::string, std::shared_ptr<apis::NodeIo>>
      prev_node_io_map;
  std::vector<std::shared_ptr<RemoteExecute>> async_running_execs;
  async_running_execs.reserve(opts_.channels->size());

  auto execute_locally =
      [&](const std::shared_ptr<Execution>& execution,
          std::unordered_map<std::string, std::shared_ptr<apis::NodeIo>>&
              prev_io_map,
          std::unordered_map<std::string, std::shared_ptr<apis::NodeIo>>&
              cur_io_map) {
        // exec locally
        auto local_exec = BuildLocalExecute(request, response, execution);
        local_exec->SetInputs(std::move(prev_io_map));
        local_exec->Run();
        local_exec->GetOutputs(&cur_io_map);
      };

  for (const auto& e : opts_.executions) {
    async_running_execs.clear();
    std::unordered_map<std::string, std::shared_ptr<apis::NodeIo>>
        new_node_io_map;
    if (e->GetDispatchType() == DispatchType::DP_ALL) {
      for (const auto& [party_id, channel] : *opts_.channels) {
        auto ctx = BuildRemoteExecute(request, response, e, party_id, channel);
        ctx->SetInputs(prev_node_io_map);
        ctx->Run();
        async_running_execs.emplace_back(ctx);
      }

      // exec locally
      if (execution_core_) {
        execute_locally(e, prev_node_io_map, new_node_io_map);
        for (auto& exec : async_running_execs) {
          exec->WaitToFinish();
          exec->GetOutputs(&new_node_io_map);
        }
      } else {
        // TODO: support no execution core scene
        SERVING_THROW(errors::ErrorCode::NOT_IMPLEMENTED, "not implemented");
      }

    } else if (e->GetDispatchType() == DispatchType::DP_ANYONE) {
      // exec locally
      if (execution_core_) {
        execute_locally(e, prev_node_io_map, new_node_io_map);
      } else {
        // TODO: support no execution core scene
        SERVING_THROW(errors::ErrorCode::NOT_IMPLEMENTED, "not implemented");
      }
    } else if (e->GetDispatchType() == DispatchType::DP_SPECIFIED) {
      if (e->SpecificToThis()) {
        SERVING_ENFORCE(execution_core_, errors::ErrorCode::UNEXPECTED_ERROR);
        execute_locally(e, prev_node_io_map, new_node_io_map);
      } else {
        auto iter = opts_.specific_party_map.find(e->id());
        SERVING_ENFORCE(iter != opts_.specific_party_map.end(),
                        serving::errors::LOGIC_ERROR,
                        "{} execution assign to no party", e->id());
        auto ctx = BuildRemoteExecute(request, response, e, iter->second,
                                      opts_.channels->at(iter->second));
        ctx->SetInputs(prev_node_io_map);
        ctx->Run();
        ctx->WaitToFinish();
        ctx->GetOutputs(&new_node_io_map);
      }
    } else {
      SERVING_THROW(errors::ErrorCode::UNEXPECTED_ERROR,
                    "unsupported dispatch type: {}",
                    DispatchType_Name(e->GetDispatchType()));
    }
    prev_node_io_map.swap(new_node_io_map);
  }

  DealFinalResult(prev_node_io_map, response);
}

最终对于RemoteExecutor的调用最终会调用到下面的代码,其实也就是调用远端的Execute。

void ExecuteContext::Execute(
    std::shared_ptr<::google::protobuf::RpcChannel> channel,
    brpc::Controller* cntl) {
  apis::ExecutionService_Stub stub(channel.get());
  stub.Execute(cntl, &exec_req_, &exec_res_, brpc::DoNothing());
}

你可能感兴趣的:(推理引擎)