secretflow-serving(https://github.com/secretflow/serving)是隐语提供的一套aby3的推理服务,代码量只有clickhouse的百分之一(一万行不到),但是麻雀虽小,五脏俱全,有模型加载和推理的整套流程,还结合Prometheus实现了监控服务。
secretflow-serving使用了C++17,代码也写的很清晰易懂,本文就结合它的架构解读一下它的源码,因为笔者并非机器学习专业人士,有错误之处希望读者不吝指教。
secretflow-serving的整体架构,分为启动和预测服务两个阶段。启动主要是读取模型并且开启了brpc的几个rpc服务,具体可见(https://brpc.apache.org/zh/docs/server/serve-grpc/)。
入口函数在https://github.com/secretflow/serving/blob/475bb3356e3a246f444fc25f62f9619874870680/secretflow_serving/server/main.cc#L46
这里用absl和gflags处理输入参数。
OpFactory 是一个Singleton,secretflow在这里实现了一个标准的Meyers’ Singleton(https://github.com/secretflow/serving/blob/475bb3356e3a246f444fc25f62f9619874870680/secretflow_serving/core/singleton.h#L20):
{
auto op_def_list =
secretflow::serving::op::OpFactory::GetInstance()->GetAllOps();
std::vector<std::string> op_names;
std::for_each(
op_def_list.begin(), op_def_list.end(),
[&](const std::shared_ptr<const secretflow::serving::op::OpDef>& o) {
op_names.emplace_back(o->name());
});
SPDLOG_INFO("op list: {}",
fmt::join(op_names.begin(), op_names.end(), ", "));
}
OpFactory在https://github.com/secretflow/serving/blob/475bb3356e3a246f444fc25f62f9619874870680/secretflow_serving/ops/op_factory.h。
class OpFactory final : public Singleton<OpFactory> {
public:
void Register(const std::shared_ptr<OpDef>& op_def) {
std::lock_guard<std::mutex> lock(mutex_);
SERVING_ENFORCE(op_defs_.emplace(op_def->name(), op_def).second,
errors::ErrorCode::LOGIC_ERROR,
"duplicated op_def registered for {}", op_def->name());
}
const std::shared_ptr<OpDef> Get(const std::string& name) {
std::lock_guard<std::mutex> lock(mutex_);
auto iter = op_defs_.find(name);
SERVING_ENFORCE(iter != op_defs_.end(), errors::ErrorCode::UNEXPECTED_ERROR,
"no op_def registered for {}", name);
return iter->second;
}
std::vector<std::shared_ptr<const OpDef>> GetAllOps() {
std::vector<std::shared_ptr<const OpDef>> result;
std::lock_guard<std::mutex> lock(mutex_);
for (const auto& pair : op_defs_) {
result.emplace_back(pair.second);
}
return result;
}
private:
std::unordered_map<std::string, std::shared_ptr<OpDef>> op_defs_;
std::mutex mutex_;
};
我们可以看到,OP使用REGISTER_OP静态注册。并且在执行图节点被使用(见https://github.com/secretflow/serving/blob/475bb3356e3a246f444fc25f62f9619874870680/secretflow_serving/ops/node.cc#L23)
#define REGISTER_OP(op_name, version, desc) \
static OpRegister const regist_op_##op_name = \
OpRegister{} << internal::OpDefBuilderWrapper(#op_name, version, desc)
之后读取了server的各种参数,然后开始启动流程。
整体代码如下:
// @hint 入口函数
int main(int argc, char* argv[]) {
// Initialize the symbolizer to get a human-readable stack trace
// 这里用absl和gflags处理输入参数
absl::InitializeSymbolizer(argv[0]);
gflags::SetVersionString(SERVING_VERSION_STRING);
gflags::AllowCommandLineReparsing();
gflags::ParseCommandLineFlags(&argc, &argv, true);
try {
// init logger
secretflow::serving::LoggingConfig log_config;
if (!FLAGS_logging_config_file.empty()) {
secretflow::serving::LoadPbFromJsonFile(FLAGS_logging_config_file,
&log_config);
}
secretflow::serving::SetupLogging(log_config);
SPDLOG_INFO("version: {}", SERVING_VERSION_STRING);
{
// OpFactory 是一个Meyers' Singleton
// OP使用REGISTER_OP静态注册
auto op_def_list =
secretflow::serving::op::OpFactory::GetInstance()->GetAllOps();
std::vector<std::string> op_names;
std::for_each(
op_def_list.begin(), op_def_list.end(),
[&](const std::shared_ptr<const secretflow::serving::op::OpDef>& o) {
op_names.emplace_back(o->name());
});
SPDLOG_INFO("op list: {}",
fmt::join(op_names.begin(), op_names.end(), ", "));
}
STRING_EMPTY_VALIDATOR(FLAGS_serving_config_file);
// init server options
secretflow::serving::Server::Options server_opts;
if (FLAGS_config_mode == "kuscia") {
secretflow::serving::kuscia::KusciaConfigParser config_parser(
FLAGS_serving_config_file);
server_opts.server_config = config_parser.server_config();
server_opts.cluster_config = config_parser.cluster_config();
server_opts.model_config = config_parser.model_config();
server_opts.feature_source_config = config_parser.feature_config();
server_opts.service_id = config_parser.service_id();
} else {
secretflow::serving::ServingConfig serving_conf;
LoadPbFromJsonFile(FLAGS_serving_config_file, &serving_conf);
server_opts.server_config = serving_conf.server_conf();
server_opts.cluster_config = serving_conf.cluster_conf();
server_opts.model_config = serving_conf.model_conf();
if (serving_conf.has_feature_source_conf()) {
server_opts.feature_source_config = serving_conf.feature_source_conf();
}
server_opts.service_id = serving_conf.id();
}
// 启动服务
secretflow::serving::Server server(std::move(server_opts));
server.Start();
// 运行直到brpc服务结束
server.WaitForEnd();
} catch (const secretflow::serving::Exception& e) {
// TODO: custom status sink
SPDLOG_ERROR("server startup failed, code: {}, msg: {}, stack: {}",
e.code(), e.what(), e.stack_trace());
return -1;
} catch (const std::exception& e) {
// TODO: custom status sink
SPDLOG_ERROR("server startup failed, msg:{}", e.what());
return -1;
}
return 0;
}
在介绍启动流程之前,我们先看一下secretflow中用于推理的模型是怎么定义的:
https://github.com/secretflow/serving/blob/475bb3356e3a246f444fc25f62f9619874870680/secretflow_serving/protos/bundle.proto#L37
model_bundle是一个proto定义, 包含了完整的模型信息。
GraphDef是执行图的定义,包括了一组携带数据的节点信息(NodeDef)和一组图的执行信息(ExecutionDef)。
// Represents an exported secertflow model. It consists of a GraphDef and extra
// metadata required for serving.
message ModelBundle {
string name = 1;
string desc = 2;
GraphDef graph = 3;
}
// The definition of a Graph. A graph consists of a set of nodes carrying data
// and a set of executions that describes the scheduling of the graph.
message GraphDef {
// Version of the graph
string version = 1;
repeated NodeDef node_list = 2;
repeated ExecutionDef execution_list = 3;
}
我们继续看NodeDef和ExecutionDef:
// The definition of a node.
message NodeDef {
// Must be unique among all nodes of the graph.
string name = 1;
// The operator name.
string op = 2;
// The parent node names of the node. The order of the parent nodes should
// match the order of the inputs of the node.
repeated string parents = 3;
// 节点OP的属性
// The attribute values configed in the node. Note that this should include
// all attrs defined in the corresponding OpDef.
map attr_values = 4;
// The operator version.
string op_version = 5;
}
// The value of an attribute
message AttrValue {
oneof value {
// INT
int32 i32 = 1;
int64 i64 = 2;
// FLOAT
float f = 3;
double d = 4;
// STRING
string s = 5;
// BOOL
bool b = 6;
// BYTES
bytes by = 7;
// Lists
// INTS
Int32List i32s = 11;
Int64List i64s = 12;
// FLOATS
FloatList fs = 13;
DoubleList ds = 14;
// STRINGS
StringList ss = 15;
// BOOLS
BoolList bs = 16;
// BYTESS
BytesList bys = 17;
}
}
// The definition of a execution. A execution represents a subgraph within a
// graph that can be scheduled for execution in a specified pattern.
message ExecutionDef {
// 包含运行时配置和节点
// Represents the nodes contained in this execution. Note that
// these node names should be findable and unique within the node
// definitions. One node can only exist in one execution and must exist in
// one.
repeated string nodes = 1;
// The runtime config of the execution.
RuntimeConfig config = 2;
}
启动的代码在
https://github.com/secretflow/serving/blob/475bb3356e3a246f444fc25f62f9619874870680/secretflow_serving/server/server.cc#L58
SourceFactory也是一个Singleton,初始化之后从文件拉取模型:
/*SourceFactory 初始化*/
// get model package
auto source = SourceFactory::GetInstance()->Create(opts_.model_config,
opts_.service_id);
// @hint 拉取模型, channels 初始化
// 这一步从文件读取
auto package_path = source->PullModel();
/*PullModel代码如下*/
std::string Source::PullModel() {
auto dst_dir = std::filesystem::path(data_dir_).append(config_.model_id());
if (!std::filesystem::exists(dst_dir)) {
std::filesystem::create_directories(dst_dir);
}
auto dst_file_path = dst_dir.append(kModelFileName);
const auto& source_sha256 = config_.source_sha256();
if (std::filesystem::exists(dst_file_path)) {
if (!source_sha256.empty()) {
if (SysUtil::CheckSHA256(dst_file_path.string(), source_sha256)) {
return dst_file_path;
}
}
SPDLOG_INFO("remove tmp model file:{}", dst_file_path.string());
std::filesystem::remove(dst_file_path);
}
// OnPullModel 从oss拉取模型
OnPullModel(dst_file_path);
if (!source_sha256.empty()) {
SERVING_ENFORCE(SysUtil::CheckSHA256(dst_file_path.string(), source_sha256),
errors::ErrorCode::IO_ERROR,
"model({}) sha256 check failed", config_.source_path());
}
return dst_file_path;
}
然后根据参与方信息 初始化rpc channel:
// build channels
std::string self_address;
std::vector<std::string> cluster_ids;
// 通过channels和每一方通信
auto channels = std::make_shared<PartyChannelMap>();
for (const auto& party : opts_.cluster_config.parties()) {
cluster_ids.emplace_back(party.id());
if (party.id() == self_party_id) {
self_address = party.listen_address().empty() ? party.address()
: party.listen_address();
continue;
}
channels->emplace(
party.id(),
CreateBrpcChannel(
party.address(), opts_.cluster_config.channel_desc().protocol(),
FLAGS_enable_peers_load_balancer,
opts_.cluster_config.channel_desc().rpc_timeout_ms() > 0
? opts_.cluster_config.channel_desc().rpc_timeout_ms()
: kPeerRpcTimeoutMs,
opts_.cluster_config.channel_desc().connect_timeout_ms() > 0
? opts_.cluster_config.channel_desc().connect_timeout_ms()
: kPeerConnectTimeoutMs,
opts_.cluster_config.channel_desc().has_tls_config()
? &opts_.cluster_config.channel_desc().tls_config()
: nullptr));
}
然后从oss拉取的文件中解压并且读取proto文件,这里我们关注load模型的过程和graph的构造函数:
// load model package
auto loader = std::make_unique<ModelLoader>();
loader->Load(package_path);
const auto& model_bundle = loader->GetModelBundle();
Graph graph(model_bundle->graph());
// 此处model_bundle是一个proto定义, 包含了完整的模型信息
// Represents an exported secertflow model. It consists of a GraphDef and extra
// metadata required for serving.
// message ModelBundle {
// string name = 1;
// string desc = 2;
// GraphDef graph = 3;
// }
/*Load方法*/
void ModelLoader::Load(const std::string& file_path) {
SPDLOG_INFO("begin load file: {}", file_path);
auto model_dir =
std::filesystem::path(file_path).parent_path().append("data");
if (std::filesystem::exists(model_dir)) {
// remove tmp model dir
SPDLOG_WARN("remove tmp model dir: {}", model_dir.string());
std::filesystem::remove_all(model_dir);
}
// unzip package file
try {
SysUtil::ExtractGzippedArchive(file_path, model_dir);
} catch (const std::exception& e) {
std::filesystem::remove_all(file_path);
SERVING_THROW(errors::ErrorCode::IO_ERROR,
"failed to extract model package {}, detail: {}", file_path,
e.what());
}
auto manifest_path =
std::filesystem::path(model_dir).append(kManifestFileName);
SERVING_ENFORCE(
std::filesystem::exists(manifest_path), errors::ErrorCode::IO_ERROR,
"can not find manifest file {}, model package file is corrupted",
manifest_path.string());
// load manifest
ModelManifest manifest;
// pb文件反序列化
LoadPbFromJsonFile(manifest_path.string(), &manifest);
auto model_file_path = model_dir.append(manifest.bundle_path());
auto model_bundle = std::make_shared<ModelBundle>();
if (manifest.bundle_format() == FileFormatType::FF_PB) {
LoadPbFromBinaryFile(model_file_path.string(), model_bundle.get());
} else if (manifest.bundle_format() == FileFormatType::FF_JSON) {
LoadPbFromJsonFile(model_file_path.string(), model_bundle.get());
} else {
SERVING_THROW(errors::ErrorCode::UNEXPECTED_ERROR,
"found unknown bundle_format:{}",
FileFormatType_Name(manifest.bundle_format()));
}
model_bundle_ = std::move(model_bundle);
SPDLOG_INFO("end load model bundle, name: {}, desc: {}, graph version: {}",
model_bundle_->name(), model_bundle_->desc(),
model_bundle_->graph().version());
}
/*Graph构造函数*/
Graph::Graph(GraphDef graph_def) : def_(std::move(graph_def)) {
// TODO: check version
// TODO: consider not storing def_ to avoiding multiple copies of node_defs
// and execution_defs
graph_view_.set_version(def_.version());
for (auto& node : def_.node_list()) {
NodeView view;
*(view.mutable_name()) = node.name();
*(view.mutable_op()) = node.op();
*(view.mutable_op_version()) = node.op_version();
*(view.mutable_parents()) = node.parents();
graph_view_.mutable_node_list()->Add(std::move(view));
}
*(graph_view_.mutable_execution_list()) = def_.execution_list();
// create nodes
// 读取node,node是一个std::unordered_map> nodes_;
for (int i = 0; i < def_.node_list_size(); ++i) {
const auto node_name = def_.node_list(i).name();
auto node = std::make_shared<Node>(def_.node_list(i));
SERVING_ENFORCE(nodes_.emplace(node_name, node).second,
errors::ErrorCode::LOGIC_ERROR, "found duplicate node:{}",
node_name);
}
// create edges
// 构建edge,edge只是一个顺序存放的vector
for (const auto& [name, node] : nodes_) {
const auto& input_nodes = node->GetInputNodeNames();
if (input_nodes.empty()) {
SERVING_ENFORCE(node->GetOpDef()->inputs_size() == 1,
errors::ErrorCode::LOGIC_ERROR,
"the entry op should only have one input to accept "
"the features, node:{}, op:{}",
name, node->node_def().op());
entry_nodes_.emplace_back(node);
}
for (size_t i = 0; i < input_nodes.size(); ++i) {
auto n_iter = nodes_.find(input_nodes[i]);
SERVING_ENFORCE(n_iter != nodes_.end(), errors::ErrorCode::LOGIC_ERROR,
"can not found input node:{} for node:{}", input_nodes[i],
name);
auto edge = std::make_shared<Edge>(n_iter->first, name, i);
n_iter->second->AddOutEdge(edge);
node->AddInEdge(edge);
edges_.emplace_back(edge);
}
}
// find exit node
// exit_node 在一个执行图中只有一个
size_t exit_node_count = 0;
for (const auto& pair : nodes_) {
if (pair.second->out_edges().empty()) {
exit_node_ = pair.second;
++exit_node_count;
}
}
SERVING_ENFORCE(!entry_nodes_.empty(), errors::ErrorCode::LOGIC_ERROR,
"can not found any entry node, please check graph def.");
SERVING_ENFORCE(exit_node_count == 1, errors::ErrorCode::LOGIC_ERROR,
"found {} exit nodes, expect only 1 in graph",
exit_node_count);
SERVING_ENFORCE(exit_node_->GetOpDef()->tag().returnable(),
errors::ErrorCode::LOGIC_ERROR,
"exit node({}) op({}) must returnable", exit_node_->GetName(),
exit_node_->GetOpDef()->name());
CheckNodesReachability();
CheckEdgeValidate();
// 这里构建了Execution
BuildExecution();
CheckExecutionValidate();
}
模型的执行图初始化完成之后,就开始构建Executor、Executable并且用它们初始化ExecutionCore了,Executable是Executor的容器类,所以我们在这里主要看Executor的构建过程。
Executor是Execute执行的最小单位,serving预测通过Executor实现了动态调度的能力。
// 构建Executor
std::vector<std::shared_ptr<Executor>> executors;
for (const auto& execution : graph.GetExecutions()) {
executors.emplace_back(std::make_shared<Executor>(execution));
}
ExecutionCore::Options exec_opts;
exec_opts.id = opts_.service_id;
exec_opts.party_id = self_party_id;
exec_opts.executable = std::make_shared<Executable>(std::move(executors));
if (opts_.server_config.op_exec_worker_num() > 0) {
exec_opts.op_exec_workers_num = opts_.server_config.op_exec_worker_num();
}
if (!opts_.server_config.feature_mapping().empty()) {
exec_opts.feature_mapping = {opts_.server_config.feature_mapping().begin(),
opts_.server_config.feature_mapping().end()};
}
exec_opts.feature_source_config = opts_.feature_source_config;
// 构建ExecutionCore
auto execution_core = std::make_shared<ExecutionCore>(std::move(exec_opts));
构建Executor的时候做了什么呢?我们继续看这里的构造函数:
首先,给每一个node 创建了一个op_kernel,和之前的op一样,op_kernel也有一个工厂类,也同样采用静态注册的方式注册op_kernel,这里注册的时候会在creators_里面放一个function用于回调。Executor执行时就会调度到op_kernel。
https://github.com/secretflow/serving/blob/475bb3356e3a246f444fc25f62f9619874870680/secretflow_serving/ops/op_kernel_factory.h#L22
Create会根据opts.op_def->name()调用不同的回调,**这里OpKernel的子类有三个:ArrowProcessing、MergeY、DotProduct,前两个都是Arrow表示的数据,后一个是Eigen表示的数据。**在构造时会构造不同的schema。
Executor::Executor(const std::shared_ptr<Execution>& execution)
: execution_(execution) {
// create op_kernel
auto nodes = execution_->nodes();
node_items_ = std::make_shared<
std::unordered_map<std::string, std::shared_ptr<NodeItem>>>();
for (const auto& [node_name, node] : nodes) {
op::OpKernelOptions ctx{node->node_def(), node->GetOpDef()};
auto item = std::make_shared<NodeItem>();
item->node = node;
item->op_kernel =
op::OpKernelFactory::GetInstance()->Create(std::move(ctx));
node_items_->emplace(node_name, item);
}
我们以ArrowProcessing为例,看看它的构造流程。首先它调用了父类构造:
// 调用父类方法
explicit OpKernel(OpKernelOptions opts) : opts_(std::move(opts)) {
num_inputs_ = opts_.op_def->inputs_size();
if (opts_.op_def->tag().variable_inputs()) {
// The actual number of inputs for op with variable parameters
// depends on node's parents.
num_inputs_ = opts_.node_def.parents_size();
}
}
然后执行 BuildInputSchema()和BuildOutputSchema(),BuildInputSchema调用了GetNodeBytesAttr,GetNodeBytesAttr是一个inline方法,用来规避ODR,也是常用的写法,一般都会建议多用inline少用static。
// 使用GetNodeBytesAttr(opts_.node_def, "input_schema_bytes") 调用
inline std::string GetNodeBytesAttr(const NodeDef& node_def,
const std::string& attr_name) {
std::string value;
if (!GetNodeBytesAttr(node_def, attr_name, &value)) {
SERVING_THROW(errors::ErrorCode::UNEXPECTED_ERROR,
"can not get attr:{} from node:{}, op:{}", attr_name,
node_def.name(), node_def.op());
}
return value;
}
bool GetNodeBytesAttr(const NodeDef& node_def, const std::string& attr_name,
std::vector<std::string>* value) {
AttrValue attr_value;
if (!GetAttrValue(node_def, attr_name, &attr_value)) {
return false;
}
SERVING_ENFORCE(
attr_value.has_by(), errors::ErrorCode::LOGIC_ERROR,
"attr_value({}) does not have expected type(bytes) value, node: {}",
attr_name, node_def.name());
SERVING_ENFORCE(!attr_value.bys().data().empty(),
errors::ErrorCode::INVALID_ARGUMENT,
"attr_value({}) type(BytesList) has empty value, node: {}",
attr_name, node_def.name());
value->reserve(attr_value.bys().data().size());
for (const auto& v : attr_value.bys().data()) {
value->emplace_back(v);
}
return true;
}
从这里我们也可以看出来, ** input_schema_bytes存放在map
然后调用arrow进行反序列化就得到了input schema。BuildOutputSchema()也是同样的流程,这里就不重复介绍了。
std::shared_ptr<arrow::Schema> DeserializeSchema(const std::string& buf) {
std::shared_ptr<arrow::Schema> result;
std::shared_ptr<arrow::io::RandomAccessFile> buffer_reader =
std::make_shared<arrow::io::BufferReader>(buf);
arrow::ipc::DictionaryMemo tmp_memo;
SERVING_GET_ARROW_RESULT(
arrow::ipc::ReadSchema(
std::static_pointer_cast<arrow::io::InputStream>(buffer_reader).get(),
&tmp_memo),
result);
return result;
}
之后的代码在arrow_processing中注册函数调用,这个在预测中我会再次介绍。
https://github.com/secretflow/serving/blob/475bb3356e3a246f444fc25f62f9619874870680/secretflow_serving/ops/arrow_processing.cc#L143
继续看Executor的构造流程,剩下部分建立了node_name到input_schema的映射,然后找到输入特征的schema,并且对node中所有的input_schema进行校验。
// get input schema
const auto& entry_nodes = execution_->GetEntryNodes();
for (const auto& node : entry_nodes) {
const auto& node_name = node->node_def().name();
auto iter = node_items_->find(node_name);
SERVING_ENFORCE(iter != node_items_->end(), errors::ErrorCode::LOGIC_ERROR);
const auto& input_schema = iter->second->op_kernel->GetAllInputSchema();
entry_node_names_.emplace_back(node_name);
input_schema_map_.emplace(node_name, input_schema);
}
if (execution_->IsEntry()) {
// build feature schema from entry execution
auto iter = input_schema_map_.begin();
const auto& first_input_schema_list = iter->second;
SERVING_ENFORCE(first_input_schema_list.size() == 1,
errors::ErrorCode::LOGIC_ERROR);
const auto& target_schema = first_input_schema_list.front();
++iter;
for (; iter != input_schema_map_.end(); ++iter) {
SERVING_ENFORCE_EQ(iter->second.size(), 1U,
"entry nodes should have only one input table");
const auto& schema = iter->second.front();
SERVING_ENFORCE_EQ(
target_schema->num_fields(), schema->num_fields(),
"entry nodes should have same shape inputs, expect: {}, found: {}",
target_schema->num_fields(), schema->num_fields());
CheckReferenceFields(schema, target_schema,
fmt::format("entry nodes should have same input "
"schema, found node:{} mismatch",
iter->first));
}
input_feature_schema_ = target_schema;
}
}
ExecutionCore是预测时的核心模块,我们看一下这部分代码:
ExecutionCore::ExecutionCore(Options opts)
: opts_(std::move(opts)),
stats_({{"handler", "ExecutionCore"}, {"party_id", opts_.party_id}}) {
SERVING_ENFORCE(!opts_.id.empty(), errors::ErrorCode::INVALID_ARGUMENT);
SERVING_ENFORCE(!opts_.party_id.empty(), errors::ErrorCode::INVALID_ARGUMENT);
SERVING_ENFORCE(opts_.executable, errors::ErrorCode::INVALID_ARGUMENT);
SERVING_ENFORCE(opts_.op_exec_workers_num > 0,
errors::ErrorCode::INVALID_ARGUMENT);
// 线程池初始化
ThreadPool::GetInstance()->Start(opts_.op_exec_workers_num);
// key: model input feature name
// value: source or predefined feature name
// 这里可以插入一些预定义的模型特征,前面的初始化流程中可以进行配置
std::unordered_map<std::string, std::string> model_feature_mapping;
valid_feature_mapping_flag_ = false;
if (opts_.feature_mapping.has_value()) {
for (const auto& pair : opts_.feature_mapping.value()) {
if (pair.first != pair.second) {
valid_feature_mapping_flag_ = true;
}
SERVING_ENFORCE(
model_feature_mapping.emplace(pair.second, pair.first).second,
errors::ErrorCode::INVALID_ARGUMENT,
"found duplicate feature mapping value:{}", pair.second);
}
}
// 读取Executable的所有schema
const auto& model_input_schema = opts_.executable->GetInputFeatureSchema();
// 加入source_schema_中
if (model_feature_mapping.empty()) {
source_schema_ = model_input_schema;
} else {
arrow::SchemaBuilder builder;
int num_fields = model_input_schema->num_fields();
for (int i = 0; i < num_fields; ++i) {
const auto& f = model_input_schema->field(i);
auto iter = model_feature_mapping.find(f->name());
SERVING_ENFORCE(iter != model_feature_mapping.end(),
errors::ErrorCode::INVALID_ARGUMENT,
"can not found {} in feature mapping rule", f->name());
SERVING_CHECK_ARROW_STATUS(
builder.AddField(arrow::field(iter->second, f->type())));
}
SERVING_GET_ARROW_RESULT(builder.Finish(), source_schema_);
}
// 初始化feature_adapter_
if (opts_.feature_source_config.has_value()) {
SPDLOG_INFO("create feature adapter, type:{}",
static_cast<int>(opts_.feature_source_config->options_case()));
feature_adapter_ = feature::FeatureAdapterFactory::GetInstance()->Create(
*opts_.feature_source_config, opts_.id, opts_.party_id, source_schema_);
}
}
FeatureSourceConfig也是一个proto定义:
// Config for a feature source
message FeatureSourceConfig {
oneof options {
MockOptions mock_opts = 1;
HttpOptions http_opts = 2;
CsvOptions csv_opts = 3;
}
}
可以看出来,它对三种格式做了适配,并且通过OnFetchFeature方法处理请求,这个我们在预测阶段再介绍。
这里一共实例化了两个brpc server,第一个是普罗米修斯监控服务,普罗米修斯应该是现在最常用的监控框架。
// start mertrics server
if (opts_.server_config.metrics_exposer_port() > 0) {
std::vector<std::string> strs = absl::StrSplit(self_address, ':');
SERVING_ENFORCE(strs.size() == 2, errors::ErrorCode::LOGIC_ERROR,
"invalid self address.");
auto metrics_listen_address = fmt::format(
"{}:{}", strs[0], opts_.server_config.metrics_exposer_port());
brpc::ServerOptions metrics_server_options;
if (opts_.server_config.has_tls_config()) {
auto* ssl_opts = metrics_server_options.mutable_ssl_options();
ssl_opts->default_cert.certificate =
opts_.server_config.tls_config().certificate_path();
ssl_opts->default_cert.private_key =
opts_.server_config.tls_config().private_key_path();
ssl_opts->verify.verify_depth = 1;
ssl_opts->verify.ca_file_path =
opts_.server_config.tls_config().ca_file_path();
}
// @hint 注册普罗米修斯监控服务
auto* metrics_service = new metrics::MetricsService();
metrics_service->RegisterCollectable(metrics::GetDefaultRegistry());
metrics_server_.set_version(SERVING_VERSION_STRING);
if (metrics_server_.AddService(metrics_service,
brpc::SERVER_OWNS_SERVICE) != 0) {
SERVING_THROW(errors::ErrorCode::UNEXPECTED_ERROR,
"fail to add metrics service into brpc server.");
}
if (metrics_server_.Start(metrics_listen_address.c_str(),
&metrics_server_options) != 0) {
SERVING_THROW(errors::ErrorCode::UNEXPECTED_ERROR,
"fail to start metrics server at {}", self_address);
}
SPDLOG_INFO("begin metrics service listen at {}, ", metrics_listen_address);
}
模型server包括三个service,第一个是模型服务,这里暂时还没有动态模型注册功能,只是能获取当前模型的状态信息等。
// build model_info_collector
// @hint 模型信息收集
ModelInfoCollector::Options m_c_opts;
m_c_opts.model_bundle = model_bundle;
m_c_opts.service_id = opts_.service_id;
m_c_opts.self_party_id = self_party_id;
m_c_opts.remote_channel_map = channels;
ModelInfoCollector model_info_collector(std::move(m_c_opts));
{
auto max_retry_cnt =
opts_.cluster_config.channel_desc().handshake_max_retry_cnt();
if (max_retry_cnt != 0) {
model_info_collector.SetRetryCounts(max_retry_cnt);
}
auto retry_interval_ms =
opts_.cluster_config.channel_desc().handshake_retry_interval_ms();
if (retry_interval_ms != 0) {
model_info_collector.SetRetryIntervalMs(retry_interval_ms);
}
}
// add services
auto* model_service = new ModelServiceImpl(
{{opts_.service_id, model_info_collector.GetSelfModelInfo()}},
self_party_id);
// @hint brpc注册服务
// 模型服务,没有动态模型注册功能
if (service_server_.AddService(model_service, brpc::SERVER_OWNS_SERVICE) !=
0) {
SERVING_THROW(errors::ErrorCode::UNEXPECTED_ERROR,
"fail to add model service into brpc server.");
}
接口如下:
// 模型 - 服务入口
class ModelServiceImpl : public apis::ModelService {
public:
explicit ModelServiceImpl(std::map<std::string, ModelInfo> model_infos,
const std::string& self_party_id);
void GetModelInfo(::google::protobuf::RpcController* controller,
const apis::GetModelInfoRequest* request,
apis::GetModelInfoResponse* response,
::google::protobuf::Closure* done) override;
private:
struct Stats {
// for request api
::prometheus::Family<::prometheus::Counter>& api_request_counter_family;
::prometheus::Family<::prometheus::Summary>&
api_request_duration_summary_family;
explicit Stats(std::map<std::string, std::string> labels,
const std::shared_ptr<::prometheus::Registry>& registry =
metrics::GetDefaultRegistry());
};
void RecordMetrics(const apis::GetModelInfoRequest& request,
const apis::GetModelInfoResponse& response,
double duration_ms, const std::string& action);
执行服务使用execution_core进行初始化,有状态信息、执行、监控接口:
auto* execution_service = new ExecutionServiceImpl(execution_core);
if (service_server_.AddService(execution_service,
brpc::SERVER_OWNS_SERVICE) != 0) {
SERVING_THROW(errors::ErrorCode::UNEXPECTED_ERROR,
"fail to add execution service into brpc server.");
}
class ExecutionServiceImpl : public apis::ExecutionService {
public:
explicit ExecutionServiceImpl(
const std::shared_ptr<ExecutionCore>& execution_core);
void Execute(::google::protobuf::RpcController* controller,
const apis::ExecuteRequest* request,
apis::ExecuteResponse* response,
::google::protobuf::Closure* done) override;
private:
void RecordMetrics(const apis::ExecuteRequest& request,
const apis::ExecuteResponse& response, double duration_ms,
const std::string& action);
struct Stats {
// for service interface
::prometheus::Family<::prometheus::Counter>& api_request_counter_family;
::prometheus::Family<::prometheus::Summary>&
api_request_duration_summary_family;
Stats(std::map<std::string, std::string> labels,
const std::shared_ptr<::prometheus::Registry>& registry =
metrics::GetDefaultRegistry());
};
预测服务会比执行服务多一些预处理逻辑,最后仍然会调用到execution_core。
这里在初始化之后会进行一些server的参数设置:
auto* prediction_service = new PredictionServiceImpl(self_party_id);
if (service_server_.AddService(prediction_service,
brpc::SERVER_OWNS_SERVICE) != 0) {
SERVING_THROW(errors::ErrorCode::UNEXPECTED_ERROR,
"fail to add prediction service into brpc server.");
}
// build services server opts
brpc::ServerOptions server_options;
server_options.max_concurrency = opts_.server_config.max_concurrency();
if (opts_.server_config.worker_num() > 0) {
server_options.num_threads = opts_.server_config.worker_num();
}
if (opts_.server_config.brpc_builtin_service_port() > 0) {
server_options.has_builtin_services = true;
server_options.internal_port =
opts_.server_config.brpc_builtin_service_port();
SPDLOG_INFO("internal port: {}", server_options.internal_port);
}
if (opts_.server_config.has_tls_config()) {
auto* ssl_opts = server_options.mutable_ssl_options();
ssl_opts->default_cert.certificate =
opts_.server_config.tls_config().certificate_path();
ssl_opts->default_cert.private_key =
opts_.server_config.tls_config().private_key_path();
ssl_opts->verify.verify_depth = 1;
ssl_opts->verify.ca_file_path =
opts_.server_config.tls_config().ca_file_path();
}
health::ServingHealthReporter hr;
server_options.health_reporter = &hr;
// start services server
service_server_.set_version(SERVING_VERSION_STRING);
if (service_server_.Start(self_address.c_str(), &server_options) != 0) {
SERVING_THROW(errors::ErrorCode::UNEXPECTED_ERROR,
"fail to start brpc server at {}", self_address);
}
之后设置prediction_core,prediction_core只记录一些信息。从这里我们可以看出来,execution是单方的,predict不仅有多方,而且有主从,prediction_core用来协调各方的推理流程。
PredictionCore::Options prediction_core_opts;
prediction_core_opts.service_id = opts_.service_id;
prediction_core_opts.party_id = self_party_id;
prediction_core_opts.cluster_ids = std::move(cluster_ids);
prediction_core_opts.predictor = predictor;
auto prediction_core =
std::make_shared<PredictionCore>(std::move(prediction_core_opts));
prediction_service->Init(prediction_core);
到这里启动流程就结束了,返回main函数,运行直到brpc服务结束。
我们先看execute请求的proto定义:
// Execute request containing one or more requests.
message ExecuteRequest {
// Custom data. The header will be passed to the downstream system which
// implement the feature service spi.
Header header = 1;
// Represents the id of the requesting party
string requester_id = 2;
// Model service specification.
// 在execute时只需要检查和ExecutionCore定义的id一致即可。
ServiceSpec service_spec = 3;
// Represents the session of this execute.
string session_id = 4;
FeatureSource feature_source = 5;
ExecutionTask task = 6;
}
FeatureSource定义了特征的拉取策略:
// Support feature source type
enum FeatureSourceType {
UNKNOWN_FS_TYPE = 0;
// No need features.
FS_NONE = 1;
// Fetch features from feature service.
FS_SERVICE = 2;
// The feature is defined in the request.
FS_PREDEFINED = 3;
}
// Descriptive feature source
message FeatureSource {
// Identifies the source type of the features
FeatureSourceType type = 1;
// Custom parameter for fetch features from feature service or other systems.
// Valid when `type==FeatureSourceType::FS_SERVICE`
FeatureParam fs_param = 2;
// Defined features.
// Valid when `type==FeatureSourceType::FS_PREDEFINED`
repeated Feature predefineds = 3;
}
ExecutionTask指定了execution id和execution的入参,入参通过NodeIo序列化传入。
// Execute request task.
message ExecutionTask {
// Specified the execution id.
int32 execution_id = 1;
repeated NodeIo nodes = 2;
}
// The serialized data of the node input/output.
message IoData {
repeated bytes datas = 1;
}
// Represents the node input/output data.
message NodeIo {
// Node name.
string name = 1;
repeated IoData ios = 2;
}
因为预测也会调用到Execute,所以我们先看Execution服务。
这里有两种拉取特征的类型,一种是远程拉取,一种是本地拉取。
std::shared_ptr<arrow::RecordBatch> features;
if (request->feature_source().type() ==
apis::FeatureSourceType::FS_SERVICE) {
SERVING_ENFORCE(
!request->feature_source().fs_param().query_datas().empty(),
errors::ErrorCode::INVALID_ARGUMENT,
"get empty feature service query datas.");
SERVING_ENFORCE(request->task().nodes().empty(),
errors::ErrorCode::LOGIC_ERROR);
features = BatchFetchFeatures(request, response);
} else if (request->feature_source().type() ==
apis::FeatureSourceType::FS_PREDEFINED) {
SERVING_ENFORCE(!request->feature_source().predefineds().empty(),
errors::ErrorCode::INVALID_ARGUMENT,
"get empty predefined features.");
SERVING_ENFORCE(request->task().nodes().empty(),
errors::ErrorCode::LOGIC_ERROR);
features = FeaturesToTable(request->feature_source().predefineds(),
source_schema_);
}
先来看远程拉取:
std::shared_ptr<arrow::RecordBatch> ExecutionCore::BatchFetchFeatures(
const apis::ExecuteRequest* request,
apis::ExecuteResponse* response) const {
SERVING_ENFORCE(feature_adapter_, errors::ErrorCode::INVALID_ARGUMENT,
"feature source is not set, please check config.");
yacl::ElapsedTimer timer;
try {
feature::FeatureAdapter::Request fa_request;
fa_request.header = &request->header();
fa_request.fs_param = &request->feature_source().fs_param();
feature::FeatureAdapter::Response fa_response;
fa_response.header = response->mutable_header();
// 拉取特征
feature_adapter_->FetchFeature(fa_request, &fa_response);
RecordBatchFeatureMetrics(request->service_spec().id(),
request->requester_id(), errors::ErrorCode::OK,
timer.CountMs());
return fa_response.features;
} catch (Exception& e) {
RecordBatchFeatureMetrics(request->service_spec().id(),
request->requester_id(), e.code(),
timer.CountMs());
throw e;
}
}
这里就会调用到feature_adapter_进行特征的拉取。FetchFeature调用了子类实现的OnFetchFeature方法。我们以HttpFeatureAdapter来看下,这里也没什么好说的,就是用brpc完成了一个http拉取请求。
void FeatureAdapter::FetchFeature(const Request& request, Response* response) {
OnFetchFeature(request, response);
CheckFeatureValid(request, response->features);
}
/** HttpFeatureAdapter 的OnFetchFeature **/
void HttpFeatureAdapter::OnFetchFeature(const Request& request,
Response* response) {
auto request_body = SerializeRequest(request);
yacl::ElapsedTimer timer;
brpc::Controller cntl;
cntl.http_request().uri() = spec_.http_opts().endpoint();
cntl.http_request().set_method(brpc::HTTP_METHOD_POST);
cntl.http_request().set_content_type("application/json");
cntl.request_attachment().append(request_body);
channel_->CallMethod(NULL, &cntl, NULL, NULL, NULL);
SERVING_ENFORCE(!cntl.Failed(), errors::ErrorCode::NETWORK_ERROR,
"http request failed, endpoint:{}, detail:{}",
spec_.http_opts().endpoint(), cntl.ErrorText());
DeserializeResponse(cntl.response_attachment().to_string(), response);
}
还有一种是预定义的特征,就是请求携带了特征,这里会将特征读取到arrow中,比较简单,不多讲了。
std::shared_ptr<arrow::RecordBatch> FeaturesToTable(
const ::google::protobuf::RepeatedPtrField<Feature>& features,
const std::shared_ptr<const arrow::Schema>& target_schema) {
arrow::SchemaBuilder schema_builder;
std::vector<std::shared_ptr<arrow::Array>> arrays;
int num_rows = -1;
for (const auto& field : target_schema->fields()) {
bool found = false;
for (const auto& f : features) {
if (f.field().name() == field->name()) {
FeatureToArrayVisitor visitor{.target_field = field, .array = {}};
FeatureVisit(visitor, f);
if (num_rows >= 0) {
SERVING_ENFORCE_EQ(
num_rows, visitor.array->length(),
"features must have same length value. {}:{}, others:{}",
f.field().name(), visitor.array->length(), num_rows);
}
num_rows = visitor.array->length();
arrays.emplace_back(visitor.array);
found = true;
break;
}
}
SERVING_ENFORCE(found, errors::ErrorCode::UNEXPECTED_ERROR,
"can not found feature:{} in response", field->name());
}
return MakeRecordBatch(target_schema, num_rows, std::move(arrays));
}
ApplyFeatureMappingRule实现了在线特征和模型特征的转换。
在配置的feature_mapping中找到对应字段的schema,然后使用arrow::RecordBatch::Make进行特征转换。
std::shared_ptr<arrow::RecordBatch> ExecutionCore::ApplyFeatureMappingRule(
const std::shared_ptr<arrow::RecordBatch>& features) {
if (features == nullptr || !valid_feature_mapping_flag_) {
// no need mapping
return features;
}
const auto& feature_mapping = opts_.feature_mapping.value();
int num_cols = features->num_columns();
const auto& old_schema = features->schema();
arrow::SchemaBuilder builder;
for (int i = 0; i < num_cols; ++i) {
auto field = old_schema->field(i);
auto iter = feature_mapping.find(field->name());
if (iter != feature_mapping.end()) {
field = arrow::field(iter->second, field->type());
}
SERVING_CHECK_ARROW_STATUS(builder.AddField(field));
}
std::shared_ptr<arrow::Schema> schema;
SERVING_GET_ARROW_RESULT(builder.Finish(), schema);
return MakeRecordBatch(schema, features->num_rows(), features->columns());
// 调用了 arrow::RecordBatch::Make(schema, num_rows, std::move(columns));
}
首先根据入参进行Task实例化,这里将nodeio的参数转换为了op::OpComputeInputs。
// executable run
Executable::Task task;
task.id = request->task().execution_id();
task.features = features;
task.node_inputs = std::make_shared<std::unordered_map<
std::string, std::shared_ptr<op::OpComputeInputs>>>();
for (const auto& n : request->task().nodes()) {
auto compute_inputs = std::make_shared<op::OpComputeInputs>();
for (const auto& io : n.ios()) {
std::vector<std::shared_ptr<arrow::RecordBatch>> inputs;
for (const auto& d : io.datas()) {
inputs.emplace_back(DeserializeRecordBatch(d));
}
compute_inputs->emplace_back(std::move(inputs));
}
task.node_inputs->emplace(n.name(), std::move(compute_inputs));
}
opts_.executable->Run(task);
然后调用执行,如果有特征的话,那么就用特征进行预测;否则,使用使用node_inputs进行预测。这里其实会将feature转换为std::unordered_map
void Executable::Run(Task& task) {
SERVING_ENFORCE(task.id < executors_.size(), errors::ErrorCode::LOGIC_ERROR);
auto executor = executors_[task.id];
if (task.features) {
task.outputs = executor->Run(task.features);
} else {
SERVING_ENFORCE(!task.node_inputs->empty(), errors::ErrorCode::LOGIC_ERROR);
task.outputs = executor->Run(*(task.node_inputs));
}
SPDLOG_DEBUG("Executable::Run end, task.outputs.size:{}",
task.outputs->size());
}
首先我们看Executor的Run方法:
std::shared_ptr<std::vector<NodeOutput>> Executor::Run(
std::shared_ptr<arrow::RecordBatch>& features) {
SERVING_ENFORCE(execution_->IsEntry(), errors::ErrorCode::LOGIC_ERROR);
auto inputs =
std::unordered_map<std::string, std::shared_ptr<op::OpComputeInputs>>();
for (size_t i = 0; i < entry_node_names_.size(); ++i) {
auto op_inputs = std::make_shared<op::OpComputeInputs>();
std::vector<std::shared_ptr<arrow::RecordBatch>> record_list = {features};
op_inputs->emplace_back(std::move(record_list));
inputs.emplace(entry_node_names_[i], std::move(op_inputs));
}
return Run(inputs);
}
// 然后调用下面的run方法
std::shared_ptr<std::vector<NodeOutput>> Executor::Run(
std::unordered_map<std::string, std::shared_ptr<op::OpComputeInputs>>&
inputs) {
// 入口node
std::vector<std::shared_ptr<op::OpComputeInputs>> entry_node_inputs;
for (const auto& node : execution_->GetEntryNodes()) {
auto iter = inputs.find(node->node_def().name());
SERVING_ENFORCE(iter != inputs.end(), errors::ErrorCode::INVALID_ARGUMENT,
"can not found inputs for node:{}",
node->node_def().name());
entry_node_inputs.emplace_back(iter->second);
}
// 实例化一个调度器
// execution_ 每个executor都有一个,在初始化时传入
auto sched = std::make_shared<ExecuteScheduler>(
node_items_, execution_->GetExitNodeNum(), ThreadPool::GetInstance(),
execution_);
// 调度器执行入口
const auto& entry_nodes = execution_->GetEntryNodes();
for (size_t i = 0; i != execution_->GetEntryNodeNum(); ++i) {
sched->AddEntryNode(entry_nodes[i], entry_node_inputs[i]);
}
sched->Schedule();
auto task_exception = sched->GetTaskException();
if (task_exception) {
SPDLOG_ERROR("Execution {} run with exception.", execution_->id());
std::rethrow_exception(task_exception);
}
SERVING_ENFORCE_EQ(sched->GetSchedCount(), execution_->nodes().size());
return std::make_shared<std::vector<NodeOutput>>(sched->GetResults());
}
这里最关键的在于调度器,它的入参是:
接下来我们看这个调度器的代码。
首先,在构造函数这里有个有意思的东西:
ExecuteScheduler(
std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<NodeItem>>>
node_items,
uint64_t res_cnt, const std::shared_ptr<ThreadPool>& thread_pool,
std::shared_ptr<Execution> execution)
: node_items_(std::move(node_items)),
context_(res_cnt),
thread_pool_(thread_pool),
execution_(std::move(execution)),
propagator_(execution_->nodes()),
sched_count_(0) {}
注意到了吗,propagator_。我们来看它的定义和实现,看起来它保管了所有node的状态信息和输入输出。
像是某种用于节点管控的东西。
struct ComputeContext {
// TODO: Session
OpComputeInputs inputs;
std::shared_ptr<arrow::RecordBatch> output;
};
struct FrameState {
std::atomic<int> pending_count;
op::ComputeContext compute_ctx;
};
class Propagator {
public:
explicit Propagator(
const std::unordered_map<std::string, std::shared_ptr<Node>>& nodes);
FrameState* GetFrame(const std::string& node_name);
private:
std::unordered_map<std::string, FrameState*> node_frame_map_;
std::vector<FrameState> frame_pool_;
};
} // namespace secretflow::serving
Propagator::Propagator(
const std::unordered_map<std::string, std::shared_ptr<Node>>& nodes) {
frame_pool_ = std::vector<FrameState>(nodes.size());
size_t idx = 0;
for (auto& [node_name, node] : nodes) {
auto frame = &frame_pool_[idx++];
frame->pending_count = node->GetInputNum();
frame->compute_ctx.inputs.resize(frame->pending_count);
SERVING_ENFORCE(node_frame_map_.emplace(node_name, std::move(frame)).second,
errors::ErrorCode::LOGIC_ERROR);
}
}
FrameState* Propagator::GetFrame(const std::string& node_name) {
auto iter = node_frame_map_.find(node_name);
SERVING_ENFORCE(iter != node_frame_map_.end(), errors::ErrorCode::LOGIC_ERROR,
"can not found frame for node: {}", node_name);
return iter->second;
}
我们带着疑问继续来看执行代码。
调度器的调度代码,这里注释说想用bthread的能力来实现worker窃取,挺有意思,关于brpc我之前也写过一篇,感兴趣的可以看看:
https://www.yuque.com/treblez/qksu6c/owbw5sm9xzmqv2qa?singleDoc# 《brpc:优秀代码鉴赏》
这里没有换出机制,仅仅是一个简单的原地等待。ready_nodes_是一个ThreadSafeQueue,这个也不是什么无锁队列,简单的mutex同步,代码就不贴了。
void Schedule() {
while (!stop_flag_.load() && !context_.IsFinish()) {
// TODO: consider use bthread::Mutex and bthread::ConditionVariable
// to make this worker can switch to others
std::shared_ptr<NodeItem> node_item;
// 调用了ready_nodes_.WaitPop(node_item);
// ready_nodes_是一个 ThreadSafeQueue
if (!context_.GetReadyNode(node_item)) {
continue;
}
SubmitExecuteOpTask(node_item);
}
}
/** 调用到下边 **/
void SubmitExecuteOpTask(std::shared_ptr<NodeItem>& node_item) {
if (stop_flag_.load()) {
return;
}
thread_pool_->SubmitTask(
std::make_unique<ExecuteOpTask>(node_item, shared_from_this()));
}
/** 下面的类被提交到线程池 **/
class ExecuteOpTask : public ThreadPool::Task {
public:
const char* Name() override { return "ExecuteOpTask"; }
ExecuteOpTask(std::shared_ptr<NodeItem> node_item,
std::shared_ptr<ExecuteScheduler> sched)
: node_item_(std::move(node_item)), sched_(std::move(sched)) {}
// 回调ExecuteOp
void Exec() override { sched_->ExecuteOp(node_item_); }
void OnException(std::exception_ptr e) noexcept override {
sched_->SetTaskException(e);
}
private:
std::shared_ptr<NodeItem> node_item_;
std::shared_ptr<ExecuteScheduler> sched_;
};
/** 调用ExecuteOp **/
void ExecuteOp(const std::shared_ptr<NodeItem>& node_item) {
if (stop_flag_.load()) {
return;
}
auto* frame = propagator_.GetFrame(node_item->node->node_def().name());
//这里调用了OpKernel,进行执行逻辑
node_item->op_kernel->Compute(&(frame->compute_ctx));
sched_count_++;
if (execution_->IsExitNode(node_item->node->node_def().name())) {
context_.AddResult(node_item->node->node_def().name(),
frame->compute_ctx.output);
}
const auto& edges = node_item->node->out_edges();
for (const auto& edge : edges) {
CompleteOutEdge(edge, frame->compute_ctx.output);
}
}
之前我介绍过OpKernel的子类有三个:ArrowProcessing、MergeY、DotProduct
这里分别讲一下它们执行期的执行方法。
先来看最重要的ArrowProcessing。
ComputeTrace被定义在proto中,表示要执行的方法。
message FunctionTrace {
// The Function name.
string name = 1;
// The serialized function options.
bytes option_bytes = 2;
// Inputs of this function.
repeated FunctionInput inputs = 3;
// Output of this function.
FunctionOutput output = 4;
}
message ComputeTrace {
// The name of this Compute.
string name = 1;
repeated FunctionTrace func_traces = 2;
}
// Function name定义如下:
enum ExtendFunctionName {
// Placeholder for proto3 default value, do not use it
UNKOWN_EX_FUNCTION_NAME = 0;
// Get colunm from table(record_batch).
// see
// https://arrow.apache.org/docs/cpp/api/table.html#_CPPv4NK5arrow11RecordBatch6columnEi
EFN_TB_COLUMN = 1;
// Add colum to table(record_batch).
// see
// https://arrow.apache.org/docs/cpp/api/table.html#_CPPv4NK5arrow11RecordBatch9AddColumnEiNSt6stringERKNSt10shared_ptrI5ArrayEE
EFN_TB_ADD_COLUMN = 2;
// Remove colunm from table(record_batch).
// see
// https://arrow.apache.org/docs/cpp/api/table.html#_CPPv4NK5arrow11RecordBatch12RemoveColumnEi
EFN_TB_REMOVE_COLUMN = 3;
// Set colunm to table(record_batch).
// see
// https://arrow.apache.org/docs/cpp/api/table.html#_CPPv4NK5arrow11RecordBatch9SetColumnEiRKNSt10shared_ptrI5FieldEERKNSt10shared_ptrI5ArrayEE
EFN_TB_SET_COLUMN = 4;
}
// FunctionInput 定义如下
message FunctionInput {
oneof value {
// '0' means root input data
int32 data_id = 1;
Scalar custom_scalar = 2;
}
}
我们可以看到,这里的几个函数就是Arrow的CRUD。
执行逻辑如下所示,这里会将compute_trace_中的所有定义函数放到执行列表func_list_里面,后续会对func_list_进行顺序的调用。
switch (ex_func_name) {
case compute::ExtendFunctionName::EFN_TB_COLUMN: {
// 查找
func_list_.emplace_back([](arrow::Datum& result_datum,
std::vector<arrow::Datum>& func_inputs) {
result_datum = func_inputs[0].record_batch()->column(
std::static_pointer_cast<arrow::Int64Scalar>(
func_inputs[1].scalar())
->value);
});
break;
}
case compute::ExtendFunctionName::EFN_TB_ADD_COLUMN: {
// 增加一列
func_list_.emplace_back([](arrow::Datum& result_datum,
std::vector<arrow::Datum>& func_inputs) {
int64_t index = std::static_pointer_cast<arrow::Int64Scalar>(
func_inputs[1].scalar())
->value;
std::string field_name(
std::static_pointer_cast<arrow::StringScalar>(
func_inputs[2].scalar())
->view());
std::shared_ptr<arrow::RecordBatch> new_batch;
SERVING_GET_ARROW_RESULT(
func_inputs[0].record_batch()->AddColumn(
index, std::move(field_name), func_inputs[3].make_array()),
new_batch);
result_datum = new_batch;
});
break;
}
case compute::ExtendFunctionName::EFN_TB_REMOVE_COLUMN: {
// 删除一列
func_list_.emplace_back([](arrow::Datum& result_datum,
std::vector<arrow::Datum>& func_inputs) {
std::shared_ptr<arrow::RecordBatch> new_batch;
SERVING_GET_ARROW_RESULT(
func_inputs[0].record_batch()->RemoveColumn(
std::static_pointer_cast<arrow::Int64Scalar>(
func_inputs[1].scalar())
->value),
new_batch);
result_datum = new_batch;
});
break;
}
case compute::ExtendFunctionName::EFN_TB_SET_COLUMN: {
// 修改一列
func_list_.emplace_back([](arrow::Datum& result_datum,
std::vector<arrow::Datum>& func_inputs) {
int64_t index = std::static_pointer_cast<arrow::Int64Scalar>(
func_inputs[1].scalar())
->value;
std::string field_name(
std::static_pointer_cast<arrow::StringScalar>(
func_inputs[2].scalar())
->view());
std::shared_ptr<arrow::Array> array = func_inputs[3].make_array();
std::shared_ptr<arrow::RecordBatch> new_batch;
SERVING_GET_ARROW_RESULT(
func_inputs[0].record_batch()->SetColumn(
index, arrow::field(std::move(field_name), array->type()),
array),
new_batch);
result_datum = new_batch;
});
break;
}
default:
SERVING_THROW(errors::ErrorCode::UNEXPECTED_ERROR,
"invalid ext func name enum: {}",
static_cast<int>(ex_func_name));
}
然后我们继续看执行时逻辑DoCompute,这里比较奇怪的是拆成了两个,这里只是做一个顺序的调用,就不多讲了。
void ArrowProcessing::DoCompute(ComputeContext* ctx) {
// sanity check
SERVING_ENFORCE(ctx->inputs.size() == 1, errors::ErrorCode::LOGIC_ERROR);
SERVING_ENFORCE(ctx->inputs.front().size() == 1,
errors::ErrorCode::LOGIC_ERROR);
if (dummy_flag_) {
ctx->output = ctx->inputs.front().front();
return;
}
SPDLOG_INFO("replay compute: {}", compute_trace_.name());
ctx->output = ReplayCompute(ctx->inputs.front().front());
}
std::shared_ptr<arrow::RecordBatch> ArrowProcessing::ReplayCompute(
const std::shared_ptr<arrow::RecordBatch>& input) {
std::map<int32_t, arrow::Datum> datas = {{0, input}};
arrow::Datum result_datum;
for (int i = 0; i < compute_trace_.func_traces_size(); ++i) {
const auto& func = compute_trace_.func_traces(i);
SPDLOG_DEBUG("replay func: {}", func.ShortDebugString());
auto func_inputs = BuildInputDatums(func.inputs(), datas);
func_list_[i](result_datum, func_inputs);
SERVING_ENFORCE(
datas.emplace(func.output().data_id(), std::move(result_datum)).second,
errors::ErrorCode::LOGIC_ERROR);
}
return datas[result_id_].record_batch();
}
和名字一样,这个OpKernel使用Eigen做一个点乘,Arrow没有提供点乘的能力
void DotProduct::DoCompute(ComputeContext* ctx) {
SERVING_ENFORCE(ctx->inputs.size() == 1, errors::ErrorCode::LOGIC_ERROR);
SERVING_ENFORCE(ctx->inputs.front().size() == 1,
errors::ErrorCode::LOGIC_ERROR);
auto features = TableToMatrix(ctx->inputs.front().front());
Double::ColVec score_vec = features * weights_;
score_vec.array() += intercept_;
std::shared_ptr<arrow::Array> array;
arrow::DoubleBuilder builder;
for (int i = 0; i < score_vec.rows(); ++i) {
auto row = score_vec.row(i);
SERVING_CHECK_ARROW_STATUS(builder.AppendValues(row.data(), 1));
}
SERVING_CHECK_ARROW_STATUS(builder.Finish(&array));
ctx->output = MakeRecordBatch(output_schema_, score_vec.rows(), {array});
}
MergeY对两方数据做一个合并:
void MergeY::DoCompute(ComputeContext* ctx) {
// santiy check
SERVING_ENFORCE(ctx->inputs.size() == 1, errors::ErrorCode::LOGIC_ERROR);
SERVING_ENFORCE(ctx->inputs.front().size() >= 1,
errors::ErrorCode::LOGIC_ERROR);
// merge partial_y
arrow::Datum incremented_datum(ctx->inputs.front()[0]->column(0));
for (size_t i = 1; i < ctx->inputs.front().size(); ++i) {
auto cur_array = ctx->inputs.front()[i]->column(0);
SERVING_GET_ARROW_RESULT(arrow::compute::Add(incremented_datum, cur_array),
incremented_datum);
}
auto merged_array = std::static_pointer_cast<arrow::DoubleArray>(
std::move(incremented_datum).make_array());
// apply link func
arrow::DoubleBuilder builder;
SERVING_CHECK_ARROW_STATUS(builder.Resize(merged_array->length()));
for (int64_t i = 0; i < merged_array->length(); ++i) {
auto score =
ApplyLinkFunc(merged_array->Value(i), link_function_) * yhat_scale_;
SERVING_CHECK_ARROW_STATUS(builder.Append(score));
}
std::shared_ptr<arrow::Array> res_array;
SERVING_CHECK_ARROW_STATUS(builder.Finish(&res_array));
ctx->output =
MakeRecordBatch(output_schema_, res_array->length(), {res_array});
}
void CompleteOutEdge(const std::shared_ptr<Edge>& edge,
std::shared_ptr<arrow::RecordBatch> output) {
std::shared_ptr<Node> dst_node;
if (!execution_->TryGetNode(edge->dst_node(), &dst_node)) {
return;
}
auto* child_frame = propagator_.GetFrame(dst_node->GetName());
child_frame->compute_ctx.inputs[edge->dst_input_id()].emplace_back(
std::move(output));
if (child_frame->pending_count.fetch_sub(1) == 1) {
context_.AddReadyNode(
node_items_->find(dst_node->node_def().name())->second);
}
}
// The value of a feature
message FeatureValue {
// int list
repeated int32 i32s = 1;
repeated int64 i64s = 2;
// float list
repeated float fs = 3;
repeated double ds = 4;
// string list
repeated string ss = 5;
// bool list
repeated bool bs = 6;
}
// The definition of a feature field.
message FeatureField {
// Unique name of the feature
string name = 1;
// Field type of the feature
FieldType type = 2;
}
// The definition of a feature
message Feature {
FeatureField field = 1;
FeatureValue value = 2;
}
message PredictRequest {
// Custom data. The header will be passed to the downstream system which
// implement the feature service spi.
Header header = 1;
// Model service specification.
ServiceSpec service_spec = 2;
// The params for fetch features. Note that this should include all the
// parties involved in the prediction.
// Key: party's id.
// Value: params for fetch features.
map<string, FeatureParam> fs_params = 3;
// Optional.
// If defined, the request party will no longer query for the feature but will
// use defined fetures in `predefined_features` for the prediction.
repeated Feature predefined_features = 4;
}
预测仅仅是比Execute多了一个多方通信的过程。
这里会用在启动流程中初始化的rpc channel,启动对方的execute,然后执行自己的流程,最后等待其余方的execute结束。
void Predictor::Predict(const apis::PredictRequest* request,
apis::PredictRespaonse* response) {
std::unordered_map<std::string, std::shared_ptr<apis::NodeIo>>
prev_node_io_map;
std::vector<std::shared_ptr<RemoteExecute>> async_running_execs;
async_running_execs.reserve(opts_.channels->size());
auto execute_locally =
[&](const std::shared_ptr<Execution>& execution,
std::unordered_map<std::string, std::shared_ptr<apis::NodeIo>>&
prev_io_map,
std::unordered_map<std::string, std::shared_ptr<apis::NodeIo>>&
cur_io_map) {
// exec locally
auto local_exec = BuildLocalExecute(request, response, execution);
local_exec->SetInputs(std::move(prev_io_map));
local_exec->Run();
local_exec->GetOutputs(&cur_io_map);
};
for (const auto& e : opts_.executions) {
async_running_execs.clear();
std::unordered_map<std::string, std::shared_ptr<apis::NodeIo>>
new_node_io_map;
if (e->GetDispatchType() == DispatchType::DP_ALL) {
for (const auto& [party_id, channel] : *opts_.channels) {
auto ctx = BuildRemoteExecute(request, response, e, party_id, channel);
ctx->SetInputs(prev_node_io_map);
ctx->Run();
async_running_execs.emplace_back(ctx);
}
// exec locally
if (execution_core_) {
execute_locally(e, prev_node_io_map, new_node_io_map);
for (auto& exec : async_running_execs) {
exec->WaitToFinish();
exec->GetOutputs(&new_node_io_map);
}
} else {
// TODO: support no execution core scene
SERVING_THROW(errors::ErrorCode::NOT_IMPLEMENTED, "not implemented");
}
} else if (e->GetDispatchType() == DispatchType::DP_ANYONE) {
// exec locally
if (execution_core_) {
execute_locally(e, prev_node_io_map, new_node_io_map);
} else {
// TODO: support no execution core scene
SERVING_THROW(errors::ErrorCode::NOT_IMPLEMENTED, "not implemented");
}
} else if (e->GetDispatchType() == DispatchType::DP_SPECIFIED) {
if (e->SpecificToThis()) {
SERVING_ENFORCE(execution_core_, errors::ErrorCode::UNEXPECTED_ERROR);
execute_locally(e, prev_node_io_map, new_node_io_map);
} else {
auto iter = opts_.specific_party_map.find(e->id());
SERVING_ENFORCE(iter != opts_.specific_party_map.end(),
serving::errors::LOGIC_ERROR,
"{} execution assign to no party", e->id());
auto ctx = BuildRemoteExecute(request, response, e, iter->second,
opts_.channels->at(iter->second));
ctx->SetInputs(prev_node_io_map);
ctx->Run();
ctx->WaitToFinish();
ctx->GetOutputs(&new_node_io_map);
}
} else {
SERVING_THROW(errors::ErrorCode::UNEXPECTED_ERROR,
"unsupported dispatch type: {}",
DispatchType_Name(e->GetDispatchType()));
}
prev_node_io_map.swap(new_node_io_map);
}
DealFinalResult(prev_node_io_map, response);
}
最终对于RemoteExecutor的调用最终会调用到下面的代码,其实也就是调用远端的Execute。
void ExecuteContext::Execute(
std::shared_ptr<::google::protobuf::RpcChannel> channel,
brpc::Controller* cntl) {
apis::ExecutionService_Stub stub(channel.get());
stub.Execute(cntl, &exec_req_, &exec_res_, brpc::DoNothing());
}