最后,贴出完整代码,感兴趣的读者可以运行实验。
注意,由于示例中需要演示中间结果,有很多打印或执行的方法,我现将调用这些方法的代码设为了注释,读者可以自己释放某些代码,查看运行效果。
package com.alibaba.alink;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.classification.LogisticRegressionTrainBatchOp;
import com.alibaba.alink.operator.batch.source.CsvSourceBatchOp;
import com.alibaba.alink.operator.stream.StreamOperator;
import com.alibaba.alink.operator.stream.dataproc.JsonValueStreamOp;
import com.alibaba.alink.operator.stream.dataproc.SplitStreamOp;
import com.alibaba.alink.operator.stream.evaluation.EvalBinaryClassStreamOp;
import com.alibaba.alink.operator.stream.onlinelearning.FtrlPredictStreamOp;
import com.alibaba.alink.operator.stream.onlinelearning.FtrlTrainStreamOp;
import com.alibaba.alink.operator.stream.source.CsvSourceStreamOp;
import com.alibaba.alink.pipeline.Pipeline;
import com.alibaba.alink.pipeline.PipelineModel;
import com.alibaba.alink.pipeline.dataproc.StandardScaler;
import com.alibaba.alink.pipeline.feature.FeatureHasher;
public class FTRLExample {
public static void main(String[] args) throws Exception {
//new TextSourceBatchOp()
// .setFilePath("http://alink-release.oss-cn-beijing.aliyuncs.com/data-files/avazu-small.csv")
// .firstN(10)
// .print();
String schemaStr
= "id string, click string, dt string, C1 string, banner_pos int, site_id string, site_domain string, "
+ "site_category string, app_id string, app_domain string, app_category string, device_id string, "
+ "device_ip string, device_model string, device_type string, device_conn_type string, C14 int, C15 int, "
+ "C16 int, C17 int, C18 int, C19 int, C20 int, C21 int";
CsvSourceBatchOp trainBatchData = new CsvSourceBatchOp()
.setFilePath("http://alink-release.oss-cn-beijing.aliyuncs.com/data-files/avazu-small.csv")
.setSchemaStr(schemaStr);
//trainBatchData.firstN(10).print();
String labelColName = "click";
String[] selectedColNames = new String[] {
"C1", "banner_pos", "site_category", "app_domain",
"app_category", "device_type", "device_conn_type",
"C14", "C15", "C16", "C17", "C18", "C19", "C20", "C21",
"site_id", "site_domain", "device_id", "device_model"};
String[] categoryColNames = new String[] {
"C1", "banner_pos", "site_category", "app_domain",
"app_category", "device_type", "device_conn_type",
"site_id", "site_domain", "device_id", "device_model"};
String[] numericalColNames = new String[] {
"C14", "C15", "C16", "C17", "C18", "C19", "C20", "C21"};
// result column name of feature enginerring
String vecColName = "vec";
int numHashFeatures = 30000;
// setup feature enginerring pipeline
Pipeline feature_pipeline = new Pipeline()
.add(
new StandardScaler()
.setSelectedCols(numericalColNames)
)
.add(
new FeatureHasher()
.setSelectedCols(selectedColNames)
.setCategoricalCols(categoryColNames)
.setOutputCol(vecColName)
.setNumFeatures(numHashFeatures)
);
// fit and save feature pipeline model
String FEATURE_PIPELINE_MODEL_FILE = "/Users/yangxu/alink/data/temp/feature_pipe_model.csv";
//feature_pipeline.fit(trainBatchData).save(FEATURE_PIPELINE_MODEL_FILE);
//
//BatchOperator.execute();
// prepare stream train data
CsvSourceStreamOp data = new CsvSourceStreamOp()
.setFilePath("http://alink-release.oss-cn-beijing.aliyuncs.com/data-files/avazu-ctr-train-8M.csv")
.setSchemaStr(schemaStr)
.setIgnoreFirstLine(true);
// split stream to train and eval data
SplitStreamOp spliter = new SplitStreamOp().setFraction(0.5).linkFrom(data);
StreamOperator train_stream_data = spliter;
StreamOperator test_stream_data = spliter.getSideOutput(0);
// load pipeline model
PipelineModel feature_pipelineModel = PipelineModel.load(FEATURE_PIPELINE_MODEL_FILE);
// train initial batch model
LogisticRegressionTrainBatchOp lr = new LogisticRegressionTrainBatchOp()
.setVectorCol(vecColName)
.setLabelCol(labelColName)
.setWithIntercept(true)
.setMaxIter(10);
BatchOperator initModel = feature_pipelineModel.transform(trainBatchData).link(lr);
// ftrl train
FtrlTrainStreamOp model = new FtrlTrainStreamOp(initModel)
.setVectorCol(vecColName)
.setLabelCol(labelColName)
.setWithIntercept(true)
.setAlpha(0.1)
.setBeta(0.1)
.setL1(0.01)
.setL2(0.01)
.setTimeInterval(10)
.setVectorSize(numHashFeatures)
.linkFrom(feature_pipelineModel.transform(train_stream_data));
// ftrl predict
FtrlPredictStreamOp predResult = new FtrlPredictStreamOp(initModel)
.setVectorCol(vecColName)
.setPredictionCol("pred")
.setReservedCols(new String[] {labelColName})
.setPredictionDetailCol("details")
.linkFrom(model, feature_pipelineModel.transform(test_stream_data));
//predResult.sample(0.0001).print();
//
//StreamOperator.execute();
// ftrl eval
predResult.link(
new EvalBinaryClassStreamOp()
.setLabelCol(labelColName)
.setPredictionCol("pred")
.setPredictionDetailCol("details")
.setTimeInterval(10)
).link(
new JsonValueStreamOp()
.setSelectedCol("Data")
.setReservedCols(new String[] {"Statistics"})
.setOutputCols(new String[] {"Accuracy", "AUC", "ConfusionMatrix"})
.setJsonPath(new String[] {"$.Accuracy", "$.AUC", "$.ConfusionMatrix"})
)
.print();
//StreamOperator.execute();
}
}