文本分类CNN源程序 github 地址:https://github.com/gaussic/text-classification-cnn-rnn
数据集链接:链接: https://pan.baidu.com/s/11yIwZkd2yI7yeUjxu6UChA 提取码: hgfx
在进行中文文本分类的过程中,克隆文件后进行测试,记录下我所遇到的问题:
如果观察run_cnn,py的代码就会发现,要想执行训练程序,需要在命令行中输入:
python run_cnn.py train
同理,要想执行测试程序,在命令行中输入:
python run_cnn.py test
如果要想对自己的数据集进行测试,只有一个txt文件,那就需要对其进行训练集,验证集和测试集的划分。
为什么要对数据集进行训练集,验证集和测试集的划分
这篇博客写的很清楚,建议读一读,然后开始我们对数据集的划分(这里train:val:test = 6:2:2)
首先导入模块,初始化训练集、验证集和测试集的列表为空
import os
import random
L_train = []
L_val = []
L_test = []
定义函数ReadFileDatas()和WriteDatasToFile(),可以方便我们读取txt的内容和将内容保存到txt文件中去,列表是不能使用write()函数的,需要先将其转换为string类型
# 读取文件中的内容,并写入列表FileNameList
def ReadFileDatas(original_filename):
FileNameList = []
file = open(original_filename, 'r+', encoding='utf-8-sig')
for line in file:
FileNameList.append(line) # 写入文件内容到列表中去
print('数据集总量:', len(FileNameList))
file.close()
return FileNameList
# 将获取的列表中的内容转为 str ,再写入到txt文件中去
# listInfo为 ReadFileDatas 的列表
def WriteDatasToFile(listInfo, new_filename):
file_handle = open(new_filename, mode='a', encoding='utf-8-sig')
for idx in range(len(listInfo)):
str = listInfo[idx] # 列表指针
str_Result = str
file_handle.write(str_Result)
file_handle.close()
print('写入 %s 文件成功.' % new_filename)
对数据集进行train:val:test = 6:2:2划分,再定义数据保存的格式
"""
将划分数据集用函数表示
划分数据集(train, val, test)的区间,(new.txt) 为随机打乱好的文件数据集
数据集列表集合
打开文件引用上一函数保存的文件
"""
def TrainValTestFile(new_filename):
# L_train = []
# L_val = []
# L_test = []
i = 0 # counter
j = 9352 # all lines
file_divide = open(new_filename, 'r', encoding='utf-8-sig')
lines = file_divide.readlines()
for line in lines:
if i < (j *0.6):
i += 1
L_train.append(line)
elif i < (j*0.8):
i += 1
L_val.append(line)
elif i < j:
i += 1
L_test.append(line)
print("总数据量:%d , 此时创建train, val, test数据集" % i)
return L_train, L_val, L_test
# 保存数据集(train, val, test)
def text_save(filename, data): #filename为写入CSV文件的路径,data为要写入数据列表
file = open(filename, 'a', encoding='utf-8-sig')
for i in range(len(data)):
s = str(data[i]).replace('[','').replace(']','') #去除[],这两行按数据不同,可以选择
# s = s.replace("'",'').replace(',','') +'\n' #去除单引号,逗号,每行末尾追加换行符
file.write(s)
file.close()
print("保存数据集(路径)成功:%s" % filename)
最后调用函数,完成训练集,测试集和验证集的划分,并保存在指定目录
# 调用函数
if __name__ == "__main__":
listFileInfo = ReadFileDatas('data.txt') # 读取文件
random.shuffle(listFileInfo) # 打乱顺序
WriteDatasToFile(listFileInfo,'new_data.txt') # 保存新的文件
# 划分数据集并保存
TrainValTestFile('new_data.txt')
text_save('./data/data_train.txt', L_train)
text_save('./data/data_val.txt', L_val)
text_save('./data/data_test.txt', L_test)