划分交叉验证集

进行交叉validation时,划分交叉验证集,此处以五折交叉验证集的划分为例子(python实现)

# -*- coding: utf-8 -*-
import os,os.path as op
import numpy as np
import random
train_txt = '/media/dell/dell/data/huawei_remotesensing/train/train1.txt'
val_txt   = '/media/dell/dell/data/huawei_remotesensing/val/val.txt'
txt_tlpt  = '/home/dell/Desktop/train/folds_split/{}_{}.txt'

def GetLines(txt_path):
    return [line for line in open(txt_path).readlines() if len(line)>0]

TrainList = GetLines(train_txt)
ValList   = GetLines(val_txt)
AllList   = TrainList+ValList
random.shuffle(AllList)
def WriteTxt(List,txt_path):
    f = open(txt_path,'w+')
    for line in List:
        f.writelines(line)
    if op.exists(txt_path):
        print('{}\t{}'.format(txt_path,len(List)))

def ReAllocation(AllList,fold=5):
    unit = int(len(AllList)/fold)
    for i in range(fold):
        j = fold-i-1
        TrainList = AllList[:unit*j]+AllList[unit*(j+1):]
        ValList   = AllList[unit*j:unit*(j+1)]
        WriteTxt(TrainList,txt_tlpt.format('train',i+1))
        WriteTxt(ValList,txt_tlpt.format('val',i+1))

ReAllocation(AllList)

你可能感兴趣的:(python,sklearn,pytorch)