Aerosolve 简介

Aerosolve 是 Airbnb 公司开源的机器学习库, 更详细的内容请参见: 官方博客, GitHub.

不同于其他开源的机器学习库, Aerosolve 的口号不是"更快,更高,更强",
而是 built for human.
其含义是 用户(即房东)也看得懂模型给出的结果, 作者 Hector Yee 曾在 Spark 大会的讲座上提到过这一点.

用户不需要(也没必要)了解机器学习的相关知识, 即可很直观地从结果中看出:

  1. 哪些特征会影响房价
  2. 这些特征有多大影响

Aerosolve 是怎么办到的呢? 答案很简单: 线性模型.

  1. 按绝对值倒排权重, 这样可以知道哪些特征比较重要.
  2. 重要程度可通过权重值体现出来.

线性模型天生有着非常好的可解释的特性, 是一种应用相当广泛的白盒模型.
其相关资料多如牛毛, 本文只讲下我对 Aerosolve 如何应用该模型的一些探索.

数据结构

所有定义都在 这两个thrift文件 中,
我们需要关注的有4个: FeatureVector, Example, ModelHeader, ModelRecord.

FeatureVector

struct FeatureVector {
  1: optional map> stringFeatures;
  2: optional map> floatFeatures;
  3: optional map> denseFeatures;
}

FeatureVector 是描述特征最基础的单元,
3个 map 属性的 key 被称为 family of features (可以理解为字段名), value 则保存具体的值.

其中 stringFeatures 和 floatFeatures 分别对应"离散值"和"连续值"的概念, 使用频率最高,
而 denseFeatures 是 floatFeatures 的一种特例.

下面是一个 json 格式的示例, geo 和 location 都是 伐木累 的名字.

{
    "stringFeatures": {
        "geo": [
            "San Francisco",
            "CA",
            "USA"
        ]
    },
    "floatFeatures": {
        "location": {
            "lat": 37.7,
            "long": 40
        }
    }
}

Example

struct Example {
  1: optional list example;
  2: optional FeatureVector context;
}

Example 将有共同特征的 FeatureVector 聚合在一起.
相同的部分保存在 context 中, example 中保存了每个 FeatureVector 不同的特征.

之所以这么做的原因是:

  1. 节省存储或传输时占用的空间.
  2. 节省计算资源, 因为 Transform 时只需对 context 做一次计算.

Model相关

ModelHeader记录了模型的相关信息, ModelRecord记录了每个权重的信息和对应的特征.

特征变换(Transform)

FeatureVector 不会直接参与建模, 而是先经过一系列的 on-the-fly 的 Transform,
转换成离散值后以二元变量的形式参与训练.

目前已有几十种 Transform, 主要有两大类方法:

  1. 将连续值转换为离散值, 即 floatFeatures -> stringFeatures.
  2. 交叉多个离散值得到新特征, 即生成两种或多种特征的全排列.

详见 TransformFactory.java.

模型训练

通过交叉方法可以获得数以万计的特征, 同时也使特征矩阵变得稀疏, 那么训练时只需告诉模型样本有哪些特征即可.

最终得到的权重向量每一维会保存为一条 ModelRecord, 再加上一个 ModelHeader, 就完整地描述了训练出的模型.

LinearRankerTrainer.MAX_WEIGHTS 的值可以看出, 目前会保留绝对值最大的100万个权重.

预测 / 模型评估 / Debug

既然特征和模型都是用 thrift 描述的, 那么 Aerosolve 应该是通过 Thrift RPC 提供在线服务.

服务器在收到新样本后, 先根据样本数据构造 FeatureVector, 再应用相同的 Transform, 最后带入模型即可得到预测结果.

没研究 Aerosolve 是怎么做模型评估和 Debug 的.
MLSchema.thrift 中和此相关的有的 EvaluationRecord, DebugScoreRecord, DebugScoreDiffRecord,
搜索类名可以找到相关代码.

备注

  • Aerosolve 刚开源时只有2种模型, 当前版本(0.1.32)已有8种, 详见 ModelFactory.java.
  • 通过 Aerosolve 附带的 demo 可以很好地了解实现细节.
  • 作者推荐通过 Aerosolve 的 test 源码了解使用方法.

你可能感兴趣的:(Aerosolve 简介)