机器学习初试(spark 文本相似度生产实践)

近期在负责公司的POI领域,全称为point of intrest即兴趣点,这个应用的最广泛的应该是地图行业,地图里每一个地址标注即为一个POI,在我们公司对它的含义进行了精简以契合公司业务的发展,将兴趣点集中在了餐饮及新零食相关的商户&超市等范畴。
听上去这个业务只是做一些商户数据的收集校正,那为什么这个业务会牵扯到了机器学习呢?真实原因很尴尬不便多说,目前我们拿到了一些商户的数据,但是无法获取品类,而品类对于我们当前业务来说非常重要,涉及到不同业务线的利益问题。所以需要通过一些特殊手段来识别出这些商户的品类。

场景

通过已有的商户数据,包括商户名称、商户菜品,识别出该商户属于什么品类(川湘菜、日料等)

解决办法

  1. 委托算法团队协助计算商户品类
    已有相似项目,可基于该项目做品类计算,实现快速
  2. 自己琢磨,研究算法
    学习算法以及应用到生产,较为耗时

最终选择

算法团队在人力资源安排上出现问题,不得已由自己来做算法计算,好在相关算法项目交接到我们团队,可以借此做为参考。
考虑到当前快速实现以及以后发展方向,最终选择两路同时并进,一路基于python脚本进行计算用于短期快速实现,另一路尝试通过spark ML进行分布式计算为我们长期目标。

算法思路

将商户的所有的菜品信息以及商户信息当成一串文本来处理,这样就可以把这个问题看成是“文本相似度”的问题,TF-IDF、LSI、LDA等一系列算法可以参考,python脚本采用TF-IDF和LSI来进行相似度计算(简单示例参考)。spark采用TF-IDF以及余弦相似度进行验证性计算(后续优化算法)。因长期规则为使用spark进行机器学习的相关计算,以下重点介绍spark上如何应用。

算法介绍

TF-IDF
余弦相似度计算

spark ML实现

分为两个spark任务,第一个任务为计算当前线上已经存在的且正确匹配的商户&菜品的TF-IDF值并且将计算出的值保存到hive表里。

任务一

数据预处理以及参考数据TF-IDF计算,通过计算
首先通过hive任务将商户的菜品数据拍平,这一步很简单,拍平后的数据如下:


商户菜品信息

然后另起spark任务对商户菜品进行TF-IDF处理,将结果保存到如下表里面。vector_indices及vector_values都为数组且长度一致,两者共同表示为多组向量


hive表结构设计

如下为tfidf工具类

public class TfidfUtil {
    /**
     * visit below website to get more detail about tfidf
     * @see  Spark入门:特征抽取: TF-IDF
     * @param dataset
     * @return
     */
    public static Dataset tfidf(Dataset dataset) {
        Tokenizer tokenizer = new Tokenizer().setInputCol("goodsSegment").setOutputCol("words");
        Dataset wordsData = tokenizer.transform(dataset);
        HashingTF hashingTF = new HashingTF()
                .setInputCol("words")
                .setOutputCol("rawFeatures");
        Dataset featurizedData = hashingTF.transform(wordsData);
        IDF idf = new IDF().setInputCol("rawFeatures").setOutputCol("features");
        IDFModel idfModel = idf.fit(featurizedData);
        return idfModel.transform(featurizedData);
    }
}

如下为spark预处理任务,主要步骤为获取商户及拍平的菜品数据,再做TF-IDF,再保存到hive表

public class CategorySuggestionTrainning {

    private static SparkSession spark;

    private static final String YESTERDAY = DateTimeUtil.getYesterdayStr();

    public static final String TRAINNING_DATA_SQL = "select id, coalesce(shop_name,'') as name,coalesce(category_id,0) as category_id, coalesce(food_name,'') as food " +
            "from dw.category_and_foodname where dt='%s' limit 100000";

    public static void main(String[] args){
        spark = initSaprk();
        try{
            Dataset rawTranningDataset = getTrainDataSet();
            Dataset trainningTfidfDataset = TfidfUtil.tfidf(rawTranningDataset);
            JavaRDD trainningFeatureRdd = getTrainningFeatureRDD(trainningTfidfDataset);
            Dataset trainningFeaturedataset = spark.createDataFrame(trainningFeatureRdd,TrainningFeature.class);

            saveToHive(trainningFeaturedataset);
            System.out.println("poi suggest trainning stopped");
            spark.stop();
        } catch (Exception e) {
            System.out.println("main method has error " + e.getMessage());
            e.printStackTrace();
        }
    }

    /**
     * to get the origin ele shop data including category and goods which is separated by '|'
     * and then divide the goods into words
     * @return Dataset
     */
    private static Dataset getTrainDataSet(){
        String trainningSql = String.format(TRAINNING_DATA_SQL,YESTERDAY);
        System.out.println("tranningData sql is "+trainningSql);
        spark.sql("use dw");
        Dataset rowRdd = spark.sql(trainningSql);
        JavaRDD trainningDataJavaRDD = rowRdd.javaRDD().map((row) -> {
            String goods = (String) row.getAs("food");
            String shopName = (String) row.getAs("name");
            if (StringUtil.isBlank(shopName) || StringUtil.isBlank(goods) || goods.length() < 50) {
                System.out.println("some field is null " + row.toString());
                return null;
            }
            TrainningData data = new TrainningData();
            data.setShopId((Long) row.getAs("id"));
            data.setShopName(shopName);
            data.setCategory((Long) row.getAs("category_id"));

            List words = WordSegmenter.seg(goods);
            StringBuilder wordsOfGoods = new StringBuilder();
            for (Word word : words) {
                wordsOfGoods.append(word.getText()).append(" ");
            }
            data.setGoodsSegment(wordsOfGoods.toString());
            return data;
        }).filter((data) -> data != null);
        return spark.createDataFrame(trainningDataJavaRDD, TrainningData.class);
    }

    private static JavaRDD getTrainningFeatureRDD(Dataset trainningTfidfDataset){
        return trainningTfidfDataset.javaRDD().map(new Function(){
            @Override
            public TrainningFeature call(Row row) throws Exception {
                TrainningFeature data = new TrainningFeature();
                data.setCategory(row.getAs("category"));
                data.setShopId(row.getAs("shopId"));
                data.setShopName(row.getAs("shopName"));
                SparseVector vector = row.getAs("features");
                data.setVectorSize(vector.size());
                data.setVectorIndices(Arrays.toString(vector.indices()));
                data.setVectorValues(Arrays.toString(vector.values()));
                return data;
            }
        });
    }

    private static SparkSession initSaprk(){
        long startTime = System.currentTimeMillis();
        return SparkSession
                .builder()
                .appName("poi-spark-trainning")
                .enableHiveSupport()
                .getOrCreate();
    }

    private static void saveToHive(Dataset trainningTfidfDataset){
        try {
            trainningTfidfDataset.createTempView("trainData");
            String sqlInsert = "insert overwrite table dw.poi_category_pre_data " +
                    "select shopId,shopName,category,vectorSize,vectorIndices,vectorValues from trainData ";

            spark.sql("use dw");
            System.out.println(spark.sql(sqlInsert).count());
        } catch (AnalysisException e) {
            System.out.println("save tranning data to hive failed");
            e.printStackTrace();
        }
    }
}

任务二

取出预处理好的数据以及待确定分类的商户数据,将两者做余弦相似度计算,选择相似度最高的预处理的商户的分类做为待确认商户的分类。
相关代码如下

public class CategorySuggestion {

    private static SparkSession spark;

    private static final String YESTERDAY = DateTimeUtil.getYesterdayStr();

    private static boolean CALCULATE_ALL = false;

    private static long MT_SHOP_COUNT = 2000;

    public static final String TRAINNING_DATA_SQL = "select shop_id, coalesce(shop_name,'') as shop_name,coalesce(category,0) as category_id, " +
            "vector_size, coalesce(vector_indices,'[]') as vector_indices, coalesce(vector_values,'[]') as vector_values " +
            "from dw.poi_category_pre_data limit %s ";

    public static final String COMPETITOR_DATA_SQL = "select id,coalesce(name,'') as name,coalesce(food,'') as food from dw.unknow_category_restaurant " +
            "where dt='%s' and id is not null limit %s ";

    public static void main(String[] args){
        spark = initSaprk();
        try{
            MiniTrainningData[] miniTrainningDataArray = getTrainData();
            final Broadcast trainningData = spark.sparkContext().broadcast(miniTrainningDataArray, ClassTag$.MODULE$.apply(MiniTrainningData[].class));
            System.out.println("broadcast success and list is "+trainningData.value().length);

            Dataset rawMeituanDataset = getMeituanDataSet();
            Dataset meituanTfidDataset = TfidfUtil.tfidf(rawMeituanDataset);
            Dataset similartyDataList = pickupTheTopSimilarShop(meituanTfidDataset, trainningData);

            saveToHive(similartyDataList);
            System.out.println("poi suggest stopped");
            spark.stop();
        } catch (Exception e) {
            System.out.println("main method has error " + e.getMessage());
            e.printStackTrace();
        }
    }

    private static SparkSession initSaprk(){
        long startTime = System.currentTimeMillis();
        return SparkSession
                .builder()
                .appName("poi-spark")
                .enableHiveSupport()
                .getOrCreate();
    }

    /**
     * to get the origin ele shop data including category and goods which is separated by '|'
     * and then divide the goods into words
     * @return Dataset
     */
    private static MiniTrainningData[] getTrainData(){
        String trainningSql = String.format(TRAINNING_DATA_SQL,20001);
        System.out.println("tranningData sql is "+trainningSql);
        spark.sql("use dw");
        Dataset rowRdd = spark.sql(trainningSql);
        List trainningDataList = rowRdd.javaRDD().map((row) -> {
            MiniTrainningData data = new MiniTrainningData();
            data.setEleShopId( row.getAs("shop_id"));
            data.setCategory( row.getAs("category_id"));
            Long vectorSize = row.getAs("vector_size");
            List vectorIndices = JSON.parseArray(row.getAs("vector_indices"),Integer.class);
            List vectorValues = JSON.parseArray(row.getAs("vector_values"),Double.class);
            SparseVector vector = new SparseVector(vectorSize.intValue(),integerListToArray(vectorIndices),doubleListToArray(vectorValues));
            data.setFeatures(vector);
            return data;
        }).collect();
        MiniTrainningData[] miniTrainningDataArray = new MiniTrainningData[trainningDataList.size()];
        return trainningDataList.toArray(miniTrainningDataArray);
    }

    private static int[] integerListToArray(List integerList){
        int[] intArray = new int[integerList.size()];
        for (int i = 0; i < integerList.size(); i++) {
            intArray[i] = integerList.get(i).intValue();
        }
        return intArray;
    }

    private static double[] doubleListToArray(List doubleList){
        double[] doubleArray = new double[doubleList.size()];
        for (int i = 0; i < doubleList.size(); i++) {
            doubleArray[i] = doubleList.get(i).intValue();
        }
        return doubleArray;
    }

    private static Dataset getMeituanDataSet() {
        String meituanSql = String.format(COMPETITOR_DATA_SQL, YESTERDAY, 10000);
        System.out.println("meituan sql is " + meituanSql);
        spark.sql("use dw");
        Dataset rowRdd = spark.sql(meituanSql);
        JavaRDD meituanDataJavaRDD = rowRdd.javaRDD().map((row) -> {
            MeiTuanData data = new MeiTuanData();
            String goods = (String) row.getAs("food");
            String shopName = (String) row.getAs("name");
            data.setShopId((Long) row.getAs("id"));
            data.setShopName(shopName);
            if (StringUtil.isBlank(goods)) {
                return null;
            }
            StringBuilder wordsOfGoods = new StringBuilder();
            try {
                List words = WordSegmenter.seg(goods.replace("|", " "));
                for (Word word : words) {
                    wordsOfGoods.append(word.getText()).append(" ");
                }
            } catch (Exception e) {
                System.out.println("exception in segment " + data);
            }
            data.setGoodsSegment(wordsOfGoods.toString());
            return data;
        }).filter((data) -> data != null);
        System.out.println("meituan data count is " + meituanDataJavaRDD.count());
        return spark.createDataFrame(meituanDataJavaRDD, MeiTuanData.class);
    }

    private static Dataset pickupTheTopSimilarShop(Dataset meituanTfidDataset, Broadcast trainningData){
         return meituanTfidDataset.map(new MapFunction() {
            @Override
            public SimilartyData call(Row row) throws Exception {
                SimilartyData similartyData = new SimilartyData();
                Long mtShopId = row.getAs("shopId");
                Vector meituanfeatures = row.getAs("features");
                similartyData.setMtShopId(mtShopId);
                MiniTrainningData[] trainDataArray = trainningData.value();
                if(ArrayUtils.isEmpty(trainDataArray)){
                    return similartyData;
                }
                double maxSimilarty = 0;
                long maxSimilarCategory = 0L;
                long maxSimilareleShopId = 0;
                for (MiniTrainningData trainData : trainDataArray) {
                    Vector trainningFeatures = trainData.getFeatures();
                    long categoryId = trainData.getCategory();
                    long eleShopId = trainData.getEleShopId();
                    double dot = BLAS.dot(meituanfeatures.toSparse(), trainningFeatures.toSparse());
                    double v1 = Vectors.norm(meituanfeatures.toSparse(), 2.0);
                    double v2 = Vectors.norm(trainningFeatures.toSparse(), 2.0);
                    double similarty = dot / (v1 * v2);
                    if(similarty>maxSimilarty){
                        maxSimilarty = similarty;
                        maxSimilarCategory = categoryId;
                        maxSimilareleShopId = eleShopId;
                    }
                }
                similartyData.setEleShopId(maxSimilareleShopId);
                similartyData.setSimilarty(maxSimilarty);
                similartyData.setCategoryId(maxSimilarCategory);
                return similartyData;
            }
        }, Encoders.bean(SimilartyData.class));
    }

    private static void saveToHive(Dataset similartyDataset){
        try {
            similartyDataset.createTempView("records");
            String sqlInsert = "insert overwrite table dw.poi_category_suggest  PARTITION (dt = '"+DateTimeUtil.getYesterdayStr()+"') \n" +
                    "select mtShopId,eleShopId,shopName,similarty,categoryId from records ";
            System.out.println(spark.sql(sqlInsert).count());
        } catch (AnalysisException e) {
            System.out.println("create SimilartyData dataFrame failed");
            e.printStackTrace();
        }
        //Dataset resultSet = spark.createDataFrame(similartyDataset,SimilartyData.class);
        spark.sql("use platform_dw");
    }
}

你可能感兴趣的:(机器学习初试(spark 文本相似度生产实践))