大规模的文本去重是目前比较热门的一个技术,由于大模型的兴起,更多的高质量数据集也是大家迫切需要的。
关于如何进行文本去重?
直观的方法首先是利用Python正则表达式进行去重。
推荐学习:1. re — 正则表达式操作 2. 正则表达式 - 教程
然后是利用文本之间的相似度进行去重。
这里主要讲第二种。
推荐学习:1. 张振虎大佬的博客 2. Github的实现源码 3. 文本内容相似度计算方法:minhash 4. Python的datasketch库中的MinHashLSH
ok!这里不多赘述关于MinHash的更多基础知识,感兴趣的朋友相信在网上也能找到。
Talk is cheap. Show me the code.
然后在看代码之前说一下算法的实现流程也是非常有帮助的
我们要实现去重的话,需要实现一个query操作
,即给定一个文本计算出数据中有哪些是与该文本相似的。然后在对该集合进行相应的操作进行去重。而这里MinHash
只是用来估计两个文本之间的 Jaccard 相似度和基数(estimate Jaccard similarity and cardinality),所以这里需要用到MinHashLSH
,其实现了一个分桶的操作,然后对桶进行哈希来计算已insert
文本之间的相似度,即可以实现query操作
。
lsh.insert(doc_id, m)
)。【注意:lsh = MinHashLSH(threshold=0.9, num_perm=256)
和minhashes = {}
都是全局的】下面看代码!
from datasketch import MinHash, MinHashLSH
import pickle
import tarfile
import os
import re
from simhash import Simhash
import json
from datetime import datetime
import numpy as np
from collections import defaultdict
import nltk
from nltk.util import ngrams
from unidecode import unidecode
import jieba
import random
nltk.download('punkt') # Download NLTK data
width = 5
hash_k = 5
max_hash_len = 0
# lsh = MinHashLSH(threshold=0.5, num_perm=9000, params=(450, 20))
lsh = MinHashLSH(threshold=0.9, num_perm=256)
minhashes = {}
# 中文数据预处理,滑动窗口为width
def preprocess(s):
s = s.lower() # 转为小写
s = re.sub(r'[^\w]+', '', s) # 去除非字母数字下划线
return [s[i:i + width] for i in range(max(len(s) - width + 1, 1))]
# 先预处理,然后jieba分词,算minhash加入lsh
def add_to_lsh(doc_id, doc_text):
# tokens = nltk.word_tokenize(preprocess(doc_text)) # 用于英文的
tokens = jieba.lcut("".join(preprocess(doc_text))) # 使用jieba进行中文分词
# while len(tokens) < 5:
# tokens.append('' )
# tokens = preprocess(doc_text)
# ngram_set = set(ngrams(tokens, 5)) # 5-grams
# m = MinHash(num_perm=9000)
m = MinHash(num_perm=256)
for ngram in tokens:
m.update("".join(ngram).encode('utf8'))
minhashes[doc_id] = m
lsh.insert(doc_id, m)
# 读取文档的每一行,加入members,并记录minhashes
def index_minhash(num):
# hashes = []
members = []
print("Starting part_02%0.3d"%(num), len(members))
with open("/home/zikang/project/datastory/redpajama-data/data_prep/book/data/part-02%0.3d"%(num), "r") as f:
lines = f.readlines()
for idx, i in enumerate(lines):
if idx % 5000 == 0:
print("This is part_02%0.3d"%(num), idx)
member = json.loads(i)
members.append(member)
try:
if max_hash_len == 0:
add_to_lsh(idx, member['content'])
else:
add_to_lsh(idx, member['content'][:max_hash_len])
except:
continue
# print("Finishing part_02%0.3d"%(num), len(hashes), len(members))
print("Finishing part_02%0.3d"%(num), len(members))
# return (hashes, members)
return members
# 【随机算法】首先标记所有数据都取,然后随机选择一个取的集合,标记其余不取。
def get_match(n):
value_dict = {}
for index in range(n):
value_dict[index] = 1
for index in range(n):
try:
if value_dict[index]:
results = lsh.query(minhashes[index])
random_element = random.choice(results)
value_dict[random_element] = 1
else:
continue
for i in results:
if i != random_element:
value_dict[i] = 0
lsh.remove(i)
except:
value_dict[index] = 0
return value_dict
# 【簇算法】 将相似的所有数据标为一个簇,然后对簇使用随机算法
# def get_match(n):
# value_dict = {}
# for index in range(n):
# value_dict[index] = 1
# for index in range(n):
# try:
# if value_dict[index]:
# list1 = lsh.query(minhashes[index])
# # 将列表转换为集合,并求并集
# set1 = set(list1 )
# for l in list1:
# list2 = lsh.query(minhashes[l])
# set2 = set(list2)
# merged_set = set1.union(set2)
# # 将集合转换回列表
# merged_list = list(merged_set)
# random_element = random.choice(merged_list)
# value_dict[random_element] = 1
# else:
# continue
# for i in merged_list:
# if i != random_element:
# value_dict[i] = 0
# lsh.remove(i)
# except:
# value_dict[index] = 0
# return value_dict
# 【最小标号算法】 每次取集合中下标最小的数据
# def get_match(n):
# value_dict = {}
# for index in range(n):
# flag = 1
# try:
# results = lsh.query(minhashes[index])
# for x in results: # 如果当前不是序号最小的则flag=0
# if x >= index:
# continue
# flag = 0
# break
# value_dict[index] = flag
# except:
# value_dict[index] = flag
# return value_dict
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('-w', type=int, default=5, help='the window size')
parser.add_argument('-num_perm', type=int, default=128, help='the number of permutation functions used in MinHash.')
parser.add_argument('-k', type=int, default=1, help='find K nearest region')
parser.add_argument('-l', type=int, default=90, help='the max length of the text for hashing, 0 means no limit')
parser.add_argument('-n', type=int, default=2, help='the number of processes to run')
args = parser.parse_args()
num_perm = args.num_perm
width = args.w
hash_k = args.k
max_hash_len = args.l
n_jobs = args.n
outfile = "去重后文本的存放路径/yyyy.jsonl"
reject_file = "去重后被丢弃文本的存放路径/nnnn.jsonl"
get_members = index_minhash(471)
# hashes_members.extend(get_book(n_jobs))
# print("Finish getting hashes and members!")
print("Finish getting lsh and members!")
# import itertools
# hashes = list(itertools.chain(*[item[0] for item in hashes_members]))
# print("hashes长度: ",len(hashes))
# import itertools
# members = list(itertools.chain(*[item for item in get_members]))
print("members长度: ",len(get_members))
# 打印一些相似数据
# results = lsh.query(minhashes[0])
# with open(reject_file, 'w') as f: # 写过滤后的数据
# for i in results:
# nnct = {"content": get_members[i]["content"]}
# f.write(json.dumps(nnct, ensure_ascii=False) + '\n')
n = len(get_members)
temp_dict = get_match(n)
# print(temp_dict)
with open(outfile, 'w') as f: # 写过滤后的数据
count = 0
for i in range(n):
if temp_dict[i] == 1:
# meta = {}
# for feature in get_members[i]:
# if feature != "content":
# meta[feature] = get_members[i][feature]
# new = {"meta": meta, "content": mem["content"]}
try:
new = {"content": get_members[i]["content"]}
count += 1
f.write(json.dumps(new, ensure_ascii=False) + '\n')
except:
continue
print("count:",count)
print("去重的数量", n-count)
with open(reject_file, 'w') as f: # 写过滤后的数据
for i in range(n):
if temp_dict[i] == 0:
try:
nnct = {"content": get_members[i]["content"]}
f.write(json.dumps(nnct, ensure_ascii=False) + '\n')
except:
continue
参考链接:
[1]: https://github.com/ekzhu/datasketch