Kaggle文本挖掘获奖选手代码解析(一):数据预处理

一.题目背景

kaggle上这三道题目都是和文本相似度相关的,要求建模评估两个文本内容的相关度。

两个电商场景:
https://www.kaggle.com/c/crowdflower-search-relevance
https://www.kaggle.com/c/home-depot-product-search-relevance
要求评估用户输入的搜索项(query)和查询的商品之间的相似度

一个论坛场景:
https://www.kaggle.com/c/quora-question-pairs
要求评估两个问题之间的相似度

虽然应用场景不同, 但我们可以有一些共同的套路来解决这些问题。

二.第三名代码解析

文章中出现的代码都来自这个github地址: https://github.com/ChenglongChen/Kaggle_HomeDepot

三.代码解析

1. pattern-replace基类

初始化: 传一个pattern_replace_pair_list, 把满足pattern的部分转换成replace
转换: 按每一对pattern, replace(tuple格式)做替换, 去掉开头结尾空格(strip())

class BaseReplacer:
    def __init__(self, pattern_replace_pair_list=[]):
        self.pattern_replace_pair_list = pattern_replace_pair_list
        
    def transform(self, text):
        for pattern, replace in self.pattern_replace_pair_list:
            try:
                text = regex.sub(pattern, replace, text)
            except:
                pass
        return regex.sub(r"\s+", " ", text).strip()

2. 首字母大写变小写

覆盖父类transform方法

class LowerCaseConverter(BaseReplacer):
    def transform(self, text):
        return text.lower()

3. 本应该分开的两个单词,第二个单词首字母大写

把:hidden from viewDurable rich finishLimited lifetime warrantyEncapsulated panels
变成: hidden from view Durable rich finish limited lifetime warranty Encapsulated panels

class LowerUpperCaseSplitter(BaseReplacer):
    def __init__(self):
        self.pattern_replace_pair_list = [
            (r"(\w)[\.?!]([A-Z])", r"\1 \2"),
            (r"(?<=( ))([a-z]+)([A-Z]+)", r"\2 \3"),
        ]

正则表达式: ()表示分组, \1 \2 \3可以取第几个分组

4. 单词替换

给一个词典, 一一对应替换, 用于改正typo
子类初始化时只积累pattern_replace_pair_list, 并没有完成替换.
调用transform方法才替换

class WordReplacer(BaseReplacer):
    def __init__(self, replace_fname):
        self.replace_fname = replace_fname  # 词典文件路径
        self.pattern_replace_pair_list = []
        for line in csv.reader(open(self.replace_fname)):
            if len(line) == 1 and line[0].startswith("#"):
                continue
            try:
                pattern = r"(?<=\W|^)%s(?=\W|$)"%line[0]   
                replace = line[1]
                self.pattern_replace_pair_list.append( (pattern, replace) )
            except:
                print(line)
                pass

5. 分割连接符连接的单词

类似上面()正则的用法

class LetterLetterSplitter(BaseReplacer):
    """
    For letter and letter
    /:
    Cleaner/Conditioner -> Cleaner Conditioner

    -:
    Vinyl-Leather-Rubber -> Vinyl Leather Rubber

    For digit and digit, we keep it as we will generate some features via math operations,
    such as approximate height/width/area etc.
    /:
    3/4 -> 3/4

    -:
    1-1/4 -> 1-1/4
    """
    def __init__(self):
        self.pattern_replace_pair_list = [
            (r"([a-zA-Z]+)[/\-]([a-zA-Z]+)", r"\1 \2"),
        ]

6. 数字加字母

数字.或-字母
字母.或-数字

class DigitLetterSplitter(BaseReplacer):
    """
    x:
    1x1x1x1x1 -> 1 x 1 x 1 x 1 x 1
    19.875x31.5x1 -> 19.875 x 31.5 x 1

    -:
    1-Gang -> 1 Gang
    48-Light -> 48 Light

    .:
    includes a tile flange to further simplify installation.60 in. L x 36 in. W x 20 in. ->
    includes a tile flange to further simplify installation. 60 in. L x 36 in. W x 20 in.
    """
    def __init__(self):
        self.pattern_replace_pair_list = [
            (r"(\d+)[\.\-]*([a-zA-Z]+)", r"\1 \2"),
            (r"([a-zA-Z]+)[\.\-]*(\d+)", r"\1 \2"),
        ]

7. 去掉数字中的逗号

class DigitCommaDigitMerger(BaseReplacer):
    """
    1,000,000 -> 1000000
    """
    def __init__(self):
        self.pattern_replace_pair_list = [
            (r"(?<=\d+),(?=000)", r""),
        ]

8. 单词的数字换成阿拉伯数字

class NumberDigitMapper(BaseReplacer):
    """
    one -> 1
    two -> 2
    """
    def __init__(self):
        numbers = [
            "zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine", "ten",
            "eleven", "twelve", "thirteen", "fourteen", "fifteen", "sixteen", "seventeen", "eighteen",
            "nineteen", "twenty", "thirty", "forty", "fifty", "sixty", "seventy", "eighty", "ninety"
        ]
        digits = [
            0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
            16, 17, 18, 19, 20, 30, 40, 50, 60, 70, 80, 90
        ]
        self.pattern_replace_pair_list = [
            (r"(?<=\W|^)%s(?=\W|$)"%n, str(d)) for n,d in zip(numbers, digits)
        ]

9. 统一单位

class UnitConverter(BaseReplacer):
    """
    shadeMature height: 36 in. - 48 in.Mature width
    PUT one UnitConverter before LowerUpperCaseSplitter
    """
    def __init__(self):
        self.pattern_replace_pair_list = [
            (r"([0-9]+)( *)(inches|inch|in|in.|')\.?", r"\1 in. "),
            (r"([0-9]+)( *)(pounds|pound|lbs|lb|lb.)\.?", r"\1 lb. "),
            (r"([0-9]+)( *)(foot|feet|ft|ft.|'')\.?", r"\1 ft. "),
            (r"([0-9]+)( *)(square|sq|sq.) ?\.?(inches|inch|in|in.|')\.?", r"\1 sq.in. "),
            (r"([0-9]+)( *)(square|sq|sq.) ?\.?(feet|foot|ft|ft.|'')\.?", r"\1 sq.ft. "),
            (r"([0-9]+)( *)(cubic|cu|cu.) ?\.?(inches|inch|in|in.|')\.?", r"\1 cu.in. "),
            (r"([0-9]+)( *)(cubic|cu|cu.) ?\.?(feet|foot|ft|ft.|'')\.?", r"\1 cu.ft. "),
            (r"([0-9]+)( *)(gallons|gallon|gal)\.?", r"\1 gal. "),
            (r"([0-9]+)( *)(ounces|ounce|oz)\.?", r"\1 oz. "),
            (r"([0-9]+)( *)(centimeters|cm)\.?", r"\1 cm. "),
            (r"([0-9]+)( *)(milimeters|mm)\.?", r"\1 mm. "),
            (r"([0-9]+)( *)(minutes|minute)\.?", r"\1 min. "),
            (r"([0-9]+)( *)(°|degrees|degree)\.?", r"\1 deg. "),
            (r"([0-9]+)( *)(v|volts|volt)\.?", r"\1 volt. "),
            (r"([0-9]+)( *)(wattage|watts|watt)\.?", r"\1 watt. "),
            (r"([0-9]+)( *)(amperes|ampere|amps|amp)\.?", r"\1 amp. "),
            (r"([0-9]+)( *)(qquart|quart)\.?", r"\1 qt. "),
            (r"([0-9]+)( *)(hours|hour|hrs.)\.?", r"\1 hr "),
            (r"([0-9]+)( *)(gallons per minute|gallon per minute|gal per minute|gallons/min.|gallons/min)\.?", r"\1 gal. per min. "),
            (r"([0-9]+)( *)(gallons per hour|gallon per hour|gal per hour|gallons/hour|gallons/hr)\.?", r"\1 gal. per hr "),
        ]

10. 去掉html标签

class HtmlCleaner:
    def __init__(self, parser):
        self.parser = parser

    def transform(self, text):
        bs = BeautifulSoup(text, self.parser)
        text = bs.get_text(separator=" ")
        return text

11. 标点乱码替换成原始标点

class QuartetCleaner(BaseReplacer):
    def __init__(self):
        self.pattern_replace_pair_list = [
            (r"<.+?>", r""),
            # html codes
            (r" ", r" "),
            (r"&", r"&"),
            (r"'", r"'"),
            (r"/>/Agt/>", r""),
            (r"", r""),
            (r"/>", r""),
            (r")(_,;:!?\+^~@#\$]+", r" "),
            ("'s\\b", r""),
            (r"[']+", r""),
            (r"[\"]+", r""),
        ]

12. 提取词干Lemma和词型归一stemmer

区别:Lemma更短,更"根本".比如
Stemmers: having->hav
Lemmatizers: having->have

class Lemmatizer:
    def __init__(self):
        self.Tokenizer = nltk.tokenize.TreebankWordTokenizer()
        self.Lemmatizer = nltk.stem.wordnet.WordNetLemmatizer()

    def transform(self, text):
        tokens = [self.Lemmatizer.lemmatize(token) for token in self.Tokenizer.tokenize(text)]
        return " ".join(tokens)


## stemming
class Stemmer:
    def __init__(self, stemmer_type="snowball"):
        self.stemmer_type = stemmer_type
        if self.stemmer_type == "porter":
            self.stemmer = nltk.stem.PorterStemmer()
        elif self.stemmer_type == "snowball":
            self.stemmer = nltk.stem.SnowballStemmer("english")

    def transform(self, text):
        tokens = [self.stemmer.stem(token) for token in text.split(" ")]
        return " ".join(tokens)

13. Query拓展

因为title和search term属于短文本, 做一些处理丰富这部分内容。
title_ngram: 把商品的title部分做ngram处理,连续的n个单词用一个连接符连在一起。
search_term_alt: 搜索项部分找每个搜索项中最常见的title_ngram。

class QueryExpansion:
    def __init__(self, df, ngram=3, stopwords_threshold=0.9, base_stopwords=set()):
        self.df = df[["search_term", "product_title"]].copy()
        self.ngram = ngram
        self.stopwords_threshold = stopwords_threshold
        self.stopwords = set(base_stopwords).union(self._get_customized_stopwords())
        
    def _get_customized_stopwords(self):
        words = " ".join(list(self.df["product_title"].values)).split(" ")
        counter = Counter(words)
        num_uniq = len(list(counter.keys()))
        num_stop = int((1.-self.stopwords_threshold)*num_uniq)
        stopwords = set()
        for e,(w,c) in enumerate(sorted(counter.items(), key=lambda x: x[1])):
            if e == num_stop:
                break
            stopwords.add(w)
        return stopwords

    def _ngram(self, text):
        tokens = text.split(" ")
        tokens = [token for token in tokens if token not in self.stopwords]
        return ngram_utils._ngrams(tokens, self.ngram, " ")

    def _get_alternative_query(self, df):
        res = []
        for v in df:
            res += v
        c = Counter(res)
        value, count = c.most_common()[0]
        return value

    def build(self):
        self.df["title_ngram"] = self.df["product_title"].apply(self._ngram)
        corpus = self.df.groupby("search_term").apply(lambda df: self._get_alternative_query(df["title_ngram"]))
        corpus = corpus.reset_index()
        corpus.columns = ["search_term", "search_term_alt"]
        self.df = pd.merge(self.df, corpus, on="search_term", how="left")
        return self.df["search_term_alt"].values

其中的ngram,以n=2为例:

def _bigrams(words, join_string, skip=0):
    """
       Input: a list of words, e.g., ["I", "am", "Denny"]
       Output: a list of bigram, e.g., ["I_am", "am_Denny"]
       I use _ as join_string for this example.
    """
    assert type(words) == list
    L = len(words)
    if L > 1:
        lst = []
        for i in range(L-1):
            for k in range(1,skip+2):
                if i+k < L:
                    lst.append( join_string.join([words[i], words[i+k]]) )
    else:
        # set it as unigram
        lst = _unigrams(words)
    return lst

14. 处理商品名称

transform前需要先用前面的类统一格式

class ProductNameExtractor(BaseReplacer):
    def __init__(self):
        self.pattern_replace_pair_list = [
            # Remove descriptions (text between paranthesis/brackets)
            ("[ ]?[[(].+?[])]", r""),
            # Remove "made in..."
            ("made in [a-z]+\\b", r""),
            # Remove descriptions (hyphen or comma followed by space then at most 2 words, repeated)
            ("([,-]( ([a-zA-Z0-9]+\\b)){1,2}[ ]?){1,}$", r""),
            # Remove descriptions (prepositions staring with: with, for, by, in )
            ("\\b(with|for|by|in|w/) .+$", r""),
            # colors & sizes
            ("size: .+$", r""),
            ("size [0-9]+[.]?[0-9]+\\b", r""),
            (COLORS_PATTERN, r""),
            # dimensions
            (DIM_PATTERN_NxNxN, r""),
            (DIM_PATTERN_NxN, r""),
            # measurement units
            (UNITS_PATTERN, r""),
            # others
            ("(value bundle|warranty|brand new|excellent condition|one size|new in box|authentic|as is)", r""),
            # stop words
            ("\\b(in)\\b", r""),
            # hyphenated words
            ("([a-zA-Z])-([a-zA-Z])", r"\1\2"),
            # special characters
            ("[ &<>)(_,.;:!?/+#*-]+", r" "),
            # numbers that are not part of a word
            ("\\b[0-9]+\\b", r""),
        ]
        
    def preprocess(self, text):
        pattern_replace_pair_list = [
            # Remove single & double apostrophes
            ("[\"]+", r""),
            # Remove product codes (long words (>5 characters) that are all caps, numbers or mix pf both)
            # don't use raw string format
            ("[ ]?\\b[0-9A-Z-]{5,}\\b", ""),
        ]
        text = BaseReplacer(pattern_replace_pair_list).transform(text)
        text = LowerCaseConverter().transform(text)
        text = DigitLetterSplitter().transform(text)
        text = UnitConverter().transform(text)
        text = DigitCommaDigitMerger().transform(text)
        text = NumberDigitMapper().transform(text)
        text = UnitConverter().transform(text)
        return text
        
    def transform(self, text):
        text = super().transform(self.preprocess(text))
        text = Lemmatizer().transform(text)
        text = Stemmer(stemmer_type="snowball").transform(text)
        # last two words in product
        text = " ".join(text.split(" ")[-2:])
        return text

15. 处理商品属性

根据输出的不同格式写两个方法

def _split_attr_to_text(text):
    attrs = text.split(config.ATTR_SEPARATOR)
    return " ".join(attrs)

def _split_attr_to_list(text):
    attrs = text.split(config.ATTR_SEPARATOR)        
    if len(attrs) == 1:
        # missing
        return [[attrs[0], attrs[0]]]
    else:
        return [[n,v] for n,v in zip(attrs[::2], attrs[1::2])]

16. 处理不同数据结构的输入输出

待处理的数据可能放在list里,也可能放在dataframe里, 处理方法稍有不同。

class ListProcessor:
    """
    WARNING: This class will operate on the original input list itself
    """
    def __init__(self, processors):
        self.processors = processors

    def process(self, lst):
        for i in range(len(lst)):
            for processor in self.processors:
                lst[i] = ProcessorWrapper(processor).transform(lst[i])
        return lst


class DataFrameProcessor:
    """
    WARNING: This class will operate on the original input dataframe itself
    """
    def __init__(self, processors):
        self.processors = processors

    def process(self, df):
        for processor in self.processors:
            df = df.apply(ProcessorWrapper(processor).transform)
        return df


class DataFrameParallelProcessor:
    """
    WARNING: This class will operate on the original input dataframe itself

    https://stackoverflow.com/questions/26520781/multiprocessing-pool-whats-the-difference-between-map-async-and-imap
    """
    def __init__(self, processors, n_jobs=4):
        self.processors = processors
        self.n_jobs = n_jobs

    def process(self, dfAll, columns):
        df_processor = DataFrameProcessor(self.processors)
        p = multiprocessing.Pool(self.n_jobs)
        dfs = p.imap(df_processor.process, [dfAll[col] for col in columns])
        for col,df in zip(columns, dfs):
            dfAll[col] = df
        return dfAll

四. 总结

把上面的代码都用户合并到一起就是main函数咯。总结一下大神的代码有几点很值得学习:

  1. 重复的内容用类,和类的继承减少代码重复(各种replacer),并且放到一起使得逻辑清晰。
  2. 用config文件中的字段控制部分代码是否运行(各种if config.大写单词), 比反复注释,取消注释节省时间。
  3. 各种设置统一到一个config文件中,避免调试的时候到处找,浪费时间。
def main():

    ###########
    ## Setup ##
    ###########
    logname = "data_processor_%s.log"%now
    logger = logging_utils._get_logger(config.LOG_DIR, logname)

    # put product_attribute_list, product_attribute and product_description first as they are
    # quite time consuming to process
    columns_to_proc = [
        # # product_attribute_list is very time consuming to process
        # # so we just process product_attribute which is of the form 
        # # attr_name1 | attr_value1 | attr_name2 | attr_value2 | ...
        # # and split it into a list afterwards
        # "product_attribute_list",
        "product_attribute_concat",
        "product_description",
        "product_brand", 
        "product_color",
        "product_title",
        "search_term", 
    ]
    if config.PLATFORM == "Linux":
        config.DATA_PROCESSOR_N_JOBS = len(columns_to_proc)

    # clean using a list of processors
    processors = [
        LowerCaseConverter(), 
        # See LowerUpperCaseSplitter and UnitConverter for why we put UnitConverter here
        UnitConverter(),
        LowerUpperCaseSplitter(), 
        WordReplacer(replace_fname=config.WORD_REPLACER_DATA), 
        LetterLetterSplitter(),
        DigitLetterSplitter(), 
        DigitCommaDigitMerger(), 
        NumberDigitMapper(),
        UnitConverter(), 
        QuartetCleaner(), 
        HtmlCleaner(parser="html.parser"), 
        Lemmatizer(),
    ]
    stemmers = [
        Stemmer(stemmer_type="snowball"), 
        Stemmer(stemmer_type="porter")
    ][0:1]

    ## simple test
    text = "1/2 inch rubber lep tips Bullet07"
    print("Original:")
    print(text)
    list_processor = ListProcessor(processors)
    print("After:")
    print(list_processor.process([text]))

    #############
    ## Process ##
    #############
    ## load raw data
    dfAll = pkl_utils._load(config.ALL_DATA_RAW)
    columns_to_proc = [col for col in columns_to_proc if col in dfAll.columns]


    ## extract product name from search_term and product_title
    ext = ProductNameExtractor()
    dfAll["search_term_product_name"] = dfAll["search_term"].apply(ext.transform)
    dfAll["product_title_product_name"] = dfAll["product_title"].apply(ext.transform)
    if config.TASK == "sample":
        print(dfAll[["search_term", "search_term_product_name", "product_title_product_name"]])


    ## clean using GoogleQuerySpellingChecker
    # MUST BE IN FRONT OF ALL THE PROCESSING
    if config.GOOGLE_CORRECTING_QUERY:
        logger.info("Run GoogleQuerySpellingChecker at search_term")
        checker = GoogleQuerySpellingChecker()
        dfAll["search_term"] = dfAll["search_term"].apply(checker.correct)


    ## clean uisng a list of processors
    df_processor = DataFrameParallelProcessor(processors, config.DATA_PROCESSOR_N_JOBS)
    df_processor.process(dfAll, columns_to_proc)
    # split product_attribute_concat into product_attribute and product_attribute_list
    dfAll["product_attribute"] = dfAll["product_attribute_concat"].apply(_split_attr_to_text)
    dfAll["product_attribute_list"] = dfAll["product_attribute_concat"].apply(_split_attr_to_list)
    if config.TASK == "sample":
        print(dfAll[["product_attribute", "product_attribute_list"]])
    # query expansion
    if config.QUERY_EXPANSION:
        list_processor = ListProcessor(processors)
        base_stopwords = set(list_processor.process(list(config.STOP_WORDS)))
        qe = QueryExpansion(dfAll, ngram=3, stopwords_threshold=0.9, base_stopwords=base_stopwords)
        dfAll["search_term_alt"] = qe.build()
        if config.TASK == "sample":
            print(dfAll[["search_term", "search_term_alt"]])
    # save data
    logger.info("Save to %s"%config.ALL_DATA_LEMMATIZED)
    columns_to_save = [col for col in dfAll.columns if col != "product_attribute_concat"]
    pkl_utils._save(config.ALL_DATA_LEMMATIZED, dfAll[columns_to_save])


    ## auto correcting query
    if config.AUTO_CORRECTING_QUERY:
        logger.info("Run AutoSpellingChecker at search_term")
        checker = AutoSpellingChecker(dfAll, exclude_stopwords=False, min_len=4)
        dfAll["search_term_auto_corrected"] = list(dfAll["search_term"].apply(checker.correct))
        columns_to_proc += ["search_term_auto_corrected"]
        if config.TASK == "sample":
            print(dfAll[["search_term", "search_term_auto_corrected"]])
        # save query_correction_map and spelling checker
        fname = "%s/auto_spelling_checker_query_correction_map_%s.log"%(config.LOG_DIR, now)
        checker.save_query_correction_map(fname)
        # save data
        logger.info("Save to %s"%config.ALL_DATA_LEMMATIZED)
        columns_to_save = [col for col in dfAll.columns if col != "product_attribute_concat"]
        pkl_utils._save(config.ALL_DATA_LEMMATIZED, dfAll[columns_to_save])


    ## clean using stemmers
    df_processor = DataFrameParallelProcessor(stemmers, config.DATA_PROCESSOR_N_JOBS)
    df_processor.process(dfAll, columns_to_proc)
    # split product_attribute_concat into product_attribute and product_attribute_list
    dfAll["product_attribute"] = dfAll["product_attribute_concat"].apply(_split_attr_to_text)
    dfAll["product_attribute_list"] = dfAll["product_attribute_concat"].apply(_split_attr_to_list)
    # query expansion
    if config.QUERY_EXPANSION:
        list_processor = ListProcessor(stemmers)
        base_stopwords = set(list_processor.process(list(config.STOP_WORDS)))
        qe = QueryExpansion(dfAll, ngram=3, stopwords_threshold=0.9, base_stopwords=base_stopwords)
        dfAll["search_term_alt"] = qe.build()
        if config.TASK == "sample":
            print(dfAll[["search_term", "search_term_alt"]])
    # save data
    logger.info("Save to %s"%config.ALL_DATA_LEMMATIZED_STEMMED)
    columns_to_save = [col for col in dfAll.columns if col != "product_attribute_concat"]
    pkl_utils._save(config.ALL_DATA_LEMMATIZED_STEMMED, dfAll[columns_to_save])

你可能感兴趣的:(Kaggle文本挖掘获奖选手代码解析(一):数据预处理)