首先是要下载数据集,下载地址:https://www.kaggle.com/c/dogs-vs-cats-redux-kernels-edition
数据解压之后会有两个文件夹,一个是“train”,一个是“test”,顾名思义一个是用来训练的,另一个是作为检验正确性的数据,也是网站要求提交标签的。
在train文件夹里边是一些已经命名好的图像,有猫也有狗
而在test文件夹中是只有编号名的图像
大致了解了数据集后,下边就开始划分数据集
先放一段代码,这是从书中截取出来的:
# coding:utf8
import os
from PIL import Image
from torch.utils import data
import numpy as np
from torchvision import transforms as T
class DogCat(data.Dataset):
def __init__(self, root, transforms=None, train=True, test=False):
"""
主要目标: 获取所有图片的地址,并根据训练,验证,测试划分数据
"""
self.test = test
imgs = [os.path.join(root, img) for img in os.listdir(root)]
# test1: data/test1/8973.jpg
# train: data/train/cat.10004.jpg
if self.test:
imgs = sorted(imgs, key=lambda x: int(x.split('.')[-2].split('/')[-1]))
else:
imgs = sorted(imgs, key=lambda x: int(x.split('.')[-2]))
imgs_num = len(imgs)
if self.test:
self.imgs = imgs
elif train:
self.imgs = imgs[:int(0.7 * imgs_num)]
else:
self.imgs = imgs[int(0.7 * imgs_num):]
if transforms is None:
normalize = T.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
if self.test or not train:
self.transforms = T.Compose([
T.Resize(224),
T.CenterCrop(224),
T.ToTensor(),
normalize
])
else:
self.transforms = T.Compose([
T.Resize(256),
T.RandomReSizedCrop(224),
T.RandomHorizontalFlip(),
T.ToTensor(),
normalize
])
def __getitem__(self, index):
"""
一次返回一张图片的数据
"""
img_path = self.imgs[index]
if self.test:
label = int(self.imgs[index].split('.')[-2].split('/')[-1])
else:
label = 1 if 'dog' in img_path.split('/')[-1] else 0
data = Image.open(img_path)
data = self.transforms(data)
return data, label
def __len__(self):
return len(self.imgs)
这里建立了一个类,继承自data.Dataset,里边有三个方法是必须重写的:
class DogCat(data.Dataset):
def __init__(self, root, transforms=None, train=True, test=False):
"""
主要目标: 获取所有图片的地址,并根据训练,验证,测试划分数据
"""
#这个__init__方法是初始化,里边可以对数据进行一些预处理
def __getitem__(self, index):
"""
一次返回一张图片的数据
"""
#__getitem__方法是迭代器需要,当读取数据集的时候就会调用__getitem__方法,
#一次读取一张照片,因此,这里主要实现返回图像与标签的功能
def __len__(self):
#这个函数的目的是返回数据集大小,也是必不可少的部分
下面开始解释每个方法中语句的功能
def __init__(self, root, transforms=None, train=True, test=False):
#root是根目录,用来存放数据
#transforms是对图像做出转换
#train和test是标记
self.test = test
#os.listdir(root)获取root目录下所有文件名
imgs = [os.path.join(root, img) for img in os.listdir(root)]
#根据测试集与训练集图片命名不同进行不同的划分
if self.test:
imgs = sorted(imgs, key=lambda x: int(x.split('.')[-2].split('/')[-1]))
else:
imgs = sorted(imgs, key=lambda x: int(x.split('.')[-2]))
#获取图像数量
imgs_num = len(imgs)
#将test文件夹中图像作为测试集
if self.test:
self.imgs = imgs
#将训练集70%作为训练集
elif train:
self.imgs = imgs[:int(0.7 * imgs_num)]
#将训练集30%作为验证集
else:
self.imgs = imgs[int(0.7 * imgs_num):]
#下边对图像做变换
if transforms is None:
normalize = T.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
if self.test or not train:
self.transforms = T.Compose([
T.Resize(224),
T.CenterCrop(224),
T.ToTensor(),
normalize
])
else:
self.transforms = T.Compose([
T.Resize(256),
T.RandomReSizedCrop(224),
T.RandomHorizontalFlip(),
T.ToTensor(),
normalize
])
def __getitem__(self, index):
"""
一次返回一张图片的数据
"""
#根据下标获取标签
img_path = self.imgs[index]
if self.test:
label = int(self.imgs[index].split('.')[-2].split('/')[-1])
else:
label = 1 if 'dog' in img_path.split('/')[-1] else 0
data = Image.open(img_path)
data = self.transforms(data)
#返回图像数据与标签
return data, label
def __len__(self):
#返回数据集长度
return len(self.imgs)
到此位置,数据集的划分与数据类已经完成
完整训练过程可以看我另一篇博客:
https://blog.csdn.net/qq_41685265/article/details/104898848