python3-11.划分数据集为训练集和测试集

import os
import random
import shutil

if __name__=='__main__':
    className = ['daisy','dandelion','roses','sunflowers','tulips']
    per = 0.8
    for idx,val in enumerate(className):
        filePath = 'D:\\python\\database\\flowersdata\\'+val +'\\'
        pathlist = os.listdir(filePath)
        dataSize = len(pathlist)
        a = list(range(dataSize))
        random.shuffle(a)
        trainSize = int(dataSize*per)
        testSize = dataSize - trainSize
        trainDstFile = 'D:\\python\\database\\flowersdata\\flowersdata_train\\' + val +'\\'
        if not os.path.exists(trainDstFile):
            os.makedirs(trainDstFile)
        testDstFile = 'D:\\python\\database\\flowersdata\\flowersdata_test\\' + val+'\\'
        if not os.path.exists(testDstFile):
            os.makedirs(testDstFile)
        for i in range(trainSize):
            idx = a[i]
            srcImg = filePath +pathlist[idx]
            temp = trainDstFile + pathlist[idx]
            shutil.copyfile(srcImg,temp)
        for j in range(testSize):
            testIdx = a[j+trainSize]
            src = filePath +pathlist[testIdx]
            temp = testDstFile + pathlist[testIdx]
            shutil.copyfile(src, temp)

 

你可能感兴趣的:(python笔记)