- 由于模型需要使用训练集、验证集和测试集,而我只有一个总的数据集,因此用Python实现了数据集的划分,特此小记一下。
- 同时也是为了记录这个过程中用到的Python的一些知识。
数据集划分的代码我写在一个函数里面。函数为:split_data(before_dataset_filepath, output_dir,split_prop)
参数解释:
before_dataset_filepath
:待划分的数据集文件的路径,是相对路径。output_dir
:训练集、测试集、验证集文件所在的目录,是相对路径。split_prop
:划分比例的列表。这里参照了文献的划分比例,即[训练集:验证集:测试集]=[3:1:2]
。该列表共三个元素,例如[3,1,2]
,表示[训练集:验证集:测试集]=[3:1:2]
。numpy.loadtxt
函数,以字符串的形式将数据读取进来,用numpy.ndarray
进行存储。data_list = np.loadtxt(before_dataset_filepath, dtype="str", comments=None, delimiter="\n", encoding="utf-8-sig")
####确定样本总数,以及训练、验证、测试集的数量
total_num = len(data_list) - 1
print("total_num:", total_num)
train_num=int(split_prop[0]/sum(split_prop)*total_num)#使用int()相当于直接舍弃掉小数位
dev_num = int(split_prop[1] / sum(split_prop)*total_num)
test_num = total_num-train_num-dev_num
print("train_num:", train_num)
print("dev_num:", dev_num)
print("test_num:", test_num)
####我这里的示例结果:
#total_num: 10
#train_num: 5
#dev_num: 1
#test_num: 4
确定了各数据集的数量之后,就需要从待划分数据集中抽取数据了。但在这之前需要把第一行的标签行去掉。这里使用了numpy.delete()
函数。
#删除第一行的标签
data_list=numpy.delete(data_list,[0])#这里的data_list就是一个一维的数组,因此第二个参数指定为了该数组的第一个元素
这里采用了random.shuffle()
来随机打乱待划分数据集数组的索引,比较方便。
####随机打乱划分前的数据
index=list(range(0,total_num))
print("打乱前索引:",index)
random.shuffle(index)
print("打乱后索引:", index)
####我这里的示例结果:
#打乱前索引: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
#打乱后索引: [2, 3, 8, 0, 6, 9, 1, 5, 4, 7]
这里有两点需要注意:
'w'
的方式写入一个空字符串。官方文档也说了,f.truncate()
函数不常用,而且我用了之后好像也没成功~index
索引列表时候的下标范围。####读取并写入训练数据集文件
train_txt=output_dir+"train.txt"#训练数据集文件的路径
#先清空文件内容
with open(train_txt, 'w', encoding="utf-8-sig") as f1:
f1.write("")
#再写入
with open(train_txt,'a',encoding="utf-8-sig") as f1:
#先写标签
f1.write("text_a\tlabel")
f1.write("\n")
for i in range(0,train_num):
f1.write(data_list[index[i]])
f1.write("\n")
####读取并写入验证数据集文件
dev_txt = output_dir + "dev.txt"
# 先清空文件内容
with open(dev_txt, 'w', encoding="utf-8-sig") as f2:
f2.write("")
with open(dev_txt,'a',encoding="utf-8-sig") as f2:
f2.write("text_a\tlabel")
f2.write("\n")
for i in range(train_num,train_num+dev_num):
f2.write(data_list[index[i]])
f2.write("\n")
####读取并写入测试数据集文件
test_txt = output_dir + "test.txt"
# 先清空文件内容
with open(test_txt, 'w', encoding="utf-8-sig") as f3:
f3.write("")
with open(test_txt,'a',encoding="utf-8-sig") as f3:
f3.write("text_a\tlabel")
f3.write("\n")
for i in range(train_num+dev_num,total_num):
f3.write(data_list[index[i]])
f3.write("\n")
split_data()
函数def split_data(before_dataset_filepath, output_dir,split_prop):
data_list = np.loadtxt(before_dataset_filepath, dtype="str", comments=None, delimiter="\n", encoding="utf-8-sig")
print(type(data_list))
#确定样本总数,以及训练、验证、测试集的数量
total_num = len(data_list) - 1
print("total_num:", total_num)
train_num=int(split_prop[0]/sum(split_prop)*total_num)
dev_num = int(split_prop[1] / sum(split_prop)*total_num)
test_num = total_num-train_num-dev_num
print("train_num:", train_num)
print("dev_num:", dev_num)
print("test_num:", test_num)
#删除第一行的标签
data_list=numpy.delete(data_list,[0])
#随机打乱划分前的数据
index=list(range(0,total_num))
print("打乱前索引:",index)
random.shuffle(index)
print("打乱后索引:", index)
#读取并写入训练数据集文件
train_txt=output_dir+"train.txt"
#先清空文件内容
with open(train_txt, 'w', encoding="utf-8-sig") as f1:
f1.write("")
with open(train_txt,'a',encoding="utf-8-sig") as f1:
f1.write("text_a\tlabel")
f1.write("\n")
for i in range(0,train_num):
f1.write(data_list[index[i]])
f1.write("\n")
# 读取并写入验证数据集文件
dev_txt = output_dir + "dev.txt"
# 先清空文件内容
with open(dev_txt, 'w', encoding="utf-8-sig") as f2:
f2.write("")
with open(dev_txt,'a',encoding="utf-8-sig") as f2:
f2.write("text_a\tlabel")
f2.write("\n")
for i in range(train_num,train_num+dev_num):
f2.write(data_list[index[i]])
f2.write("\n")
# 读取并写入测试数据集文件
test_txt = output_dir + "test.txt"
# 先清空文件内容
with open(test_txt, 'w', encoding="utf-8-sig") as f3:
f3.write("")
with open(test_txt,'a',encoding="utf-8-sig") as f3:
f3.write("text_a\tlabel")
f3.write("\n")
for i in range(train_num+dev_num,total_num):
f3.write(data_list[index[i]])
f3.write("\n")
####相关参数
PROC_DATA_DIR = "data/processed_data/"
TRAIN_DATA_FILENAME = "train_data.txt"
TRAIN_DATA_FILEPATH = os.path.join(PROC_DATA_DIR, TRAIN_DATA_FILENAME)
out_put_dir="data/split_data/"
split_prop=[3,1,2]#[训练集,验证集,测试集]
####调用
split_data(TRAIN_DATA_FILEPATH,out_put_dir,split_prop)
TRAIN_DATA_FILEPATH = os.path.join(PROC_DATA_DIR, TRAIN_DATA_FILENAME)
out_put_dir="data/split_data/"
split_prop=[3,1,2]#[训练集,验证集,测试集]
####调用
split_data(TRAIN_DATA_FILEPATH,out_put_dir,split_prop)