狗猫分类数据集划分详解

数据集介绍

首先是要下载数据集,下载地址:https://www.kaggle.com/c/dogs-vs-cats-redux-kernels-edition

数据解压之后会有两个文件夹,一个是“train”,一个是“test”,顾名思义一个是用来训练的,另一个是作为检验正确性的数据,也是网站要求提交标签的。

狗猫分类数据集划分详解_第1张图片

在train文件夹里边是一些已经命名好的图像,有猫也有狗

狗猫分类数据集划分详解_第2张图片

而在test文件夹中是只有编号名的图像

狗猫分类数据集划分详解_第3张图片

大致了解了数据集后,下边就开始划分数据集

代码

先放一段代码,这是从书中截取出来的:

# coding:utf8
import os
from PIL import Image
from torch.utils import data
import numpy as np
from torchvision import transforms as T


class DogCat(data.Dataset):

    def __init__(self, root, transforms=None, train=True, test=False):
        """
        主要目标: 获取所有图片的地址,并根据训练,验证,测试划分数据
        """
        self.test = test
        imgs = [os.path.join(root, img) for img in os.listdir(root)]

        # test1: data/test1/8973.jpg
        # train: data/train/cat.10004.jpg 
        if self.test:
            imgs = sorted(imgs, key=lambda x: int(x.split('.')[-2].split('/')[-1]))
        else:
            imgs = sorted(imgs, key=lambda x: int(x.split('.')[-2]))

        imgs_num = len(imgs)

        if self.test:
            self.imgs = imgs
        elif train:
            self.imgs = imgs[:int(0.7 * imgs_num)]
        else:
            self.imgs = imgs[int(0.7 * imgs_num):]

        if transforms is None:
            normalize = T.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])

            if self.test or not train:
                self.transforms = T.Compose([
                    T.Resize(224),
                    T.CenterCrop(224),
                    T.ToTensor(),
                    normalize
                ])
            else:
                self.transforms = T.Compose([
                    T.Resize(256),
                    T.RandomReSizedCrop(224),
                    T.RandomHorizontalFlip(),
                    T.ToTensor(),
                    normalize
                ])

    def __getitem__(self, index):
        """
        一次返回一张图片的数据
        """
        img_path = self.imgs[index]
        if self.test:
            label = int(self.imgs[index].split('.')[-2].split('/')[-1])
        else:
            label = 1 if 'dog' in img_path.split('/')[-1] else 0
        data = Image.open(img_path)
        data = self.transforms(data)
        return data, label

    def __len__(self):
        return len(self.imgs)

详解

这里建立了一个类,继承自data.Dataset,里边有三个方法是必须重写的:

class DogCat(data.Dataset):

    def __init__(self, root, transforms=None, train=True, test=False):
        """
        主要目标: 获取所有图片的地址,并根据训练,验证,测试划分数据
        """
        #这个__init__方法是初始化,里边可以对数据进行一些预处理

    def __getitem__(self, index):
        """
        一次返回一张图片的数据
        """
        #__getitem__方法是迭代器需要,当读取数据集的时候就会调用__getitem__方法,
        #一次读取一张照片,因此,这里主要实现返回图像与标签的功能

    def __len__(self):
        #这个函数的目的是返回数据集大小,也是必不可少的部分

下面开始解释每个方法中语句的功能

    def __init__(self, root, transforms=None, train=True, test=False):
       
        #root是根目录,用来存放数据
        #transforms是对图像做出转换
        #train和test是标记
        self.test = test

        #os.listdir(root)获取root目录下所有文件名
        imgs = [os.path.join(root, img) for img in os.listdir(root)]

        #根据测试集与训练集图片命名不同进行不同的划分
        if self.test:
            imgs = sorted(imgs, key=lambda x: int(x.split('.')[-2].split('/')[-1]))
        else:
            imgs = sorted(imgs, key=lambda x: int(x.split('.')[-2]))
        
        #获取图像数量
        imgs_num = len(imgs)

        #将test文件夹中图像作为测试集
        if self.test:
            self.imgs = imgs
        #将训练集70%作为训练集
        elif train:
            self.imgs = imgs[:int(0.7 * imgs_num)]
        #将训练集30%作为验证集
        else:
            self.imgs = imgs[int(0.7 * imgs_num):]
        #下边对图像做变换
        if transforms is None:
            normalize = T.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])

            if self.test or not train:
                self.transforms = T.Compose([
                    T.Resize(224),
                    T.CenterCrop(224),
                    T.ToTensor(),
                    normalize
                ])
            else:
                self.transforms = T.Compose([
                    T.Resize(256),
                    T.RandomReSizedCrop(224),
                    T.RandomHorizontalFlip(),
                    T.ToTensor(),
                    normalize
                ])
    def __getitem__(self, index):
        """
        一次返回一张图片的数据
        """
        #根据下标获取标签
        img_path = self.imgs[index]
        if self.test:
            label = int(self.imgs[index].split('.')[-2].split('/')[-1])
        else:
            label = 1 if 'dog' in img_path.split('/')[-1] else 0
        data = Image.open(img_path)
        data = self.transforms(data)
        #返回图像数据与标签
        return data, label
    def __len__(self):
        #返回数据集长度
        return len(self.imgs)

到此位置,数据集的划分与数据类已经完成

 

完整训练过程可以看我另一篇博客:

https://blog.csdn.net/qq_41685265/article/details/104898848

你可能感兴趣的:(pytorch)