# -*- coding: utf-8 -*-
import csv
import os
import pandas as pd
home_path = 'D:\\工作文件\\项目\\分割语料测试'
data_set_path = os.path.join(home_path,'acd_simple_data.csv')
# total_len = len(open(data_set_path, 'r', encoding='utf-8').readlines()) # csv文件行数
total_len = 157271
per_train = 80
per_eval = 90
train_row = round(total_len * per_train / 100)
eval_row = round(total_len * per_eval / 100)
def split_csv(path, total_len, per_train, per_eval):
# 如果train.csv,vali.csv,test.csv存在就删除
if os.path.exists('D:\\工作文件\\项目\\分割语料测试\\train.csv'):
os.remove('D:\\工作文件\\项目\\分割语料测试\\train.csv')
if os.path.exists('D:\\工作文件\\项目\\分割语料测试\\eval.csv'):
os.remove('D:\\工作文件\\项目\\分割语料测试\\eval.csv')
if os.path.exists('D:\\工作文件\\项目\\分割语料测试\\test.csv'):
os.remove('D:\\工作文件\\项目\\分割语料测试\\test.csv')
with open(path, 'r', newline='', encoding='utf-8') as file:
csvreader = csv.reader(file)
i = 0
train_row = round(total_len * per_train / 100)
eval_row = round(total_len * per_eval / 100)
print("训练集长度是: " + str(train_row))
print("验证集长度是: " + str(eval_row - train_row))
print("测试集长度是: " + str(total_len - eval_row))
for row in csvreader:
if i < round(total_len * per_train / 100):
# train.csv存放路径
csv_path_train = os.path.join(home_path, 'train.csv')
# print(csv_path)
# 不存在此文件的时候,就创建
if not os.path.exists(csv_path_train):
with open(csv_path_train, 'w', newline='', encoding='utf-8') as file:
csvwriter = csv.writer(file)
csvwriter.writerow(row)
i += 1
# 存在的时候就往里面添加
else:
with open(csv_path_train, 'a', newline='', encoding='utf-8') as file:
csvwriter = csv.writer(file)
csvwriter.writerow(row)
i += 1
elif (i >= round(total_len * per_train / 100)) and (i < round(total_len * per_eval / 100)):
# eval.csv存放路径
# count_1 = i
# print("上个数据集切割了:" + str(count_1))
csv_path_eval = os.path.join(home_path, 'eval.csv')
# print(csv_path)
# 不存在此文件的时候,就创建
if not os.path.exists(csv_path_eval):
with open(csv_path_eval, 'w', newline='', encoding='utf-8') as file:
csvwriter = csv.writer(file)
csvwriter.writerow(row)
i += 1
# 存在的时候就往里面添加
else:
with open(csv_path_eval, 'a', newline='', encoding='utf-8') as file:
csvwriter = csv.writer(file)
csvwriter.writerow(row)
i += 1
elif (i >= round(total_len * per_eval / 100)) and (i < total_len):
# count_2 = i - count_1
# print("上个数据集切割了:" + str(count_2))
# vali.csv存放路径
csv_path_test = os.path.join(home_path, 'test.csv')
# print(csv_path)
# 不存在此文件的时候,就创建
if not os.path.exists(csv_path_test):
with open(csv_path_test, 'w', newline='', encoding='utf-8') as file:
csvwriter = csv.writer(file)
csvwriter.writerow(row)
i += 1
# 存在的时候就往里面添加
else:
with open(csv_path_test, 'a', newline='', encoding='utf-8') as file:
csvwriter = csv.writer(file)
csvwriter.writerow(row)
i += 1
else:
break
print("数据集分离成功")
return
def split_csv_by_panda():
data = pd.read_csv(data_set_path)
print("总行数: " +str(data.shape[0]))
# 每个excel保存3万行,那么530000+数据需要18个.csv文档保存
train_data = data.iloc[0:train_row]
file_name_train =os.path.join(home_path,'train_data.csv') # 保存文件路径以及文件名称
train_data.to_csv(file_name_train, index=False) # 保存格式为.csv,如果是xlsx则修改为save_data.to_excel
eval_data = data.iloc[train_row:eval_row]
file_name_eval = os.path.join(home_path, 'eval_data.csv')
eval_data.to_csv(file_name_eval, index=False) # 保存格式为.csv,如果是xlsx则修改为save_data.to_excel
test_data = data.iloc[eval_row:total_len]
file_name_test = os.path.join(home_path, 'test_data.csv')
test_data.to_csv(file_name_test, index=False) # 保存格式为.csv,如果是xlsx则修改为save_data.to_excel
if __name__ == '__main__':
path = 'D:\\工作文件\\项目\\分割语料测试\\acd_simple_data.csv'
print("文件总条数: " + str(total_len))
split_csv(path, total_len, per_train, per_eval)
# split_csv_by_panda()