Python将数据集分成train/validation/test,并平衡数据

1. 将数据集分成train/validation/test

比例 7:2:1

import os, random, shutil
#-*- coding: UTF-8 -*-
import os
import  random
import shutil

def eachFile(filepath):
    name_list = []
    pathDir =  os.listdir(filepath)
    return pathDir 
 
def divideTrainValiTest(source,dist,pos_or_neg):
    pic_name=eachFile1(source)
    random.shuffle(pic_name)
    train_list = pic_name[0:int(0.7*len(pic_name))]
    validation_list = pic_name[int(0.7*len(pic_name)):int(0.9*len(pic_name))]
    test_list = pic_name[int(0.9*len(pic_name)):]
    for train_pic in train_list:              
        shutil.move(source+'/'+train_pic, dist+'/train/'+pos_or_neg+'/'+train_pic)   
    for validation_pic in validation_list:
        shutil.move(source+'/'+validation_pic, dist+'/validation/'+pos_or_neg+'/'+validation_pic)
    
    for test_pic in test_list:
        shutil.move(source+'/'+test_pic, dist+'/test/'+pos_or_neg+'/'+test_pic)
    return
    
 
if __name__ == '__main__':
    filepath = r'/your_path/raw_data'
    dist = r'/your_path/data'
    divideTrainValiTest(filepath,dist,'pos')

2. 平衡数据

我的负样本比较多,所以做法比较粗暴

def balancePosNeg(filepath):
    pic_name = os.listdir(filepath+'/neg')
    l_pos = len(os.listdir(filepath+'/pos'))
    random.shuffle(pic_name)
    for pic in pic_name[l_pos:]:
        os.remove(os.path.join(filepath+'/neg/'+pic))

你可能感兴趣的:(code)