数据集下载地址:http://thuctc.thunlp.org/
本文主要记录了清华的THUCNEWS数据集用于文本分类的数据预处理方法,比较简单直接上代码
#! usr/bin/env python3
# -*- coding:utf-8 -*-
"""
清华大学的文本分类数据集的处理
@Author:MaCan
@Time:2019/9/17 11:40
@File:thunews_process.py
@Mail:[email protected]
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import re
import pickle
def load_file(path):
class_file_num = {}
with open('thu_news_3w_v1.txt', 'w', encoding='utf-8') as fw:
for sub_dir in os.listdir(path):
print(sub_dir)
if sub_dir in ['星座']:
continue
curr_dir = os.path.join(path, sub_dir)
if os.path.isdir(curr_dir):
for file in os.listdir(curr_dir):
curr_file = os.path.join(curr_dir, file)
# print(curr_file)
context = ''
with open(curr_file, 'r', encoding='utf-8') as fd:
idx = 0
for line in fd:
if idx == 0 and len(line) < 30:
context += line + '|||'
else:
context += line
idx += 1
if len(context) < 10:
continue
context = re.sub('(\ud83c[\udf00-\udfff])|(\ud83d[\udc00-\ude4f\ude80-\udeff])|[\u2600-\u2B55]', '',
context.strip())
context = re.sub('[\n\r\t \\\]+', '', context)
context = '__' + sub_dir + '__\t' + context
# print(context)
fw.write(context + '\n')
class_file_num[sub_dir] = class_file_num.get(sub_dir, 0) + 1
if class_file_num.get(sub_dir) > 30000:
break
print(class_file_num)
def cut_sentence(line, max_seq_len, mode):
"""
文本截断
:param line:
:param mode 截断方式:[扔掉末尾,扔掉开始,title+top+tail, top+tail-title]
:return:
"""
line = line.strip()
if mode == 0:
if len(line) > max_seq_len:
line = line[0:max_seq_len]
return line
elif mode == 1:
if len(line) > max_seq_len:
line = line[:-max_seq_len]
return line
else:
line = line.split('|||')
context = ''
if len(line) >= 2:#title+top+tail
if mode == 2:
context += line[0]
# elif mode == 3:
# context += line[1:]
line = ''.join(line[1:])
else:
line = ''.join(line)
# 按照标点符号进行句子切割
sentences = re.split(r"([.。!!??;;,,\s+])", line)
sentences.append("")
sentences = ["".join(i) for i in zip(sentences[0::2], sentences[1::2])]
head = []
tail = []
if mode == 2:
head.append(context)
def check_curr_sen_len():
if len(''.join(head)) + len(''.join(tail)) > max_seq_len:
return True
return False
for idx in range(len(sentences)):
if check_curr_sen_len():
break
head.append(sentences[idx])
if check_curr_sen_len():
break
tail.insert(0, sentences[-(idx+1)])
# merge
return ''.join(head) + ''.join(tail)
# def test_load_file():
# load_file(path)
if __name__ == '__main__':
path = 'F:\THUCNews\THUCNews'
load_file(path)
mode = 2
if mode == 0:
with open('thu_news_3w_v2_cutmode_3.txt', 'w', encoding='utf-8') as fw:
with open('thu_news_3w_v2.txt', 'r', encoding='utf-8') as fd:
for line in fd:
line = line.split('\t')
new_context = cut_sentence(''.join(line[1:]), 200, 3)
fw.write(line[0] + '\t' + new_context + '\n')
elif mode == 1:
datas = []
labels = []
with open('thu_news_3w_v1_cutmode_2.txt', 'r', encoding='utf-8') as fd:
for line in fd:
line = line.split('\t')
if len(line) == 2:
labels.append(line[0])
datas.append(line[1])
from sklearn.model_selection import train_test_split
print('all data size:{}, {}'.format(len(datas), len(labels)))
train_x, last_x, train_y, last_y = train_test_split(datas, labels, test_size=0.4)
dev_x, test_x, dev_y, test_y = train_test_split(last_x, last_y, test_size=0.5)
def save_list(datas ,labels, path):
print(path + ' data size:{}, {}'.format(len(datas), len(labels)))
with open(path, 'w', encoding='utf-8') as fd:
for data, label in zip(datas, labels):
fd.write(label + '\t' + data)
save_list(train_x, train_y, 'train.txt')
save_list(dev_x, dev_y, 'dev.txt')
save_list(test_x, test_y, 'test.txt')
elif mode == 2:
datas = []
labels = []
with open('thu_news_3w_v1_cutmode_2.txt', 'r', encoding='utf-8') as fd:
for line in fd:
line = line.split('\t')
if len(line) == 2:
label = re.sub('[_]+', '', line[0])
if label in ['体育', '财经'] and len(labels) < 10000:
labels.append(line[0])
datas.append(line[1])
from sklearn.model_selection import train_test_split
print('all data size:{}, {}'.format(len(datas), len(labels)))
train_x, last_x, train_y, last_y = train_test_split(datas, labels, test_size=0.4)
dev_x, test_x, dev_y, test_y = train_test_split(last_x, last_y, test_size=0.5)
def save_list(datas ,labels, path):
print(path + ' data size:{}, {}'.format(len(datas), len(labels)))
with open(path, 'w', encoding='utf-8') as fd:
for data, label in zip(datas, labels):
fd.write(label + '\t' + data)
save_list(train_x, train_y, 'mini_train.txt')
save_list(dev_x, dev_y, 'mini_dev.txt')
save_list(test_x, test_y, 'mini_test.txt')
在main中,mode:0 表示将文本进行切分,具体的切分方法在cut_sentence中,mode:1 表示将数据进行训练集,验证集合测试集的划分。