清华大学THUCNews数据集处理方法

清华大学THUCNews数据集处理方法

数据集下载地址:http://thuctc.thunlp.org/

本文主要记录了清华的THUCNEWS数据集用于文本分类的数据预处理方法,比较简单直接上代码

#! usr/bin/env python3
# -*- coding:utf-8 -*-

"""
 清华大学的文本分类数据集的处理
 @Author:MaCan
 @Time:2019/9/17 11:40
 @File:thunews_process.py
 @Mail:[email protected]
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import re
import pickle


def load_file(path):
    class_file_num = {}
    with open('thu_news_3w_v1.txt', 'w', encoding='utf-8') as fw:
        for sub_dir in os.listdir(path):
            print(sub_dir)
            if sub_dir in ['星座']:
                continue
            curr_dir = os.path.join(path, sub_dir)
            if os.path.isdir(curr_dir):
                for file in os.listdir(curr_dir):
                    curr_file = os.path.join(curr_dir, file)
                    # print(curr_file)
                    context = ''
                    with open(curr_file, 'r', encoding='utf-8') as fd:
                        idx = 0
                        for line in fd:
                            if idx == 0 and len(line) < 30:
                                context += line + '|||'
                            else:
                                context += line
                            idx += 1
                    if len(context) < 10:
                        continue
                    context = re.sub('(\ud83c[\udf00-\udfff])|(\ud83d[\udc00-\ude4f\ude80-\udeff])|[\u2600-\u2B55]', '',
                                     context.strip())
                    context = re.sub('[\n\r\t    \\\]+', '', context)
                    context = '__' + sub_dir + '__\t' + context
                    # print(context)
                    fw.write(context + '\n')
                    class_file_num[sub_dir] = class_file_num.get(sub_dir, 0) + 1
                    if class_file_num.get(sub_dir) > 30000:
                        break
    print(class_file_num)


def cut_sentence(line, max_seq_len, mode):
    """
    文本截断
    :param line:
    :param mode 截断方式:[扔掉末尾,扔掉开始,title+top+tail, top+tail-title]
    :return:
    """
    line = line.strip()
    if mode == 0:
        if len(line) > max_seq_len:
            line = line[0:max_seq_len]
        return line
    elif mode == 1:
        if len(line) > max_seq_len:
            line = line[:-max_seq_len]
        return line
    else:
        line = line.split('|||')
        context = ''
        if len(line) >= 2:#title+top+tail
            if mode == 2:
                context += line[0]
            # elif mode == 3:
            #     context += line[1:]
            line = ''.join(line[1:])
        else:
            line = ''.join(line)

        # 按照标点符号进行句子切割
        sentences = re.split(r"([.。!!??;;,,\s+])", line)
        sentences.append("")
        sentences = ["".join(i) for i in zip(sentences[0::2], sentences[1::2])]
        head = []
        tail = []
        if mode == 2:
            head.append(context)

        def check_curr_sen_len():
            if len(''.join(head)) + len(''.join(tail)) > max_seq_len:
                return True
            return False

        for idx in range(len(sentences)):
            if check_curr_sen_len():
                break
            head.append(sentences[idx])
            if check_curr_sen_len():
                break
            tail.insert(0, sentences[-(idx+1)])
        # merge
        return ''.join(head) + ''.join(tail)

# def test_load_file():
#     load_file(path)


if __name__ == '__main__':
    path = 'F:\THUCNews\THUCNews'
    load_file(path)
    mode = 2
    if mode == 0:
        with open('thu_news_3w_v2_cutmode_3.txt', 'w', encoding='utf-8') as fw:
            with open('thu_news_3w_v2.txt', 'r', encoding='utf-8') as fd:
                for line in fd:
                    line = line.split('\t')
                    new_context = cut_sentence(''.join(line[1:]), 200, 3)
                    fw.write(line[0] + '\t' + new_context + '\n')
    elif mode == 1:
        datas = []
        labels = []
        with open('thu_news_3w_v1_cutmode_2.txt', 'r', encoding='utf-8') as fd:
            for line in fd:
                line = line.split('\t')
                if len(line) == 2:
                    labels.append(line[0])
                    datas.append(line[1])
        from sklearn.model_selection import train_test_split
        print('all data size:{}, {}'.format(len(datas), len(labels)))
        train_x, last_x, train_y, last_y = train_test_split(datas, labels, test_size=0.4)
        dev_x, test_x, dev_y, test_y = train_test_split(last_x, last_y, test_size=0.5)

        def save_list(datas ,labels, path):
            print(path + ' data size:{}, {}'.format(len(datas), len(labels)))
            with open(path, 'w', encoding='utf-8') as fd:
                for data, label in zip(datas, labels):
                    fd.write(label + '\t' + data)

        save_list(train_x, train_y, 'train.txt')
        save_list(dev_x, dev_y, 'dev.txt')
        save_list(test_x, test_y, 'test.txt')

    elif mode == 2:
        datas = []
        labels = []
        with open('thu_news_3w_v1_cutmode_2.txt', 'r', encoding='utf-8') as fd:
            for line in fd:
                line = line.split('\t')
                if len(line) == 2:
                    label = re.sub('[_]+', '', line[0])
                    if label in ['体育', '财经'] and len(labels) < 10000:
                        labels.append(line[0])
                        datas.append(line[1])
        from sklearn.model_selection import train_test_split
        print('all data size:{}, {}'.format(len(datas), len(labels)))
        train_x, last_x, train_y, last_y = train_test_split(datas, labels, test_size=0.4)
        dev_x, test_x, dev_y, test_y = train_test_split(last_x, last_y, test_size=0.5)

        def save_list(datas ,labels, path):
            print(path + ' data size:{}, {}'.format(len(datas), len(labels)))
            with open(path, 'w', encoding='utf-8') as fd:
                for data, label in zip(datas, labels):
                    fd.write(label + '\t' + data)

        save_list(train_x, train_y, 'mini_train.txt')
        save_list(dev_x, dev_y, 'mini_dev.txt')
        save_list(test_x, test_y, 'mini_test.txt')

在main中,mode:0 表示将文本进行切分,具体的切分方法在cut_sentence中,mode:1 表示将数据进行训练集,验证集合测试集的划分。

你可能感兴趣的:(nlp,python)