Alink 是阿里巴巴基于实时计算引擎 Flink 研发的新一代机器学习算法平台,是业界首个同时支持批式算法、流式算法的机器学习平台。二分类评估是对二分类算法的预测结果进行效果评估。本文将剖析Alink中对应代码实现
public class EvalBinaryClassExample {
AlgoOperator getData(boolean isBatch) {
Row[] rows = new Row[]{
Row.of("prefix1", "{\"prefix1\": 0.9, \"prefix0\": 0.1}"),
Row.of("prefix1", "{\"prefix1\": 0.8, \"prefix0\": 0.2}"),
Row.of("prefix1", "{\"prefix1\": 0.7, \"prefix0\": 0.3}"),
Row.of("prefix0", "{\"prefix1\": 0.75, \"prefix0\": 0.25}"),
Row.of("prefix0", "{\"prefix1\": 0.6, \"prefix0\": 0.4}")
};
String[] schema = new String[]{"label", "detailInput"};
if (isBatch) {
return new MemSourceBatchOp(rows, schema);
} else {
return new MemSourceStreamOp(rows, schema);
}
}
public static void main(String[] args) throws Exception {
EvalBinaryClassExample test = new EvalBinaryClassExample();
BatchOperator batchData = (BatchOperator) test.getData(true);
BinaryClassMetrics metrics = new EvalBinaryClassBatchOp()
.setLabelCol("label")
.setPredictionDetailCol("detailInput")
.linkFrom(batchData)
.collectMetrics();
System.out.println("RocCurve:" + metrics.getRocCurve());
System.out.println("AUC:" + metrics.getAuc());
System.out.println("KS:" + metrics.getKs());
System.out.println("PRC:" + metrics.getPrc());
System.out.println("Accuracy:" + metrics.getAccuracy());
System.out.println("Macro Precision:" + metrics.getMacroPrecision());
System.out.println("Micro Recall:" + metrics.getMicroRecall());
System.out.println("Weighted Sensitivity:" + metrics.getWeightedSensitivity());
}
}
程序输出
RocCurve:([0.0, 0.0, 0.0, 0.5, 0.5, 1.0, 1.0],[0.0, 0.3333333333333333, 0.6666666666666666, 0.6666666666666666, 1.0, 1.0, 1.0])
AUC:0.8333333333333333
KS:0.6666666666666666
PRC:0.9027777777777777
Accuracy:0.6
Macro Precision:0.3
Micro Recall:0.6
Weighted Sensitivity:0.6
在 Alink 中,二分类评估有批处理,流处理两种实现,下面一一为大家介绍( Alink 复杂之一在于大量精细的数据结构,所以下文会大量打印程序中变量以便大家理解)。
把 [0,1] 分成假设 100000个桶(bin)。所以得到positiveBin / negativeBin 两个100000的数组。
根据输入给positiveBin / negativeBin赋值。positiveBin就是 TP + FP,negativeBin就是 TN + FN。这些是后续计算的基础。
遍历bins中每一个有意义的点,计算出totalTrue和totalFalse,并且在每一个点上计算该点的混淆矩阵,tpr,以及rocCurve,recallPrecisionCurve,liftChart在该点对应的数据;
依据曲线内容计算并且存储 AUC/PRC/KS
具体后续还有详细调用关系综述。
EvalBinaryClassBatchOp是二分类评估的实现,功能是计算二分类的评估指标(evaluation metrics)。
输入有两种:
我们例子中 "prefix1"
就是 label,"{\"prefix1\": 0.9, \"prefix0\": 0.1}"
就是 predDetail
Row.of("prefix1", "{\"prefix1\": 0.9, \"prefix0\": 0.1}")
具体类摘录如下:
public class EvalBinaryClassBatchOp extends BaseEvalClassBatchOp implements BinaryEvaluationParams , EvaluationMetricsCollector {
@Override
public BinaryClassMetrics collectMetrics() {
return new BinaryClassMetrics(this.collect().get(0));
}
}
可以看到,其主要工作都是在基类BaseEvalClassBatchOp中完成,所以我们会首先看BaseEvalClassBatchOp。
我们还是从 linkFrom 函数入手,其主要是做了几件事:
具体代码如下
@Override
public T linkFrom(BatchOperator>... inputs) {
BatchOperator> in = checkAndGetFirst(inputs);
String labelColName = this.get(MultiEvaluationParams.LABEL_COL);
String positiveValue = this.get(BinaryEvaluationParams.POS_LABEL_VAL_STR);
// Judge the evaluation type from params.
ClassificationEvaluationUtil.Type type = ClassificationEvaluationUtil.judgeEvaluationType(this.getParams());
DataSet res;
switch (type) {
case PRED_DETAIL: {
String predDetailColName = this.get(MultiEvaluationParams.PREDICTION_DETAIL_COL);
// 从输入中提取某些列:"label","detailInput"
DataSet data = in.select(new String[] {labelColName, predDetailColName}).getDataSet();
// 按照partition分别计算evaluation metrics
res = calLabelPredDetailLocal(data, positiveValue, binary);
break;
}
......
}
// 综合reduce上述计算结果
DataSet metrics = res
.reduce(new EvaluationUtil.ReduceBaseMetrics());
// 把最终数值输入到 output table
this.setOutput(metrics.flatMap(new EvaluationUtil.SaveDataAsParams()),
new String[] {DATA_OUTPUT}, new TypeInformation[] {Types.STRING});
return (T)this;
}
// 执行中一些变量如下
labelColName = "label"
predDetailColName = "detailInput"
type = {ClassificationEvaluationUtil$Type@2532} "PRED_DETAIL"
binary = true
positiveValue = null
3.2.0 调用关系综述
因为后续代码调用关系复杂,所以先给出一个调用关系:
3.2.1 calLabelPredDetailLocal
本函数按照partition分别计算评估指标 evaluation metrics。是的,这代码很短,但是有个地方需要注意。有时候越简单的地方越容易疏漏。容易疏漏点是:
第一行代码的结果 labels 是第二行代码的参数,而并非第二行主体。第二行代码主体和第一行代码主体一样,都是data。
private static DataSet calLabelPredDetailLocal(DataSet data, final String positiveValue, oolean binary) {
DataSet, String[]>> labels = data.flatMap(new FlatMapFunction() {
@Override
public void flatMap(Row row, Collector collector) {
TreeMap labelProbMap;
if (EvaluationUtil.checkRowFieldNotNull(row)) {
labelProbMap = EvaluationUtil.extractLabelProbMap(row);
labelProbMap.keySet().forEach(collector::collect);
collector.collect(row.getField(0).toString());
}
}
}).reduceGroup(new EvaluationUtil.DistinctLabelIndexMap(binary, positiveValue));
return data
.rebalance()
.mapPartition(new CalLabelDetailLocal(binary))
.withBroadcastSet(labels, LABELS);
}
calLabelPredDetailLocal中具体分为三步骤:
下面具体看看。
3.2.1.1 flatMap
在flatMap中,主要是从label列和prediction列中,取出所有labels(注意是取出labels的名字 ),发送给下游算子。
EvaluationUtil.extractLabelProbMap 作用就是解析输入的json,获得具体detailInput中的信息。
下游算子是reduceGroup,所以Flink runtime会对这些labels自动去重。如果对这部分有兴趣,可以参见我之前介绍reduce的文章。CSDN : [源码解析] Flink的groupBy和reduce究竟做了什么 博客园 : [源码解析] Flink的groupBy和reduce究竟做了什么
程序中变量如下
row = {Row@8922} "prefix1,{"prefix1": 0.9, "prefix0": 0.1}"
fields = {Object[2]@8925}
0 = "prefix1"
1 = "{"prefix1": 0.9, "prefix0": 0.1}"
labelProbMap = {TreeMap@9008} size = 2
"prefix0" -> {Double@9015} 0.1
"prefix1" -> {Double@9017} 0.9
labelProbMap.keySet().forEach(collector::collect); //这里发送 "prefix0", "prefix1"
collector.collect(row.getField(0).toString()); // 这里发送 "prefix1"
// 因为下一个操作是reduceGroup,所以这些label会被runtime去重
3.2.1.2 reduceGroup
主要功能是通过buildLabelIndexLabelArray去重labels,然后给每一个label一个ID,最后结果是一个
reduceGroup(new EvaluationUtil.DistinctLabelIndexMap(binary, positiveValue));
DistinctLabelIndexMap的作用是从label列和prediction列中,取出所有不同的labels,返回一个
前面已经提到,这里的参数rows已经被自动去重。
public static class DistinctLabelIndexMap implements
GroupReduceFunction, String[]>> {
......
@Override
public void reduce(Iterable rows, Collector, String[]>> collector) throws Exception {
HashSet labels = new HashSet<>();
rows.forEach(labels::add);
collector.collect(buildLabelIndexLabelArray(labels, binary, positiveValue));
}
}
// 变量为
labels = {HashSet@9008} size = 2
0 = "prefix1"
1 = "prefix0"
binary = true
buildLabelIndexLabelArray的作用是给每一个label一个ID,得到一个
// Give each label an ID, return a map of label and ID.
public static Tuple2
3.2.1.3 mapPartition
这里主要功能是分区调用 CalLabelDetailLocal 来为后来计算混淆矩阵做准备。
return data
.rebalance()
.mapPartition(new CalLabelDetailLocal(binary)) //这里是业务所在
.withBroadcastSet(labels, LABELS);
具体工作是 CalLabelDetailLocal 完成的,其作用是分区调用getDetailStatistics
// Calculate the confusion matrix based on the label and predResult.
static class CalLabelDetailLocal extends RichMapPartitionFunction {
private Tuple2
getDetailStatistics 的作用是:初始化分类评估的度量指标 base classification evaluation metrics,累积计算混淆矩阵需要的数据。主要就是遍历 rows 数据,提取每一个item(比如 "prefix1,{"prefix1": 0.8, "prefix0": 0.2}"),然后累积计算混淆矩阵所需数据。
// Initialize the base classification evaluation metrics. There are two cases: BinaryClassMetrics and MultiClassMetrics.
private static BaseMetricsSummary getDetailStatistics(Iterable rows,
String positiveValue,
boolean binary,
Tuple2
先回忆下混淆矩阵:
预测值 0 | 预测值 1 | |||
---|---|---|---|---|
真实值 0 | TN | FP | ||
真实值 1 | FN | TP |
针对混淆矩阵,BinaryMetricsSummary 的作用是Save the evaluation data for binary classification。函数具体计算思路是:
把 [0,1] 分成ClassificationEvaluationUtil.DETAIL_BIN_NUMBER(100000)这么多桶(bin)。所以binaryMetricsSummary的positiveBin/negativeBin分别是两个100000的数组。如果某一个 sample 为 正例(positive value) 的概率是 p, 则该 sample 对应的 bin index 就是 p * 100000。如果 p 被预测为正例(positive value) ,则positiveBin[index]++,否则就是被预测为负例(negative value) ,则negativeBin[index]++。positiveBin就是 TP + FP,negativeBin就是 TN + FN。
所以这里会遍历输入,如果某一个输入(以"prefix1", "{\"prefix1\": 0.9, \"prefix0\": 0.1}"
为例),0.9 是prefix1(正例) 的概率,0.1 是为prefix0(负例) 的概率。
具体对应我们示例代码的5个采样,分类如下:
Row.of("prefix1", "{\"prefix1\": 0.9, \"prefix0\": 0.1}"), positiveBin 90000处+1
Row.of("prefix1", "{\"prefix1\": 0.8, \"prefix0\": 0.2}"), positiveBin 80000处+1
Row.of("prefix1", "{\"prefix1\": 0.7, \"prefix0\": 0.3}"), positiveBin 70000处+1
Row.of("prefix0", "{\"prefix1\": 0.75, \"prefix0\": 0.25}"), negativeBin 75000处+1
Row.of("prefix0", "{\"prefix1\": 0.6, \"prefix0\": 0.4}") negativeBin 60000处+1
具体代码如下
public static void updateBinaryMetricsSummary(TreeMap labelProbMap,
String label,
BinaryMetricsSummary binaryMetricsSummary) {
binaryMetricsSummary.total++;
binaryMetricsSummary.logLoss += extractLogloss(labelProbMap, label);
double d = labelProbMap.get(binaryMetricsSummary.labels[0]);
int idx = d == 1.0 ? ClassificationEvaluationUtil.DETAIL_BIN_NUMBER - 1 :
(int)Math.floor(d * ClassificationEvaluationUtil.DETAIL_BIN_NUMBER);
if (idx >= 0 && idx < ClassificationEvaluationUtil.DETAIL_BIN_NUMBER) {
if (label.equals(binaryMetricsSummary.labels[0])) {
binaryMetricsSummary.positiveBin[idx] += 1;
} else if (label.equals(binaryMetricsSummary.labels[1])) {
binaryMetricsSummary.negativeBin[idx] += 1;
} else {
.....
}
}
}
private static double extractLogloss(TreeMap labelProbMap, String label) {
Double prob = labelProbMap.get(label);
prob = null == prob ? 0. : prob;
return -Math.log(Math.max(Math.min(prob, 1 - LOG_LOSS_EPS), LOG_LOSS_EPS));
}
// 变量如下
ClassificationEvaluationUtil.DETAIL_BIN_NUMBER=100000
// 当 "prefix1", "{\"prefix1\": 0.9, \"prefix0\": 0.1}" 时候
labelProbMap = {TreeMap@9305} size = 2
"prefix0" -> {Double@9331} 0.1
"prefix1" -> {Double@9333} 0.9
d = 0.9
idx = 90000
binaryMetricsSummary = {BinaryMetricsSummary@9262}
labels = {String[2]@9242}
0 = "prefix1"
1 = "prefix0"
total = 1
positiveBin = {long[100000]@9263} // 90000处+1
negativeBin = {long[100000]@9264}
logLoss = 0.10536051565782628
// 当 "prefix0", "{\"prefix1\": 0.6, \"prefix0\": 0.4}" 时候
labelProbMap = {TreeMap@9514} size = 2
"prefix0" -> {Double@9546} 0.4
"prefix1" -> {Double@9547} 0.6
d = 0.6
idx = 60000
binaryMetricsSummary = {BinaryMetricsSummary@9262}
labels = {String[2]@9242}
0 = "prefix1"
1 = "prefix0"
total = 2
positiveBin = {long[100000]@9263}
negativeBin = {long[100000]@9264} // 60000处+1
logLoss = 1.0216512475319812
3.2.2 ReduceBaseMetrics
ReduceBaseMetrics作用是把局部计算的 BaseMetrics 聚合起来。
DataSet metrics = res
.reduce(new EvaluationUtil.ReduceBaseMetrics());
ReduceBaseMetrics如下
public static class ReduceBaseMetrics implements ReduceFunction {
@Override
public BaseMetricsSummary reduce(BaseMetricsSummary t1, BaseMetricsSummary t2) throws Exception {
return null == t1 ? t2 : t1.merge(t2);
}
}
具体计算是在BinaryMetricsSummary.merge,其作用就是Merge the bins, and add the logLoss。
@Override
public BinaryMetricsSummary merge(BinaryMetricsSummary binaryClassMetrics) {
for (int i = 0; i < this.positiveBin.length; i++) {
this.positiveBin[i] += binaryClassMetrics.positiveBin[i];
}
for (int i = 0; i < this.negativeBin.length; i++) {
this.negativeBin[i] += binaryClassMetrics.negativeBin[i];
}
this.logLoss += binaryClassMetrics.logLoss;
this.total += binaryClassMetrics.total;
return this;
}
// 程序变量是
this = {BinaryMetricsSummary@9316}
labels = {String[2]@9322}
0 = "prefix1"
1 = "prefix0"
total = 2
positiveBin = {long[100000]@9320}
negativeBin = {long[100000]@9323}
logLoss = 1.742969305058623
3.2.3 SaveDataAsParams
this.setOutput(metrics.flatMap(new EvaluationUtil.SaveDataAsParams()),
new String[] {DATA_OUTPUT}, new TypeInformation[] {Types.STRING});
当归并所有BaseMetrics之后,得到了total BaseMetrics,计算indexes,存入到params。
public static class SaveDataAsParams implements FlatMapFunction {
@Override
public void flatMap(BaseMetricsSummary t, Collector collector) throws Exception {
collector.collect(t.toMetrics().serialize());
}
}
实际业务在BinaryMetricsSummary.toMetrics中完成,即基于bin的信息计算,得到confusionMatrix array, threshold array, rocCurve/recallPrecisionCurve/LiftChart等等,然后存储到params。
public BinaryClassMetrics toMetrics() {
Params params = new Params();
// 生成若干曲线,比如rocCurve/recallPrecisionCurve/LiftChart
Tuple3 matrixThreCurve =
extractMatrixThreCurve(positiveBin, negativeBin, total);
// 依据曲线内容计算并且存储 AUC/PRC/KS
setCurveAreaParams(params, matrixThreCurve.f2);
// 对生成的rocCurve/recallPrecisionCurve/LiftChart输出进行抽样
Tuple3 sampledMatrixThreCurve = sample(
PROBABILITY_INTERVAL, matrixThreCurve);
// 依据抽样后的输出存储 RocCurve/RecallPrecisionCurve/LiftChar
setCurvePointsParams(params, sampledMatrixThreCurve);
ConfusionMatrix[] matrices = sampledMatrixThreCurve.f0;
// 存储正例样本的度量指标
setComputationsArrayParams(params, sampledMatrixThreCurve.f1, sampledMatrixThreCurve.f0);
// 存储Logloss
setLoglossParams(params, logLoss, total);
// Pick the middle point where threshold is 0.5.
int middleIndex = getMiddleThresholdIndex(sampledMatrixThreCurve.f1);
setMiddleThreParams(params, matrices[middleIndex], labels);
return new BinaryClassMetrics(params);
}
extractMatrixThreCurve是全文重点。这里是 Extract the bins who are not empty, keep the middle threshold 0.5,然后初始化了 RocCurve, Recall-Precision Curve and Lift Curve,计算出ConfusionMatrix array(混淆矩阵), threshold array, rocCurve/recallPrecisionCurve/LiftChart.。
/**
* Extract the bins who are not empty, keep the middle threshold 0.5.
* Initialize the RocCurve, Recall-Precision Curve and Lift Curve.
* RocCurve: (FPR, TPR), starts with (0,0). Recall-Precision Curve: (recall, precision), starts with (0, p), p is the precision with the lowest. LiftChart: (TP+FP/total, TP), starts with (0,0). confusion matrix = [TP FP][FN * TN].
*
* @param positiveBin positiveBins.
* @param negativeBin negativeBins.
* @param total sample number
* @return ConfusionMatrix array, threshold array, rocCurve/recallPrecisionCurve/LiftChart.
*/
static Tuple3 extractMatrixThreCurve(long[] positiveBin, long[] negativeBin, long total) {
ArrayList effectiveIndices = new ArrayList<>();
long totalTrue = 0, totalFalse = 0;
// 计算totalTrue,totalFalse,effectiveIndices
for (int i = 0; i < ClassificationEvaluationUtil.DETAIL_BIN_NUMBER; i++) {
if (0L != positiveBin[i] || 0L != negativeBin[i]
|| i == ClassificationEvaluationUtil.DETAIL_BIN_NUMBER / 2) {
effectiveIndices.add(i);
totalTrue += positiveBin[i];
totalFalse += negativeBin[i];
}
}
// 以我们例子,得到
effectiveIndices = {ArrayList@9273} size = 6
0 = {Integer@9277} 50000 //这里加入了中间点
1 = {Integer@9278} 60000
2 = {Integer@9279} 70000
3 = {Integer@9280} 75000
4 = {Integer@9281} 80000
5 = {Integer@9282} 90000
totalTrue = 3
totalFalse = 2
// 继续初始化,生成若干curve
final int length = effectiveIndices.size();
final int newLen = length + 1;
final double m = 1.0 / ClassificationEvaluationUtil.DETAIL_BIN_NUMBER;
EvaluationCurvePoint[] rocCurve = new EvaluationCurvePoint[newLen];
EvaluationCurvePoint[] recallPrecisionCurve = new EvaluationCurvePoint[newLen];
EvaluationCurvePoint[] liftChart = new EvaluationCurvePoint[newLen];
ConfusionMatrix[] data = new ConfusionMatrix[newLen];
double[] threshold = new double[newLen];
long curTrue = 0;
long curFalse = 0;
// 以我们例子,得到
length = 6
newLen = 7
m = 1.0E-5
// 计算, 其中rocCurve,recallPrecisionCurve,liftChart 都可以从代码中看出
for (int i = 1; i < newLen; i++) {
int index = effectiveIndices.get(length - i);
curTrue += positiveBin[index];
curFalse += negativeBin[index];
threshold[i] = index * m;
// 计算出混淆矩阵
data[i] = new ConfusionMatrix(
new long[][] {{curTrue, curFalse}, {totalTrue - curTrue, totalFalse - curFalse}});
double tpr = (totalTrue == 0 ? 1.0 : 1.0 * curTrue / totalTrue);
// 比如当 90000 这点,得到 curTrue = 1 curFalse = 0 i = 1 index = 90000 tpr = 0.3333333333333333。totalTrue = 3 totalFalse = 2,
// 我们也知道,TPR = TP / (TP + FN) ,所以可以计算 tpr = 1 / 3
rocCurve[i] = new EvaluationCurvePoint(totalFalse == 0 ? 1.0 : 1.0 * curFalse / totalFalse, tpr, threshold[i]);
recallPrecisionCurve[i] = new EvaluationCurvePoint(tpr, curTrue + curTrue == 0 ? 1.0 : 1.0 * curTrue / (curTrue + curFalse), threshold[i]);
liftChart[i] = new EvaluationCurvePoint(1.0 * (curTrue + curFalse) / total, curTrue, threshold[i]);
}
// 以我们例子,得到
curTrue = 3
curFalse = 2
threshold = {double[7]@9349}
0 = 0.0
1 = 0.9
2 = 0.8
3 = 0.7500000000000001
4 = 0.7000000000000001
5 = 0.6000000000000001
6 = 0.5
rocCurve = {EvaluationCurvePoint[7]@9315}
1 = {EvaluationCurvePoint@9440}
x = 0.0
y = 0.3333333333333333
p = 0.9
2 = {EvaluationCurvePoint@9448}
x = 0.0
y = 0.6666666666666666
p = 0.8
3 = {EvaluationCurvePoint@9449}
x = 0.5
y = 0.6666666666666666
p = 0.7500000000000001
4 = {EvaluationCurvePoint@9450}
x = 0.5
y = 1.0
p = 0.7000000000000001
5 = {EvaluationCurvePoint@9451}
x = 1.0
y = 1.0
p = 0.6000000000000001
6 = {EvaluationCurvePoint@9452}
x = 1.0
y = 1.0
p = 0.5
recallPrecisionCurve = {EvaluationCurvePoint[7]@9320}
1 = {EvaluationCurvePoint@9444}
x = 0.3333333333333333
y = 1.0
p = 0.9
2 = {EvaluationCurvePoint@9453}
x = 0.6666666666666666
y = 1.0
p = 0.8
3 = {EvaluationCurvePoint@9454}
x = 0.6666666666666666
y = 0.6666666666666666
p = 0.7500000000000001
4 = {EvaluationCurvePoint@9455}
x = 1.0
y = 0.75
p = 0.7000000000000001
5 = {EvaluationCurvePoint@9456}
x = 1.0
y = 0.6
p = 0.6000000000000001
6 = {EvaluationCurvePoint@9457}
x = 1.0
y = 0.6
p = 0.5
liftChart = {EvaluationCurvePoint[7]@9325}
1 = {EvaluationCurvePoint@9458}
x = 0.2
y = 1.0
p = 0.9
2 = {EvaluationCurvePoint@9459}
x = 0.4
y = 2.0
p = 0.8
3 = {EvaluationCurvePoint@9460}
x = 0.6
y = 2.0
p = 0.7500000000000001
4 = {EvaluationCurvePoint@9461}
x = 0.8
y = 3.0
p = 0.7000000000000001
5 = {EvaluationCurvePoint@9462}
x = 1.0
y = 3.0
p = 0.6000000000000001
6 = {EvaluationCurvePoint@9463}
x = 1.0
y = 3.0
p = 0.5
data = {ConfusionMatrix[7]@9339}
0 = {ConfusionMatrix@9486}
longMatrix = {LongMatrix@9488}
matrix = {long[2][]@9491}
0 = {long[2]@9492}
0 = 0
1 = 0
1 = {long[2]@9493}
0 = 3
1 = 2
rowNum = 2
colNum = 2
labelCnt = 2
total = 5
actualLabelFrequency = {long[2]@9489}
0 = 3
1 = 2
predictLabelFrequency = {long[2]@9490}
0 = 0
1 = 5
tpCount = 2.0
tnCount = 2.0
fpCount = 3.0
fnCount = 3.0
1 = {ConfusionMatrix@9435}
longMatrix = {LongMatrix@9469}
matrix = {long[2][]@9472}
0 = {long[2]@9474}
0 = 1
1 = 0
1 = {long[2]@9475}
0 = 2
1 = 2
rowNum = 2
colNum = 2
labelCnt = 2
total = 5
actualLabelFrequency = {long[2]@9470}
0 = 3
1 = 2
predictLabelFrequency = {long[2]@9471}
0 = 1
1 = 4
tpCount = 3.0
tnCount = 3.0
fpCount = 2.0
fnCount = 2.0
......
threshold[0] = 1.0;
data[0] = new ConfusionMatrix(new long[][] {{0, 0}, {totalTrue, totalFalse}});
rocCurve[0] = new EvaluationCurvePoint(0, 0, threshold[0]);
recallPrecisionCurve[0] = new EvaluationCurvePoint(0, recallPrecisionCurve[1].getY(), threshold[0]);
liftChart[0] = new EvaluationCurvePoint(0, 0, threshold[0]);
return Tuple3.of(data, threshold, new EvaluationCurve[] {new EvaluationCurve(rocCurve),
new EvaluationCurve(recallPrecisionCurve), new EvaluationCurve(liftChart)});
}
3.2.4 计算混淆矩阵
这里再给大家讲讲混淆矩阵如何计算,这里思路比较绕。
3.2.4.1 原始矩阵
调用之处是:
// 调用之处
data[i] = new ConfusionMatrix(
new long[][] {{curTrue, curFalse}, {totalTrue - curTrue, totalFalse - curFalse}});
// 调用时候各种赋值
i = 1
index = 90000
totalTrue = 3
totalFalse = 2
curTrue = 1
curFalse = 0
得到原始矩阵,以下都有cur,说明只针对当前点来说。
curTrue = 1 | curFalse = 0 |
totalTrue - curTrue = 2 | totalFalse - curFalse = 2 |
3.2.4.2 计算标签
后续ConfusionMatrix计算中,由此可以得到
actualLabelFrequency = longMatrix.getColSums();
predictLabelFrequency = longMatrix.getRowSums();
actualLabelFrequency = {long[2]@9322}
0 = 3
1 = 2
predictLabelFrequency = {long[2]@9323}
0 = 1
1 = 4
可以看出来,Alink算法认为:每列的sum和实际标签有关;每行sum和预测标签有关。
得到新矩阵如下
predictLabelFrequency | |||
---|---|---|---|
curTrue = 1 | curFalse = 0 | 1 = curTrue + curFalse | |
totalTrue - curTrue = 2 | totalFalse - curFalse = 2 | 4 = total - curTrue - curFalse | |
actualLabelFrequency | 3 = totalTrue | 2 = totalFalse |
后续计算将要基于这些来计算:
计算中就用到longMatrix 对角线上的数据,即longMatrix(0)(0)和 longMatrix(1)(1)。一定要注意,这里考虑的都是 当前状态 (画重点强调)。
longMatrix(0)(0) :curTrue
longMatrix(1)(1) :totalFalse - curFalse
totalFalse :( TN + FN )
totalTrue :( TP + FP )
double numTrueNegative(Integer labelIndex) {
// labelIndex为 0 时候,return 1 + 5 - 1 - 3 = 2;
// labelIndex为 1 时候,return 2 + 5 - 4 - 2 = 1;
return null == labelIndex ? tnCount : longMatrix.getValue(labelIndex, labelIndex) + total - predictLabelFrequency[labelIndex] - actualLabelFrequency[labelIndex];
}
double numTruePositive(Integer labelIndex) {
// labelIndex为 0 时候,return 1; 这个是 curTrue,就是真实标签是True,判别也是True。是TP
// labelIndex为 1 时候,return 2; 这个是 totalFalse - curFalse,总判别错 - 当前判别错。这就意味着“本来判别错了但是当前没有发现”,所以认为在当前状态下,这也算是TP
return null == labelIndex ? tpCount : longMatrix.getValue(labelIndex, labelIndex);
}
double numFalseNegative(Integer labelIndex) {
// labelIndex为 0 时候,return 3 - 1;
// actualLabelFrequency[0] = totalTrue。所以return totalTrue - curTrue,即当前“全部正确”中没有“判别为正确”,这个就可以认为是“判别错了且判别为负”
// labelIndex为 1 时候,return 2 - 2;
// actualLabelFrequency[1] = totalFalse。所以return totalFalse - ( totalFalse - curFalse ) = curFalse
return null == labelIndex ? fnCount : actualLabelFrequency[labelIndex] - longMatrix.getValue(labelIndex, labelIndex);
}
double numFalsePositive(Integer labelIndex) {
// labelIndex为 0 时候,return 1 - 1;
// predictLabelFrequency[0] = curTrue + curFalse。
// 所以 return = curTrue + curFalse - curTrue = curFalse = current( TN + FN ) 这可以认为是判断错了实际是正确标签
// labelIndex为 1 时候,return 4 - 2;
// predictLabelFrequency[1] = total - curTrue - curFalse。
// 所以 return = total - curTrue - curFalse - (totalFalse - curFalse) = totalTrue - curTrue = ( TP + FP ) - currentTP = currentFP
return null == labelIndex ? fpCount : predictLabelFrequency[labelIndex] - longMatrix.getValue(labelIndex, labelIndex);
}
// 最后得到
tpCount = 3.0
tnCount = 3.0
fpCount = 2.0
fnCount = 2.0
3.2.4.3 具体代码
// 具体计算
public ConfusionMatrix(LongMatrix longMatrix) {
longMatrix = {LongMatrix@9297}
0 = {long[2]@9324}
0 = 1
1 = 0
1 = {long[2]@9325}
0 = 2
1 = 2
this.longMatrix = longMatrix;
labelCnt = this.longMatrix.getRowNum();
// 这里就是计算
actualLabelFrequency = longMatrix.getColSums();
predictLabelFrequency = longMatrix.getRowSums();
actualLabelFrequency = {long[2]@9322}
0 = 3
1 = 2
predictLabelFrequency = {long[2]@9323}
0 = 1
1 = 4
labelCnt = 2
total = 5
total = longMatrix.getTotal();
for (int i = 0; i < labelCnt; i++) {
tnCount += numTrueNegative(i);
tpCount += numTruePositive(i);
fnCount += numFalseNegative(i);
fpCount += numFalsePositive(i);
}
}
Alink原有python示例代码中,Stream部分是没有输出的,因为MemSourceStreamOp没有和时间相关联,而Alink中没有提供基于时间的StreamOperator,所以只能自己仿照MemSourceBatchOp写了一个。虽然代码有些丑,但是至少可以提供输出,这样就能够调试。
4.1.1 主类
public class EvalBinaryClassExampleStream {
AlgoOperator getData(boolean isBatch) {
Row[] rows = new Row[]{
Row.of("prefix1", "{\"prefix1\": 0.9, \"prefix0\": 0.1}")
};
String[] schema = new String[]{"label", "detailInput"};
if (isBatch) {
return new MemSourceBatchOp(rows, schema);
} else {
return new TimeMemSourceStreamOp(rows, schema, new EvalBinaryStreamSource());
}
}
public static void main(String[] args) throws Exception {
EvalBinaryClassExampleStream test = new EvalBinaryClassExampleStream();
StreamOperator streamData = (StreamOperator) test.getData(false);
StreamOperator sOp = new EvalBinaryClassStreamOp()
.setLabelCol("label")
.setPredictionDetailCol("detailInput")
.setTimeInterval(1)
.linkFrom(streamData);
sOp.print();
StreamOperator.execute();
}
}
4.1.2 TimeMemSourceStreamOp
这个是我自己炮制的。借鉴了MemSourceStreamOp。
public final class TimeMemSourceStreamOp extends StreamOperator {
public TimeMemSourceStreamOp(Row[] rows, String[] colNames, EvalBinaryStrSource source) {
super(null);
init(source, Arrays.asList(rows), colNames);
}
private void init(EvalBinaryStreamSource source, List rows, String[] colNames) {
Row first = rows.iterator().next();
int arity = first.getArity();
TypeInformation >[] types = new TypeInformation[arity];
for (int i = 0; i < arity; ++i) {
types[i] = TypeExtractor.getForObject(first.getField(i));
}
init(source, colNames, types);
}
private void init(EvalBinaryStreamSource source, String[] colNames, TypeInformation >[] colTypes) {
DataStream dastr = MLEnvironmentFactory.get(getMLEnvironmentId())
.getStreamExecutionEnvironment().addSource(source);
StringBuilder sbd = new StringBuilder();
sbd.append(colNames[0]);
for (int i = 1; i < colNames.length; i++) {
sbd.append(",").append(colNames[i]);
}
this.setOutput(dastr, colNames, colTypes);
}
@Override
public TimeMemSourceStreamOp linkFrom(StreamOperator>... inputs) {
return null;
}
}
4.1.3 Source
定时提供Row,加入了随机数,让概率有变化。
class EvalBinaryStreamSource extends RichSourceFunction[Row] {
override def run(ctx: SourceFunction.SourceContext[Row]) = {
while (true) {
val rdm = Math.random() // 这里加入了随机数,让概率有变化
val rows: Array[Row] = Array[Row](
Row.of("prefix1", "{\"prefix1\": " + rdm + ", \"prefix0\": " + (1-rdm) + "}"),
Row.of("prefix1", "{\"prefix1\": 0.8, \"prefix0\": 0.2}"),
Row.of("prefix1", "{\"prefix1\": 0.7, \"prefix0\": 0.3}"),
Row.of("prefix0", "{\"prefix1\": 0.75, \"prefix0\": 0.25}"),
Row.of("prefix0", "{\"prefix1\": 0.6, \"prefix0\": 0.4}"))
for(row <- rows) {
println(s"当前值:$row")
ctx.collect(row)
}
Thread.sleep(1000)
}
}
override def cancel() = ???
}
Alink流处理类是 EvalBinaryClassStreamOp,主要工作在其基类 BaseEvalClassStreamOp,所以我们重点看后者。
public class BaseEvalClassStreamOp> extends StreamOperator {
@Override
public T linkFrom(StreamOperator>... inputs) {
StreamOperator> in = checkAndGetFirst(inputs);
String labelColName = this.get(MultiEvaluationStreamParams.LABEL_COL);
String positiveValue = this.get(BinaryEvaluationStreamParams.POS_LABEL_VAL_STR);
Integer timeInterval = this.get(MultiEvaluationStreamParams.TIME_INTERVAL);
ClassificationEvaluationUtil.Type type www.letaizaixian.cn= ClassificationEvaluationUtil.judgeEvaluationType(this.getParams());
DataStream statistics;
switch (www.yachengyl.cn type) {
case PRED_RESULT: {
......
}
case PRED_DETAIL: {
String predDetailColName = this.get(MultiEvaluationStreamParams.PREDICTION_DETAIL_COL);
//
PredDetailLabel eval = new PredDetailLabel(positiveValue, binary);
// 获取输入数据,重点是timeWindowAll
statistics = in.select(new String[] {labelColName, predDetailColName})
.getDataStream()
.timeWindowAll(Time.of(timeInterval, TimeUnit.SECONDS))
.apply(eval);
break;
}
}
// 把各个窗口的数据累积到 totalStatistics,注意,这里是新变量了。
DataStream totalStatistics = statistics
.map(new EvaluationUtil.AllDataMerge())
.setParallelism(1); // 并行度设置为1
// 基于两种 bins 计算&序列化,得到当前的 statistics
DataStream windowOutput = statistics.map(
new EvaluationUtil.SaveDataStream(ClassificationEvaluationUtil.WINDOW.f0));
// 基于bins计算&序列化,得到累积的 totalStatistics
DataStream allOutput = totalStatistics.map(
new EvaluationUtil.SaveDataStream(ClassificationEvaluationUtil.ALL.f0));
// "当前" 和 "累积" 做联合,最终返回
DataStream union = windowOutput.union(allOutput);
this.setOutput(union,
new String[] {ClassificationEvaluationUtil.STATISTICS_OUTPUT, DATA_OUTPUT},
new TypeInformation[] {Types.STRING, Types.STRING});
return (T)this;
}
}
具体业务是:
4.2.1 PredDetailLabel
static class PredDetailLabel implements AllWindowFunction {
@Override
public void apply( www.baihuayllpt.cn TimeWindow timeWindow, Iterable rows, Collector www.93ylzc.cn collector) throws Exception {
HashSet labels = new HashSet<>(www.lcx528.cn);
// 首先还是获取 labels 名字
for (Row row www.tianshun178.cn: rows) {
if (www.hongniuyLe.cn EvaluationUtil.checkRowFieldNotNull(row)) {
labels.addAll(EvaluationUtil.extractLabelProbMap(row).keySet());
labels.add(row.getField(0).toString());
}
}
labels = {www.tongyayule.com HashSet@9757} size www.yingchenyl.com= 2
0 = "prefix1"
1 = "prefix0"
// 之前介绍过,buildLabelIndexLabelArray 去重 "labels名字",然后给每一个label一个ID,最后结果是一个Map。
// getDetailStatistics 遍历 rows 数据,累积计算混淆矩阵所需数据( "TP + FN" / "TN + FP")。
if (labels.size() > 0) {
collector.collect(
getDetailStatistics(rows, binary, buildLabelIndexLabelArray(labels, binary, positiveValue)));
}
}
}
4.2.2 AllDataMerge
EvaluationUtil.AllDataMerge 把各个窗口的数据累积
/**
* Merge data from different windows.
*/
public static class AllDataMerge implements MapFunction {
private BaseMetricsSummary statistics;
@Override
public BaseMetricsSummary map(BaseMetricsSummary value) {
this.statistics = (null == this.statistics ? value : this.statistics.merge(value));
return this.statistics;
}
}
4.2.3 SaveDataStream
SaveDataStream具体调用的函数之前批处理介绍过,实际业务在BinaryMetricsSummary.toMetrics,即基于bin的信息计算,存储到params。
这里与批处理不同的是直接就把"构建出的度量信息“返回给用户。
public static class SaveDataStream implements MapFunction {
@Override www.tengyao3zc.cn www.anxinzc5.cn
public Row map(BaseMetricsSummary baseMetricsSummary) throws Exception {
BaseMetricsSummary metrics = baseMetricsSummary;
BaseMetrics baseMetrics = metrics.toMetrics();
Row row =www.tianhuoyl.cn baseMetrics.serialize();
return Row.of(funtionName, www.yixingylzc.cn row.getField(0));
}
}
// 最后得到的 row 其实就是最终返回给用户的度量信息
row = {Row@10008} "{"PRC":"0.9164636268708667","SensitivityArray":"[0.38461538461538464,0.6923076923076923,0.6923076923076923,1.0,1.0,1.0]","ConfusionMatrix":"[[13,8],[0,0]]","MacroRecall":"0.5","MacroSpecificity":"0.5","FalsePositiveRateArray":"[0.0,0.0,0.5,0.5,1.0,1.0]"www.jintianxuesha.com ...... 还有很多其他的
4.2.4 Union
DataStream windowOutput = statistics.map(
new EvaluationUtil.SaveDataStream(ClassificationEvaluationUtil.WINDOW.f0));
DataStream allOutput = totalStatistics.map(
new EvaluationUtil.SaveDataStream(ClassificationEvaluationUtil.ALL.f0));
DataStream union = windowOutput.union(allOutput);
最后返回两种统计数据
4.2.4.1 allOutput
all|{"PRC":"0.7341146115890359","SensitivityArray":"[0.3333333333333333,0.3333333333333333,0.6666666666666666,0.7333333333333333,0.8,0.8,0.8666666666666667,0.8666666666666667,0.9333333333333333,1.0]","ConfusionMatrix":"[[13,10],[2,0]]","MacroRecall":"0.43333333333333335","MacroSpecificity":"0.43333333333333335","FalsePositiveRateArray":"[0.0,0.5,0.5,0.5,0.5,1.0,1.0,1.0,1.0,1.0]","TruePositiveRateArray":"[0.3333333333333333,0.3333333333333333,0.6666666666666666,0.7333333333333333,0.8,0.8,0.8666666666666667,0.8666666666666667,0.9333333333333333,1.0]","AUC":"0.5666666666666667","MacroAccuracy":"0.52", ......
4.2.4.2 windowOutput
window|{"PRC":"0.7638888888888888","SensitivityArray":"[0.3333333333333333,0.3333333333333333,0.6666666666666666,1.0,1.0,1.0]"