近期在负责公司的POI领域,全称为point of intrest即兴趣点,这个应用的最广泛的应该是地图行业,地图里每一个地址标注即为一个POI,在我们公司对它的含义进行了精简以契合公司业务的发展,将兴趣点集中在了餐饮及新零食相关的商户&超市等范畴。
听上去这个业务只是做一些商户数据的收集校正,那为什么这个业务会牵扯到了机器学习呢?真实原因很尴尬不便多说,目前我们拿到了一些商户的数据,但是无法获取品类,而品类对于我们当前业务来说非常重要,涉及到不同业务线的利益问题。所以需要通过一些特殊手段来识别出这些商户的品类。
场景
通过已有的商户数据,包括商户名称、商户菜品,识别出该商户属于什么品类(川湘菜、日料等)
解决办法
- 委托算法团队协助计算商户品类
已有相似项目,可基于该项目做品类计算,实现快速 - 自己琢磨,研究算法
学习算法以及应用到生产,较为耗时
最终选择
算法团队在人力资源安排上出现问题,不得已由自己来做算法计算,好在相关算法项目交接到我们团队,可以借此做为参考。
考虑到当前快速实现以及以后发展方向,最终选择两路同时并进,一路基于python脚本进行计算用于短期快速实现,另一路尝试通过spark ML进行分布式计算为我们长期目标。
算法思路
将商户的所有的菜品信息以及商户信息当成一串文本来处理,这样就可以把这个问题看成是“文本相似度”的问题,TF-IDF、LSI、LDA等一系列算法可以参考,python脚本采用TF-IDF和LSI来进行相似度计算(简单示例参考)。spark采用TF-IDF以及余弦相似度进行验证性计算(后续优化算法)。因长期规则为使用spark进行机器学习的相关计算,以下重点介绍spark上如何应用。
算法介绍
TF-IDF
余弦相似度计算
spark ML实现
分为两个spark任务,第一个任务为计算当前线上已经存在的且正确匹配的商户&菜品的TF-IDF值并且将计算出的值保存到hive表里。
任务一
数据预处理以及参考数据TF-IDF计算,通过计算
首先通过hive任务将商户的菜品数据拍平,这一步很简单,拍平后的数据如下:
然后另起spark任务对商户菜品进行TF-IDF处理,将结果保存到如下表里面。vector_indices及vector_values都为数组且长度一致,两者共同表示为多组向量
如下为tfidf工具类
public class TfidfUtil {
/**
* visit below website to get more detail about tfidf
* @see Spark入门:特征抽取: TF-IDF
* @param dataset
* @return
*/
public static Dataset tfidf(Dataset dataset) {
Tokenizer tokenizer = new Tokenizer().setInputCol("goodsSegment").setOutputCol("words");
Dataset wordsData = tokenizer.transform(dataset);
HashingTF hashingTF = new HashingTF()
.setInputCol("words")
.setOutputCol("rawFeatures");
Dataset featurizedData = hashingTF.transform(wordsData);
IDF idf = new IDF().setInputCol("rawFeatures").setOutputCol("features");
IDFModel idfModel = idf.fit(featurizedData);
return idfModel.transform(featurizedData);
}
}
如下为spark预处理任务,主要步骤为获取商户及拍平的菜品数据,再做TF-IDF,再保存到hive表
public class CategorySuggestionTrainning {
private static SparkSession spark;
private static final String YESTERDAY = DateTimeUtil.getYesterdayStr();
public static final String TRAINNING_DATA_SQL = "select id, coalesce(shop_name,'') as name,coalesce(category_id,0) as category_id, coalesce(food_name,'') as food " +
"from dw.category_and_foodname where dt='%s' limit 100000";
public static void main(String[] args){
spark = initSaprk();
try{
Dataset rawTranningDataset = getTrainDataSet();
Dataset trainningTfidfDataset = TfidfUtil.tfidf(rawTranningDataset);
JavaRDD trainningFeatureRdd = getTrainningFeatureRDD(trainningTfidfDataset);
Dataset trainningFeaturedataset = spark.createDataFrame(trainningFeatureRdd,TrainningFeature.class);
saveToHive(trainningFeaturedataset);
System.out.println("poi suggest trainning stopped");
spark.stop();
} catch (Exception e) {
System.out.println("main method has error " + e.getMessage());
e.printStackTrace();
}
}
/**
* to get the origin ele shop data including category and goods which is separated by '|'
* and then divide the goods into words
* @return Dataset
*/
private static Dataset getTrainDataSet(){
String trainningSql = String.format(TRAINNING_DATA_SQL,YESTERDAY);
System.out.println("tranningData sql is "+trainningSql);
spark.sql("use dw");
Dataset rowRdd = spark.sql(trainningSql);
JavaRDD trainningDataJavaRDD = rowRdd.javaRDD().map((row) -> {
String goods = (String) row.getAs("food");
String shopName = (String) row.getAs("name");
if (StringUtil.isBlank(shopName) || StringUtil.isBlank(goods) || goods.length() < 50) {
System.out.println("some field is null " + row.toString());
return null;
}
TrainningData data = new TrainningData();
data.setShopId((Long) row.getAs("id"));
data.setShopName(shopName);
data.setCategory((Long) row.getAs("category_id"));
List words = WordSegmenter.seg(goods);
StringBuilder wordsOfGoods = new StringBuilder();
for (Word word : words) {
wordsOfGoods.append(word.getText()).append(" ");
}
data.setGoodsSegment(wordsOfGoods.toString());
return data;
}).filter((data) -> data != null);
return spark.createDataFrame(trainningDataJavaRDD, TrainningData.class);
}
private static JavaRDD getTrainningFeatureRDD(Dataset trainningTfidfDataset){
return trainningTfidfDataset.javaRDD().map(new Function(){
@Override
public TrainningFeature call(Row row) throws Exception {
TrainningFeature data = new TrainningFeature();
data.setCategory(row.getAs("category"));
data.setShopId(row.getAs("shopId"));
data.setShopName(row.getAs("shopName"));
SparseVector vector = row.getAs("features");
data.setVectorSize(vector.size());
data.setVectorIndices(Arrays.toString(vector.indices()));
data.setVectorValues(Arrays.toString(vector.values()));
return data;
}
});
}
private static SparkSession initSaprk(){
long startTime = System.currentTimeMillis();
return SparkSession
.builder()
.appName("poi-spark-trainning")
.enableHiveSupport()
.getOrCreate();
}
private static void saveToHive(Dataset trainningTfidfDataset){
try {
trainningTfidfDataset.createTempView("trainData");
String sqlInsert = "insert overwrite table dw.poi_category_pre_data " +
"select shopId,shopName,category,vectorSize,vectorIndices,vectorValues from trainData ";
spark.sql("use dw");
System.out.println(spark.sql(sqlInsert).count());
} catch (AnalysisException e) {
System.out.println("save tranning data to hive failed");
e.printStackTrace();
}
}
}
任务二
取出预处理好的数据以及待确定分类的商户数据,将两者做余弦相似度计算,选择相似度最高的预处理的商户的分类做为待确认商户的分类。
相关代码如下
public class CategorySuggestion {
private static SparkSession spark;
private static final String YESTERDAY = DateTimeUtil.getYesterdayStr();
private static boolean CALCULATE_ALL = false;
private static long MT_SHOP_COUNT = 2000;
public static final String TRAINNING_DATA_SQL = "select shop_id, coalesce(shop_name,'') as shop_name,coalesce(category,0) as category_id, " +
"vector_size, coalesce(vector_indices,'[]') as vector_indices, coalesce(vector_values,'[]') as vector_values " +
"from dw.poi_category_pre_data limit %s ";
public static final String COMPETITOR_DATA_SQL = "select id,coalesce(name,'') as name,coalesce(food,'') as food from dw.unknow_category_restaurant " +
"where dt='%s' and id is not null limit %s ";
public static void main(String[] args){
spark = initSaprk();
try{
MiniTrainningData[] miniTrainningDataArray = getTrainData();
final Broadcast trainningData = spark.sparkContext().broadcast(miniTrainningDataArray, ClassTag$.MODULE$.apply(MiniTrainningData[].class));
System.out.println("broadcast success and list is "+trainningData.value().length);
Dataset rawMeituanDataset = getMeituanDataSet();
Dataset meituanTfidDataset = TfidfUtil.tfidf(rawMeituanDataset);
Dataset similartyDataList = pickupTheTopSimilarShop(meituanTfidDataset, trainningData);
saveToHive(similartyDataList);
System.out.println("poi suggest stopped");
spark.stop();
} catch (Exception e) {
System.out.println("main method has error " + e.getMessage());
e.printStackTrace();
}
}
private static SparkSession initSaprk(){
long startTime = System.currentTimeMillis();
return SparkSession
.builder()
.appName("poi-spark")
.enableHiveSupport()
.getOrCreate();
}
/**
* to get the origin ele shop data including category and goods which is separated by '|'
* and then divide the goods into words
* @return Dataset
*/
private static MiniTrainningData[] getTrainData(){
String trainningSql = String.format(TRAINNING_DATA_SQL,20001);
System.out.println("tranningData sql is "+trainningSql);
spark.sql("use dw");
Dataset rowRdd = spark.sql(trainningSql);
List trainningDataList = rowRdd.javaRDD().map((row) -> {
MiniTrainningData data = new MiniTrainningData();
data.setEleShopId( row.getAs("shop_id"));
data.setCategory( row.getAs("category_id"));
Long vectorSize = row.getAs("vector_size");
List vectorIndices = JSON.parseArray(row.getAs("vector_indices"),Integer.class);
List vectorValues = JSON.parseArray(row.getAs("vector_values"),Double.class);
SparseVector vector = new SparseVector(vectorSize.intValue(),integerListToArray(vectorIndices),doubleListToArray(vectorValues));
data.setFeatures(vector);
return data;
}).collect();
MiniTrainningData[] miniTrainningDataArray = new MiniTrainningData[trainningDataList.size()];
return trainningDataList.toArray(miniTrainningDataArray);
}
private static int[] integerListToArray(List integerList){
int[] intArray = new int[integerList.size()];
for (int i = 0; i < integerList.size(); i++) {
intArray[i] = integerList.get(i).intValue();
}
return intArray;
}
private static double[] doubleListToArray(List doubleList){
double[] doubleArray = new double[doubleList.size()];
for (int i = 0; i < doubleList.size(); i++) {
doubleArray[i] = doubleList.get(i).intValue();
}
return doubleArray;
}
private static Dataset getMeituanDataSet() {
String meituanSql = String.format(COMPETITOR_DATA_SQL, YESTERDAY, 10000);
System.out.println("meituan sql is " + meituanSql);
spark.sql("use dw");
Dataset rowRdd = spark.sql(meituanSql);
JavaRDD meituanDataJavaRDD = rowRdd.javaRDD().map((row) -> {
MeiTuanData data = new MeiTuanData();
String goods = (String) row.getAs("food");
String shopName = (String) row.getAs("name");
data.setShopId((Long) row.getAs("id"));
data.setShopName(shopName);
if (StringUtil.isBlank(goods)) {
return null;
}
StringBuilder wordsOfGoods = new StringBuilder();
try {
List words = WordSegmenter.seg(goods.replace("|", " "));
for (Word word : words) {
wordsOfGoods.append(word.getText()).append(" ");
}
} catch (Exception e) {
System.out.println("exception in segment " + data);
}
data.setGoodsSegment(wordsOfGoods.toString());
return data;
}).filter((data) -> data != null);
System.out.println("meituan data count is " + meituanDataJavaRDD.count());
return spark.createDataFrame(meituanDataJavaRDD, MeiTuanData.class);
}
private static Dataset pickupTheTopSimilarShop(Dataset meituanTfidDataset, Broadcast trainningData){
return meituanTfidDataset.map(new MapFunction() {
@Override
public SimilartyData call(Row row) throws Exception {
SimilartyData similartyData = new SimilartyData();
Long mtShopId = row.getAs("shopId");
Vector meituanfeatures = row.getAs("features");
similartyData.setMtShopId(mtShopId);
MiniTrainningData[] trainDataArray = trainningData.value();
if(ArrayUtils.isEmpty(trainDataArray)){
return similartyData;
}
double maxSimilarty = 0;
long maxSimilarCategory = 0L;
long maxSimilareleShopId = 0;
for (MiniTrainningData trainData : trainDataArray) {
Vector trainningFeatures = trainData.getFeatures();
long categoryId = trainData.getCategory();
long eleShopId = trainData.getEleShopId();
double dot = BLAS.dot(meituanfeatures.toSparse(), trainningFeatures.toSparse());
double v1 = Vectors.norm(meituanfeatures.toSparse(), 2.0);
double v2 = Vectors.norm(trainningFeatures.toSparse(), 2.0);
double similarty = dot / (v1 * v2);
if(similarty>maxSimilarty){
maxSimilarty = similarty;
maxSimilarCategory = categoryId;
maxSimilareleShopId = eleShopId;
}
}
similartyData.setEleShopId(maxSimilareleShopId);
similartyData.setSimilarty(maxSimilarty);
similartyData.setCategoryId(maxSimilarCategory);
return similartyData;
}
}, Encoders.bean(SimilartyData.class));
}
private static void saveToHive(Dataset similartyDataset){
try {
similartyDataset.createTempView("records");
String sqlInsert = "insert overwrite table dw.poi_category_suggest PARTITION (dt = '"+DateTimeUtil.getYesterdayStr()+"') \n" +
"select mtShopId,eleShopId,shopName,similarty,categoryId from records ";
System.out.println(spark.sql(sqlInsert).count());
} catch (AnalysisException e) {
System.out.println("create SimilartyData dataFrame failed");
e.printStackTrace();
}
//Dataset resultSet = spark.createDataFrame(similartyDataset,SimilartyData.class);
spark.sql("use platform_dw");
}
}