Pytorch学习_定义自己的数据集2

文章目录

    • 1. Dataset类
    • 2. DataLoader类
    • 3. 实例
        • 代码实现:
        • 验证效果

Pytorch中定义数据集主要涉及到两个主要的类: Dataset、DataLoader。

1. Dataset类

Dataset类是Pytorch中所有数据集加载类中应该继承的父类。其中父类中的两个私有成员函数__len__、__getitem__必须被重载,否则将会触发错误提示:

Pytorch学习_定义自己的数据集2_第1张图片

其中__len__应该返回数据集的大小,而__getitem__实现可以通过索引来返回图像数据的功能。

我们要定义自己的数据集类,首先继承上面的Dataset类,然后在__init__()方法中对数据集进行整理,得到图像的路径,给图片打标签,划分数据集等。

另外,如果我们需要在读取数据的同时对图像进行增强的话,可以在__getitem__(self, index)函数中设置图像增强的代码,图像增强的方法可以使用Pytorch内置的图像增强方式,也可以使用自定义或者其他的图像增强库。这个很灵活,当然要记住一点,在Pytorch中得到的图像必须是tensor,也就是说我们还需要在__getitem__中将读取到的数据转换为tensor

2. DataLoader类

Dataset类是读入数据集数据并且对读入的数据进行了索引。但是光有这个功能是不够用的,在实际的加载数据集的过程中,我们的数据量往往都很大,对此我们还需要一下几个功能:

  • 可以分批次读取:batch-size
  • 可以对数据进行随机读取,可以对数据进行洗牌操作(shuffling),打乱数据集内数据分布的顺序
  • 可以并行加载数据(利用多核处理器加快载入数据的效率)

这时候就需要Dataloader类了,它为我们提供的常用操作有:batch_size(每个batch的大小), shuffle(是否进行shuffle操作), num_workers(加载数据的时候使用几个子进程)。Dataloader这个类并不需要我们自己设计代码,我们只需要利用DataLoader类读取我们设计好的Dataset子类即可:

# 利用dataloader读取我们的数据对象,并设定batch-size和工作进程
loader = DataLoader(train_dataset, batch_size=16, num_workers=4, shuffle=True)

这时候通过loader返回的数据就是按照batch_size来返回特定数量的训练数据的tensor;利用了多进程,读取数据的速度相比单进程快很多;设置了数据的随机读取,打乱了数据集分布的顺序。

参考:https://www.cnblogs.com/ranjiewen/p/10128046.html

3. 实例

下面通过网络上收集的神奇宝贝图片,制作图像分类数据集。

数据集链接:链接: https://pan.baidu.com/s/1pCfDEDHFn0UjSJTqfBBFng 提取码: 6b3c

上面的数据集中有1168张宝可梦的图片,其中皮卡丘234张、超梦239张、杰尼龟223张、小火龙238、张妙蛙种子234张

下载后的目录结构如下:

Pytorch学习_定义自己的数据集2_第2张图片

每个目录由神奇宝贝名字命名,对应目录有下是该神奇宝贝的图片,图片的格式有jpgpngjpeg三种。

数据集的划分如下:

训练集60%,验证集20%,测试集20%。

代码实现:

#coding=utf-8
import torch
import os, glob
import random, csv 
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
 
class Pokemon(Dataset):
    
    def __init__(self, root, resize, mode):
        super(Pokemon, self).__init__()
        self.root = root
        self.resize = resize
        self.name2label = {}
        # 返回指定目录下的文件列表,并对文件列表进行排序,
        # os.listdir每次返回目录下的文件列表顺序会不一致,
        # 排序是为了每次返回文件列表顺序一致
        for name in sorted(os.listdir(os.path.join(root))):
            # 过滤掉非目录文件
            if not os.path.isdir(os.path.join(root, name)):
                continue
            #构建字典,名字:0~4数字
            self.name2label[name] = len(self.name2label.keys())
       
        # eg: {'squirtle': 4, 'bulbasaur': 0, 'pikachu': 3, 'mewtwo': 2, 'charmander': 1}
        print(self.name2label)
 
        # image, label
        self.images, self.labels = self.load_csv("images.csv")
 
        # 对数据集进行划分
        if mode == "train": # 60%
            self.images = self.images[:int(0.6*len(self.images))]
            self.labels = self.labels[:int(0.6*len(self.labels))]
        elif mode == "val": # 20% = 60%~80%
            self.images = self.images[int(0.6*len(self.images)):int(0.8 * len(self.images))]
            self.labels = self.labels[int(0.6*len(self.labels)):int(0.8 * len(self.labels))]
        else: # 20% = 80%~100%
            self.images = self.images[int(0.8 * len(self.images)):]
            self.labels = self.labels[int(0.8 * len(self.labels)):]
 
    # 将目录下的图片路径与其对应的标签写入csv文件,
    # 并将csv文件写入的内容读出,返回图片名与其标签
    def load_csv(self, filename):
        """
        :param filename:
        :return:
        """
        # 是否已经存在了cvs文件
        if not os.path.exists(os.path.join(self.root, filename)):
            images = []
            for name in self.name2label.keys():
                # 获取指定目录下所有的满足后缀的图像名
                # pokemon/mewtwo/00001.png
                images += glob.glob(os.path.join(self.root, name, "*.png"))
                images += glob.glob(os.path.join(self.root, name, "*.jpg"))
                images += glob.glob(os.path.join(self.root, name, "*.jpeg"))
 
            # 1165 'pokemon/pikachu/00000058.png'
            print(len(images), images)
 
            # 将元素打乱
            random.shuffle(images)
            with open(os.path.join(self.root, filename), mode="w", newline="") as f:
                writer = csv.writer(f)
                for img in images: # 'pokemon/pikachu/00000058.png'
                    name = img.split(os.sep)[-2]
                    label = self.name2label[name]
                    # 将图片路径以及对应的标签写入到csv文件中
                    # 'pokemon/pikachu/00000058.png', 0
                    writer.writerow([img, label])
                print("writen into csv file: ", filename)
 
        # 如果已经存在了csv文件,则读取csv文件
        images, labels = [], []
        with open(os.path.join(self.root, filename)) as f:
            reader = csv.reader(f)
            for row in reader:
                # 'pokemon/pikachu/00000058.png', 0
                img, label = row
                label = int(label)
 
                images.append(img)
                labels.append(label)
        assert len(images) == len(labels)
 
        return images, labels
 
    def __len__(self):
        return len(self.images)
 
    def denormalize(self, x_hat):
 
        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]
 
        # x_hat = (x-mean)/std
        # x = x_hat*std = mean
        # x: [c, h, w]
        # mean: [3] => [3, 1, 1]
        mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
        std = torch.tensor(std).unsqueeze(1).unsqueeze(1)
        # print(mean.shape, std.shape)
        x = x_hat * std + mean
 
        return x
 
    def __getitem__(self, idx):
        # idx~[0~len(images)]
        # self.images, self.labels
        # img: 'pokemon/bulbasaur/00000000.png'
        # label: 0
        img, label = self.images[idx], self.labels[idx]
 
        tf = transforms.Compose([
            lambda x:Image.open(x).convert("RGB"), # string path => image data
            transforms.Resize((int(self.resize*1.25), int(self.resize*1.25))),
            transforms.RandomRotation(15),
            transforms.CenterCrop(self.resize),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
 
        img = tf(img)
        label = torch.tensor(label)
 
        return img, label
 
def main():
    import visdom
    import time
 
    viz = visdom.Visdom()
    db = Pokemon("pokemon", 224, "train")
    x, y = next(iter(db))
    print("sample: ", x.shape, y.shape, y)
    viz.image(db.denormalize(x), win='sample_x', opts=dict(title='sample_x'))
    loader = DataLoader(db, batch_size=32, shuffle=True, num_workers=4)
    for x, y in loader:
        viz.images(db.denormalize(x), nrow=8, win='batch', opts=dict(title='batch'))
        viz.text(str(y.numpy()), win='label', opts=dict(title='batch-y'))
        time.sleep(10)
 
if __name__ == '__main__':
    main()

代码中定义的子类Pokemon继承自Dataset类,重写了父类的__len____getitem__方法。

__init__方法中首先读取图片路径,构建了标签字典name2label,接下来将图片路径以及对应的标签写入到csv文件中,通过保存在csv文件中的路径与标签信息,划分数据集,csv文件内容如下:Pytorch学习_定义自己的数据集2_第3张图片

__len__方法中返回了数据集的大小。

__getitem__方法中进行图像缩放、图像旋转、图像去中心、转换到tensor、归一化等数据增强。

另外,在Pokemon类中实现了denormalize方法,也就是反归一化,因为在对图像进行归一化处理后,在visdom显示图像的时候,可见度不高,因此denormalize方法仅在visdom显示图像的时候调用。

main()函数中,进行数据集的可视化,可以看到继承自Dataset类的Pokemon类可以通过迭代器iter进行访问,并通过visdom进行可视化展示,另外,通过DataLoader类实现了对数据集的加载,在visdom中以32个batch进行加载。

验证效果

启动两个终端,分别执行如下两条命令:

python -m visdom.server

Pytorch学习_定义自己的数据集2_第4张图片

python pokemon.py

img

复制第一个终端中visdom链接 http://localhost:8097到浏览器

Pytorch学习_定义自己的数据集2_第5张图片

可以看到,32个batch的图片与标签在浏览器中展示。

你可能感兴趣的:(pytorch)