一、问题描述:
我们需要对数据进行划分为三个:
训练集(train):猫 狗
验证集(validation):猫 狗
测试集(test):猫 狗
二、解决问题:
1.首先创建一个文件接受原数据:data_dir="D:/暑假/data/cats_and_dogs/tra"
再创建cats_and_dogs_small文件夹:
代码:
import os
import shutil
base_dir='D:/暑假/data/cats_and_dogs_small'
os.mkdir(base_dir)
2.在上一个文件夹中创建train、validation、test。
代码:
train_dir=os.path.join(base_dir,'train')
os.mkdir(train_dir)
validation_dir=os.path.join(base_dir,'validation')
os.mkdir(validation_dir)
test_dir=os.path.join(base_dir,'test')
os.mkdir(test_dir)
结果:
3.再分别在上三个文件夹中创建cats和dogs。
代码:
train_cats_dir=os.path.join(train_dir,'cats')
os.mkdir(train_cats_dir)
train_dogs_dir=os.path.join(train_dir,'dogs')
os.mkdir(train_dogs_dir)
validation_cats_dir=os.path.join(validation_dir,'cats')
os.mkdir(validation_cats_dir)
validation_dogs_dir=os.path.join(validation_dir,'dogs')
os.mkdir(validation_dogs_dir)
test_cats_dir=os.path.join(test_dir,'cats')
os.mkdir(test_cats_dir)
test_dogs_dir=os.path.join(test_dir,'dogs')
os.mkdir(test_dogs_dir)
4.数据集包含25000张猫狗图像(每个类别具有12500张)我们可以将其划分为训练8000张,验证2500张,测试2000张。
我们观察一下他给的图片格式:都是cat/dog.{}.jpg
代码如下:
name=['cat.{}.jpg'.format(i) for i in range(8000)]
for na in name:
scr=os.path.join(data_dir,na)
dst=os.path.join(train_cats_dir,na)
shutil.copyfile(scr,dst)
name=['cat.{}.jpg'.format(i) for i in range(8000,10500)]
for na in name:
scr=os.path.join(data_dir,na)
dst=os.path.join(validation_cats_dir,na)
shutil.copyfile(scr,dst)
name=['cat.{}.jpg'.format(i) for i in range(10500,12500)]
for na in name:
scr=os.path.join(data_dir,na)
dst=os.path.join(test_cats_dir,na)
shutil.copyfile(scr,dst)
name=['dog.{}.jpg'.format(i) for i in range(8000)]
for na in name:
scr=os.path.join(data_dir,na)
dst=os.path.join(train_dogs_dir,na)
shutil.copyfile(scr,dst)
name=['dog.{}.jpg'.format(i) for i in range(8000,10500)]
for na in name:
scr=os.path.join(data_dir,na)
dst=os.path.join(validation_dogs_dir,na)
shutil.copyfile(scr,dst)
name=['dog.{}.jpg'.format(i) for i in range(10500,12500)]
for na in name:
scr=os.path.join(data_dir,na)
dst=os.path.join(test_dogs_dir,na)
shutil.copyfile(scr,dst)
检查一下:
print('total train cat images:',len(os.listdir(train_cats_dir)))
print('total train dog images:',len(os.listdir(train_dogs_dir)))
print('total validation cat images:',len(os.listdir(validation_cats_dir)))
print('total validation dog images:',len(os.listdir(validation_dogs_dir)))
print('total test cat images:',len(os.listdir(test_cats_dir)))
print('total test dog images:',len(os.listdir(test_dogs_dir)))