NLP_文本去重_附Python实现【MinHash和MinHashLSH】算法

NLP_文本去重_附Python实现【MinHash和MinHashLSH】算法

  • 前言
  • 代码的实现【注释丰富】

前言

大规模的文本去重是目前比较热门的一个技术,由于大模型的兴起,更多的高质量数据集也是大家迫切需要的。

关于如何进行文本去重?

直观的方法首先是利用Python正则表达式进行去重。
推荐学习:1. re — 正则表达式操作 2. 正则表达式 - 教程

然后是利用文本之间的相似度进行去重。
这里主要讲第二种。
推荐学习:1. 张振虎大佬的博客 2. Github的实现源码 3. 文本内容相似度计算方法:minhash 4. Python的datasketch库中的MinHashLSH

ok!这里不多赘述关于MinHash的更多基础知识,感兴趣的朋友相信在网上也能找到。
Talk is cheap. Show me the code.
然后在看代码之前说一下算法的实现流程也是非常有帮助的

我们要实现去重的话,需要实现一个query操作,即给定一个文本计算出数据中有哪些是与该文本相似的。然后在对该集合进行相应的操作进行去重。而这里MinHash只是用来估计两个文本之间的 Jaccard 相似度和基数(estimate Jaccard similarity and cardinality),所以这里需要用到MinHashLSH,其实现了一个分桶的操作,然后对桶进行哈希来计算已insert文本之间的相似度,即可以实现query操作

  1. 首先是逐行读取文档中的文本内容,然后对每行文本进行hash,并加入到lsh中(lsh.insert(doc_id, m))。【注意:lsh = MinHashLSH(threshold=0.9, num_perm=256)minhashes = {}都是全局的】
  2. 对字典进行操作,如果取该文本则标记为1,反之为0。这里有三种算法
    • 其一【随机算法】,首先标记所有文本都取,然后查询相似集合随机取集合中的一个文本,再标记其余文本不取。
    • 其二【簇算法】,将相似的所有数据标为一个簇,然后对簇使用随机算法。
    • 其三【最小标号算法】 ,每次取集合中下标最小的数据。即如果当前文本是相似集合中最小的文本则取,否则不取。该算法的优势是无需全局的字典,只需要考虑当前的数据即可。
  3. 最后,实现将去重后的文本写入文件,和写入哪些是被去重的文档进行结果对比!

下面看代码!

代码的实现【注释丰富】

from datasketch import MinHash, MinHashLSH
import pickle
import tarfile
import os
import re
from simhash import Simhash
import json
from datetime import datetime
import numpy as np
from collections import defaultdict
import nltk
from nltk.util import ngrams
from unidecode import unidecode
import jieba
import random

nltk.download('punkt')  # Download NLTK data
width = 5
hash_k = 5
max_hash_len = 0

# lsh = MinHashLSH(threshold=0.5, num_perm=9000, params=(450, 20))
lsh = MinHashLSH(threshold=0.9, num_perm=256)
minhashes = {}

# 中文数据预处理,滑动窗口为width
def preprocess(s):
    s = s.lower()   # 转为小写
    s = re.sub(r'[^\w]+', '', s)    # 去除非字母数字下划线
    return [s[i:i + width] for i in range(max(len(s) - width + 1, 1))]

# 先预处理,然后jieba分词,算minhash加入lsh
def add_to_lsh(doc_id, doc_text):
    # tokens = nltk.word_tokenize(preprocess(doc_text)) # 用于英文的
    tokens = jieba.lcut("".join(preprocess(doc_text))) # 使用jieba进行中文分词
    # while len(tokens) < 5:
    #     tokens.append('')
    # tokens = preprocess(doc_text)
    # ngram_set = set(ngrams(tokens, 5))  # 5-grams
    # m = MinHash(num_perm=9000)
    m = MinHash(num_perm=256)
    for ngram in tokens:
        m.update("".join(ngram).encode('utf8'))
    minhashes[doc_id] = m
    lsh.insert(doc_id, m)

# 读取文档的每一行,加入members,并记录minhashes
def index_minhash(num):
    # hashes = []
    members = []
    print("Starting part_02%0.3d"%(num), len(members))
    with open("/home/zikang/project/datastory/redpajama-data/data_prep/book/data/part-02%0.3d"%(num), "r") as f:
        lines = f.readlines()
        for idx, i in enumerate(lines):
            if idx % 5000 == 0:
                print("This is part_02%0.3d"%(num), idx)
            member = json.loads(i)
            members.append(member)
            try:
                if max_hash_len == 0:
                    add_to_lsh(idx, member['content'])
                else:
                    add_to_lsh(idx, member['content'][:max_hash_len])
            except:
                continue          
    # print("Finishing part_02%0.3d"%(num), len(hashes), len(members))
    print("Finishing part_02%0.3d"%(num), len(members))
    # return (hashes, members)
    return members


# 【随机算法】首先标记所有数据都取,然后随机选择一个取的集合,标记其余不取。
def get_match(n):
    value_dict = {}
    for index in range(n):
        value_dict[index] = 1
    for index in range(n):
        try:
            if value_dict[index]:
                results = lsh.query(minhashes[index])
                random_element = random.choice(results)
                value_dict[random_element] = 1
            else:
                continue
            for i in results:
                if i != random_element:
                    value_dict[i] = 0
                    lsh.remove(i)
        except:
            value_dict[index] = 0
    return value_dict

# 【簇算法】 将相似的所有数据标为一个簇,然后对簇使用随机算法
# def get_match(n):
    # value_dict = {}
    # for index in range(n):
    #     value_dict[index] = 1
    # for index in range(n):
    #     try:
    #         if value_dict[index]:
    #             list1  = lsh.query(minhashes[index])
    #             # 将列表转换为集合,并求并集
    #             set1 = set(list1 )
    #             for l in list1:
    #                 list2 = lsh.query(minhashes[l])
    #                 set2 = set(list2)
    #                 merged_set = set1.union(set2)
    #             # 将集合转换回列表
    #             merged_list = list(merged_set)
    #             random_element = random.choice(merged_list)
    #             value_dict[random_element] = 1
    #         else:
    #             continue
    #         for i in merged_list:
    #             if i != random_element:
    #                 value_dict[i] = 0
    #                 lsh.remove(i)
    #     except:
    #         value_dict[index] = 0
    # return value_dict

# 【最小标号算法】 每次取集合中下标最小的数据
# def get_match(n):
#     value_dict = {}
#     for index in range(n):
#         flag = 1
#         try:
#             results = lsh.query(minhashes[index])
#             for x in results: # 如果当前不是序号最小的则flag=0
#                 if x >= index:
#                     continue
#                 flag = 0
#                 break
#             value_dict[index] = flag
#         except:
#             value_dict[index] = flag
#     return value_dict

if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()

    parser.add_argument('-w', type=int, default=5, help='the window size')
    parser.add_argument('-num_perm', type=int, default=128, help='the number of permutation functions used in MinHash.')
    parser.add_argument('-k', type=int, default=1, help='find K nearest region')
    parser.add_argument('-l', type=int, default=90, help='the max length of the text for hashing, 0 means no limit')
    parser.add_argument('-n', type=int, default=2, help='the number of processes to run')

    args = parser.parse_args()

    num_perm = args.num_perm
    width = args.w
    hash_k = args.k
    max_hash_len = args.l
    n_jobs = args.n

    outfile = "去重后文本的存放路径/yyyy.jsonl"
    reject_file = "去重后被丢弃文本的存放路径/nnnn.jsonl"
    get_members = index_minhash(471)
    # hashes_members.extend(get_book(n_jobs))
    # print("Finish getting hashes and members!")
    print("Finish getting lsh and members!")

    # import itertools
    # hashes = list(itertools.chain(*[item[0] for item in hashes_members]))
    # print("hashes长度: ",len(hashes))
    # import itertools
    # members = list(itertools.chain(*[item for item in get_members]))
    print("members长度: ",len(get_members))

    # 打印一些相似数据
    # results = lsh.query(minhashes[0])
    # with open(reject_file, 'w') as f: # 写过滤后的数据
    #     for i in results:
    #         nnct = {"content": get_members[i]["content"]}
    #         f.write(json.dumps(nnct, ensure_ascii=False) + '\n')

    n = len(get_members)
    temp_dict = get_match(n)
    # print(temp_dict)
    with open(outfile, 'w') as f: # 写过滤后的数据
        count = 0
        for i in range(n):          
            if temp_dict[i] == 1:
                # meta = {}
                # for feature in get_members[i]:
                #     if feature != "content":
                #         meta[feature] = get_members[i][feature]
                # new = {"meta": meta, "content": mem["content"]}
                
                try:
                    new = {"content": get_members[i]["content"]}
                    count += 1
                    f.write(json.dumps(new, ensure_ascii=False) + '\n')
                except:
                    continue
        print("count:",count)
        print("去重的数量", n-count)

    with open(reject_file, 'w') as f: # 写过滤后的数据
        for i in range(n):
            if temp_dict[i] == 0:
                try:
                    nnct = {"content": get_members[i]["content"]}
                    f.write(json.dumps(nnct, ensure_ascii=False) + '\n')
                except:
                    continue

参考链接:
[1]: https://github.com/ekzhu/datasketch

你可能感兴趣的:(pytorch深度学习实战,自然语言处理,python,算法,minhash,文本去重)